Source code for mavis.assemble

import networkx as nx
import itertools
import warnings
from .bam import cigar as cigar_tools
from .bam.read import nsb_align
from .constants import reverse_complement


[docs]class Contig: """ """ def __init__(self, sequence, score): self.seq = sequence self.remapped_sequences = {} self.score = score self.alignments = [] self.input_reads = set() self.strand_specific = False def __hash__(self): return hash(self.seq)
[docs] def add_mapped_sequence(self, read, multimap=1): rc = reverse_complement(read) if rc in self.remapped_sequences: self.remapped_sequences[rc] = min(self.remapped_sequences.get(rc, 1), 1 / multimap) else: self.remapped_sequences[read] = min(self.remapped_sequences.get(read, 1), 1 / multimap)
[docs] def remap_score(self): return sum(self.remapped_sequences.values())
[docs]class DeBruijnGraph(nx.DiGraph): """ wrapper for a basic digraph enforces edge weights """
[docs] def get_edge_freq(self, n1, n2): """ returns the freq from the data attribute for a specified edge """ if not self.has_edge(n1, n2): raise KeyError('missing edge', n1, n2) data = self.get_edge_data(n1, n2) return data['freq']
[docs] def add_edge(self, n1, n2, freq=1): """ add a given edge to the graph, if it exists add the frequency to the existing frequency count """ if self.has_edge(n1, n2): data = self.get_edge_data(n1, n2) freq += data['freq'] nx.DiGraph.add_edge(self, n1, n2, freq=freq)
[docs] def trim_tails_by_freq(self, min_weight): """ for any paths where all edges are lower than the minimum weight trim Args: min_weight (int): the minimum weight for an edge to be retained """ for n in list(self.nodes()): if not self.has_node(n): continue # follow until the path forks or we run out of low weigh edges curr = n while self.degree(curr) == 1: if self.out_degree(curr) == 1: curr, other, data = self.out_edges(curr, data=True)[0] if data['freq'] < min_weight: self.remove_node(curr) curr = other else: break elif self.in_degree(curr) == 1: other, curr, data = self.in_edges(curr, data=True)[0] if data['freq'] < min_weight: self.remove_node(curr) curr = other else: break else: break for n in list(self.nodes()): if not self.has_node(n): continue if self.degree(n) == 0: self.remove_node(n)
[docs] def trim_noncutting_paths_by_freq(self, min_weight): """ trim any low weight edges where another path exists between the source and target of higher weight """ current_edges = list(self.edges(data=True)) for src, tgt, data in sorted(current_edges, key=lambda x: (x[2]['freq'], x[0], x[1])): # come up with the path by extending this edge either direction until the degree > 2 if not self.has_node(src) or not self.has_node(tgt) or not self.has_edge(src, tgt): continue path = [] while self.in_degree(src) == 1 and self.out_degree(src) == 1: temp = self.in_edges(src, data=True)[0] if temp[2]['freq'] >= min_weight: break path.insert(0, src) src = temp[0] path.insert(0, src) while self.in_degree(tgt) == 1 and self.out_degree(tgt) == 1: temp = self.out_edges(tgt, data=True)[0] if temp[2]['freq'] >= min_weight: break path.append(tgt) tgt = temp[1] path.append(tgt) start_edge_data = self.get_edge_data(path[0], path[1]) self.remove_edge(path[0], path[1]) end_edge_data = None if len(path) > 2: end_edge_data = self.get_edge_data(path[-2], path[-1]) self.remove_edge(path[-2], path[-1]) if not nx.has_path(self, src, tgt): self.add_edge(path[0], path[1], **start_edge_data) if len(path) > 2: self.add_edge(path[-2], path[-1], **end_edge_data) else: for node in path[1:-1]: self.remove_node(node)
[docs] def get_sinks(self, subgraph=None): """ returns all nodes with an outgoing degree of zero """ nodeset = set() if subgraph is None: subgraph = self.nodes() for node in subgraph: if self.out_degree(node) == 0: nodeset.add(node) return nodeset
[docs] def get_sources(self, subgraph=None): """ returns all nodes with an incoming degree of zero """ nodeset = set() if subgraph is None: subgraph = self.nodes() for node in subgraph: if self.in_degree(node) == 0: nodeset.add(node) return nodeset
[docs]def digraph_connected_components(graph, subgraph=None): """ the networkx module does not support deriving connected components from digraphs (only simple graphs) this function assumes that connection != reachable this means there is no difference between connected components in a simple graph and a digraph Args: graph (networkx.DiGraph): the input graph to gather components from Returns: :class:`list` of :class:`list`: returns a list of compnents which are lists of node names """ if subgraph is None: subgraph = graph.nodes() g = nx.Graph() for src, tgt in graph.edges(): if src in subgraph and tgt in subgraph: g.add_edge(src, tgt) for n in graph.nodes(): if n in subgraph: g.add_node(n) return nx.connected_components(g)
def _pull_assembled_paths(assembly, assembly_min_edge_weight, assembly_max_paths, log=lambda *pos, **kwargs: None): path_scores = {} # path_str => score_int unresolved_components = [(assembly_min_edge_weight, c) for c in digraph_connected_components(assembly)] while len(unresolved_components) > 0: # since now we know it's a tree, the assemblies will all be ltd to # simple paths w, component = unresolved_components.pop(0) paths_est = len(assembly.get_sinks(component)) * len(assembly.get_sources(component)) # if the assembly has too many sinks/sources we'll need to clean it # we can do this by removing all current sources/sinks subgraph = assembly.subgraph(component) if not nx.is_directed_acyclic_graph(subgraph): log('dropping cyclic subgraph', time_stamp=False) continue all_paths = [] if paths_est <= assembly_max_paths: for source, sink in itertools.product(assembly.get_sources(component), assembly.get_sinks(component)): all_paths.extend(list(nx.all_simple_paths(assembly, source, sink))) if len(all_paths) > assembly_max_paths: break paths_est = max([paths_est, len(all_paths)]) if paths_est > assembly_max_paths: log( 'reducing estimated paths. Current estimate is {}+ from'.format(paths_est), len(component), 'nodes', 'filter increase', w + 1, time_stamp=False) w += 1 assembly.trim_noncutting_paths_by_freq(w) assembly.trim_tails_by_freq(w) # with open('assembly_failed_edges.txt', 'w') as fh: # fh.write('source\ttarget\n') # for src, tgt in assembly.edges(): # fh.write('{}\t{}\n'.format(src, tgt)) # print('wrote: assembly_failed_edges.txt') # with open('assembly_failed_edges.tgf', 'w') as fh: # for node in assembly.nodes(): # fh.write(node + '\n') # fh.write('#\n') # for src, tgt, data in assembly.edges(data=True): # fh.write('{} {} {}\n'.format(src, tgt, data['freq'])) # print('wrote: assembly_failed_edges.tgf') # exit(1) for new_comp in digraph_connected_components(assembly, component): unresolved_components.append((w, new_comp)) else: for source, sink in itertools.product(assembly.get_sources(component), assembly.get_sinks(component)): paths = list(nx.all_simple_paths(assembly, source, sink)) for path in paths: s = path[0] + ''.join([p[-1] for p in path[1:]]) score = 0 for i in range(0, len(path) - 1): score += assembly.get_edge_freq(path[i], path[i + 1]) path_scores[s] = max(path_scores.get(s, 0), score) return path_scores
[docs]def assemble( sequences, assembly_max_kmer_size=None, assembly_min_edge_weight=3, assembly_min_match_quality=0.95, assembly_min_read_mapping_overlap=None, assembly_min_contig_length=None, assembly_min_exact_match_to_remap=6, assembly_max_paths=20, assembly_max_kmer_strict=False, log=lambda *pos, **kwargs: None ): """ for a set of sequences creates a DeBruijnGraph simplifies trailing and leading paths where edges fall below a weight threshold and the return all possible unitigs/contigs Args: sequences (:class:`list` of :class:`str`): a list of strings/sequences to assemble assembly_max_kmer_size (int): the size of the kmer to use assembly_min_edge_weight (int): see :term:`assembly_min_edge_weight` assembly_min_match_quality (float): percent match for re-aligned reads to contigs assembly_min_read_mapping_overlap (int): the minimum amount of overlap required when aligning reads to contigs assembly_max_paths (int): see :term:`assembly_max_paths` Returns: :class:`list` of :class:`Contig`: a list of putative contigs """ if len(sequences) == 0: return [] min_seq = min([len(s) for s in sequences]) if assembly_max_kmer_size is None: temp = int(min_seq * 0.75) if temp < 10: assembly_max_kmer_size = min(min_seq, 10) else: assembly_max_kmer_size = temp elif assembly_max_kmer_size > min_seq: if not assembly_max_kmer_strict: assembly_max_kmer_size = min_seq warnings.warn( 'cannot specify a kmer size larger than one of the input sequences. reset to {0}'.format(min_seq)) assembly_min_read_mapping_overlap = assembly_max_kmer_size if assembly_min_read_mapping_overlap is None else \ assembly_min_read_mapping_overlap assembly_min_contig_length = min_seq + 1 if assembly_min_contig_length is None else assembly_min_contig_length assembly = DeBruijnGraph() log('hashing kmers') for s in sequences: if len(s) < assembly_max_kmer_size: continue kmers_list = kmers(s, assembly_max_kmer_size) for kmer in kmers_list: l = kmer[:-1] r = kmer[1:] assembly.add_edge(l, r) # now just work with connected components assembly.trim_noncutting_paths_by_freq(assembly_min_edge_weight) # trim all paths from sources or to sinks where the edge weight is low assembly.trim_tails_by_freq(assembly_min_edge_weight) path_scores = _pull_assembled_paths( assembly, assembly_min_edge_weight=assembly_min_edge_weight, assembly_max_paths=assembly_max_paths, log=log ) # now map the contigs to the possible input sequences contigs = {} for seq, score in list(path_scores.items()): if seq not in sequences and len(seq) >= assembly_min_contig_length: contigs[seq] = Contig(seq, score) # remap the input reads filtered_contigs = {} for seq, contig in sorted(contigs.items()): rseq = reverse_complement(seq) if seq not in filtered_contigs and rseq not in filtered_contigs: filtered_contigs[seq] = contig contigs = list(filtered_contigs.values()) input_seq_kmers = {} for seq in sequences: input_seq_kmers[seq] = set(kmers(seq, assembly_min_exact_match_to_remap)) contig_kmers = {} for contig in contigs: contig_kmers[contig.seq] = set(kmers(contig.seq, assembly_min_exact_match_to_remap)) log('remapping reads to {} contigs'.format(len(contigs))) for input_seq in sequences: maps_to = {} # contig, score for contig in contigs: if len(input_seq_kmers[input_seq] & contig_kmers[contig.seq]) == 0: continue a = nsb_align( contig.seq, input_seq, min_overlap_percent=assembly_min_read_mapping_overlap / len(contig.seq), min_match=assembly_min_match_quality ) if len(a) != 1: continue if cigar_tools.match_percent(a[0].cigar) < assembly_min_match_quality: continue maps_to[contig] = a[0] for contig, read in maps_to.items(): contig.add_mapped_sequence(read, len(maps_to.keys())) log('assemblies complete') return contigs
[docs]def kmers(s, size): """ for a sequence, compute and return a list of all kmers of a specified size Args: s (str): the input sequence size (int): the size of the kmers Returns: :class:`list` of :class:`str`: the list of kmers Example: >>> kmers('abcdef', 2) ['ab', 'bc', 'cd', 'de', 'ef'] """ kmers = [] for i in range(0, len(s)): if i + size > len(s): break kmers.append(s[i:i + size]) return kmers