import networkx as nx
import itertools
import warnings
from .bam import cigar as cigar_tools
from .bam.read import nsb_align, calculate_alignment_score
from .constants import reverse_complement
from .util import devnull
[docs]class Contig:
"""
"""
def __init__(self, sequence, score):
self.seq = sequence
self.remapped_sequences = {} # alignment score contribution on the contig by read
self.score = score
self.alignments = []
self.input_reads = set()
self.strand_specific = False
def __hash__(self):
return hash(self.seq)
[docs] def add_mapped_sequence(self, read, multimap=1):
self.remapped_sequences[read] = 1 / multimap
[docs] def remap_score(self):
return sum(self.remapped_sequences.values())
[docs]class DeBruijnGraph(nx.DiGraph):
"""
wrapper for a basic digraph
enforces edge weights
"""
[docs] def get_edge_freq(self, n1, n2):
"""
returns the freq from the data attribute for a specified edge
"""
if not self.has_edge(n1, n2):
raise KeyError('missing edge', n1, n2)
data = self.get_edge_data(n1, n2)
return data['freq']
[docs] def add_edge(self, n1, n2, freq=1):
"""
add a given edge to the graph, if it exists add the frequency to the existing frequency count
"""
if self.has_edge(n1, n2):
data = self.get_edge_data(n1, n2)
freq += data['freq']
nx.DiGraph.add_edge(self, n1, n2, freq=freq)
[docs] def trim_tails_by_freq(self, min_weight):
"""
for any paths where all edges are lower than the minimum weight trim
Args:
min_weight (int): the minimum weight for an edge to be retained
"""
for n in list(self.nodes()):
if not self.has_node(n):
continue
# follow until the path forks or we run out of low weigh edges
curr = n
while self.degree(curr) == 1:
if self.out_degree(curr) == 1:
curr, other, data = self.out_edges(curr, data=True)[0]
if data['freq'] < min_weight:
self.remove_node(curr)
curr = other
else:
break
elif self.in_degree(curr) == 1:
other, curr, data = self.in_edges(curr, data=True)[0]
if data['freq'] < min_weight:
self.remove_node(curr)
curr = other
else:
break
else:
break
for n in list(self.nodes()):
if not self.has_node(n):
continue
if self.degree(n) == 0:
self.remove_node(n)
[docs] def trim_forks_by_freq(self, min_weight):
"""
for all nodes in the graph, if the node has an out-degree > 1 and one of the outgoing
edges has freq < min_weight. then that outgoing edge is deleted
"""
nodes = list(self.nodes())
for node in nodes:
outgoing_edges = self.out_edges(node, data=True)
if len(outgoing_edges) > 1:
for src, tgt, data in outgoing_edges:
if data['freq'] < min_weight:
self.remove_edge(src, tgt)
[docs] def trim_noncutting_paths_by_freq(self, min_weight):
"""
trim any low weight edges where another path exists between the source and target
of higher weight
"""
current_edges = list(self.edges(data=True))
for src, tgt, data in sorted(current_edges, key=lambda x: (x[2]['freq'], x[0], x[1])):
# come up with the path by extending this edge either direction until the degree > 2
if not self.has_node(src) or not self.has_node(tgt) or not self.has_edge(src, tgt):
continue
if src == tgt and data['freq'] < min_weight:
self.remove_edge(src, tgt)
else:
path = []
while self.in_degree(src) == 1 and self.out_degree(src) == 1:
s, t, data = self.in_edges(src, data=True)[0]
if data['freq'] >= min_weight or s in path:
break
path.insert(0, src)
src = s
path.insert(0, src)
while self.in_degree(tgt) == 1 and self.out_degree(tgt) == 1:
s, t, data = self.out_edges(tgt, data=True)[0]
if data['freq'] >= min_weight or t in path:
break
path.append(tgt)
tgt = t
path.append(tgt)
start_edge_data = self.get_edge_data(path[0], path[1])
self.remove_edge(path[0], path[1])
end_edge_data = None
if len(path) > 2:
end_edge_data = self.get_edge_data(path[-2], path[-1])
self.remove_edge(path[-2], path[-1])
if not nx.has_path(self, src, tgt):
self.add_edge(path[0], path[1], **start_edge_data)
if len(path) > 2:
self.add_edge(path[-2], path[-1], **end_edge_data)
else:
for node in path[1:-1]:
self.remove_node(node)
[docs] def get_sinks(self, subgraph=None):
"""
returns all nodes with an outgoing degree of zero
"""
nodeset = set()
if subgraph is None:
subgraph = self.nodes()
for node in subgraph:
if self.out_degree(node) == 0:
nodeset.add(node)
return nodeset
[docs] def get_sources(self, subgraph=None):
"""
returns all nodes with an incoming degree of zero
"""
nodeset = set()
if subgraph is None:
subgraph = self.nodes()
for node in subgraph:
if self.in_degree(node) == 0:
nodeset.add(node)
return nodeset
[docs]def digraph_connected_components(graph, subgraph=None):
"""
the networkx module does not support deriving connected
components from digraphs (only simple graphs)
this function assumes that connection != reachable
this means there is no difference between connected components
in a simple graph and a digraph
Args:
graph (networkx.DiGraph): the input graph to gather components from
Returns:
:class:`list` of :class:`list`: returns a list of compnents which are lists of node names
"""
if subgraph is None:
subgraph = graph.nodes()
g = nx.Graph()
for src, tgt in graph.edges():
if src in subgraph and tgt in subgraph:
g.add_edge(src, tgt)
for n in graph.nodes():
if n in subgraph:
g.add_node(n)
return nx.connected_components(g)
[docs]def pull_contigs_from_component(
assembly, component, assembly_min_nc_edge_weight, assembly_max_paths, log=devnull
):
"""
builds contigs from the a connected component of the assembly DeBruijn graph
Args:
assembly (DeBruijnGraph): the assembly graph
component (list): list of nodes which make up the connected component
assembly_min_nc_edge_weight (int): the minimum weight to not remove a non cutting edge/path
assembly_max_paths (int): the maximum number of paths allowed before the graph is further simplified
log (function): the log function
Returns:
:class:`Dict` of :class:`int` by :class:`str`: the paths/contigs and their scores
"""
path_scores = {} # path_str => score_int
w = assembly_min_nc_edge_weight
unresolved_components = [component]
while len(unresolved_components) > 0:
# since now we know it's a tree, the assemblies will all be ltd to
# simple paths
component = unresolved_components.pop(0)
paths_est = len(assembly.get_sinks(component)) * len(assembly.get_sources(component))
if paths_est > assembly_max_paths:
min_edge_weight = min([e[2]['freq'] for e in assembly.edges(
assembly.get_sources(component) | assembly.get_sinks(component), data=True)])
w = max([w + 1, min_edge_weight])
log(
'reducing estimated paths. Current estimate is {}+ from'.format(paths_est),
len(component), 'nodes', 'filter increase', w)
assembly.trim_forks_by_freq(w)
assembly.trim_noncutting_paths_by_freq(w)
assembly.trim_tails_by_freq(w)
unresolved_components.extend(digraph_connected_components(assembly, component))
else:
for source, sink in itertools.product(assembly.get_sources(component), assembly.get_sinks(component)):
paths = list(nx.all_simple_paths(assembly, source, sink))
for path in paths:
s = path[0] + ''.join([p[-1] for p in path[1:]])
score = 0
for i in range(0, len(path) - 1):
score += assembly.get_edge_freq(path[i], path[i + 1])
path_scores[s] = max(path_scores.get(s, 0), score)
return path_scores
[docs]def assemble(
sequences,
assembly_max_kmer_size=None,
assembly_min_nc_edge_weight=3,
assembly_min_edge_weight=2,
assembly_min_match_quality=0.95,
assembly_min_read_mapping_overlap=None,
assembly_min_contig_length=None,
assembly_min_exact_match_to_remap=6,
assembly_max_paths=20,
assembly_max_kmer_strict=False,
log=lambda *pos, **kwargs: None
):
"""
for a set of sequences creates a DeBruijnGraph
simplifies trailing and leading paths where edges fall
below a weight threshold and the return all possible unitigs/contigs
Args:
sequences (:class:`list` of :class:`str`): a list of strings/sequences to assemble
assembly_max_kmer_size: see :term:`assembly_max_kmer_size`
assembly_min_nc_edge_weight: see :term:`assembly_min_nc_edge_weight`
assembly_min_edge_weight: see :term:`assembly_min_edge_weight`
assembly_min_match_quality: see :term:`assembly_min_match_quality`
assembly_min_read_mapping_overlap: see :term:`assembly_min_read_mapping_overlap`
assembly_min_contig_length: see :term:`assembly_min_contig_length`
assembly_min_exact_match_to_remap: see :term:`assembly_min_exact_match_to_remap`
assembly_max_paths: see :term:`assembly_max_paths`
log (function): the log function
Returns:
:class:`list` of :class:`Contig`: a list of putative contigs
"""
if len(sequences) == 0:
return []
min_seq = min([len(s) for s in sequences])
if assembly_max_kmer_size is None:
temp = int(min_seq * 0.75)
if temp < 10:
assembly_max_kmer_size = min(min_seq, 10)
else:
assembly_max_kmer_size = temp
elif assembly_max_kmer_size > min_seq:
if not assembly_max_kmer_strict:
assembly_max_kmer_size = min_seq
warnings.warn(
'cannot specify a kmer size larger than one of the input sequences. reset to {0}'.format(min_seq))
assembly_min_read_mapping_overlap = assembly_max_kmer_size if assembly_min_read_mapping_overlap is None else \
assembly_min_read_mapping_overlap
assembly_min_contig_length = min_seq + 1 if assembly_min_contig_length is None else assembly_min_contig_length
assembly = DeBruijnGraph()
log('hashing kmers')
for s in sequences:
if len(s) < assembly_max_kmer_size:
continue
kmers_list = kmers(s, assembly_max_kmer_size)
for kmer in kmers_list:
assembly.add_edge(kmer[:-1], kmer[1:])
# use the ab min edge weight to remove all low weight edges first
edges = list(assembly.edges(data=True))
for s, t, data in edges:
if data['freq'] < assembly_min_edge_weight:
assembly.remove_edge(s, t)
# then remove all nodes with no edges
nodes = list(assembly.nodes())
for n in nodes:
if assembly.in_degree(n) == 0 and assembly.out_degree(n) == 0:
assembly.remove_node(n)
# drop all cyclic components
for component in digraph_connected_components(assembly):
subgraph = assembly.subgraph(component)
if not nx.is_directed_acyclic_graph(subgraph):
log('dropping cyclic component', time_stamp=False)
for node in subgraph.nodes():
assembly.remove_node(node)
# initial data cleaning
assembly.trim_forks_by_freq(assembly_min_nc_edge_weight)
assembly.trim_noncutting_paths_by_freq(assembly_min_nc_edge_weight)
assembly.trim_tails_by_freq(assembly_min_nc_edge_weight)
path_scores = {}
for component in digraph_connected_components(assembly):
# pull the path scores
path_scores.update(pull_contigs_from_component(
assembly, component,
assembly_min_nc_edge_weight=assembly_min_nc_edge_weight,
assembly_max_paths=assembly_max_paths,
log=log
))
# now map the contigs to the possible input sequences
contigs = {}
for seq, score in list(path_scores.items()):
if seq not in sequences and len(seq) >= assembly_min_contig_length:
contigs[seq] = Contig(seq, score)
# remap the input reads
filtered_contigs = {}
for seq, contig in sorted(contigs.items()):
rseq = reverse_complement(seq)
if seq not in filtered_contigs and rseq not in filtered_contigs:
filtered_contigs[seq] = contig
contigs = list(filtered_contigs.values())
log('remapping reads to {} contigs'.format(len(contigs)))
for input_seq in sequences:
maps_to = {} # contig, score
for contig in contigs:
a = nsb_align(
contig.seq,
input_seq,
min_overlap_percent=assembly_min_read_mapping_overlap / len(contig.seq),
min_match=assembly_min_match_quality,
min_consecutive_match=assembly_min_exact_match_to_remap
)
if len(a) != 1:
continue
if cigar_tools.match_percent(a[0].cigar) < assembly_min_match_quality:
continue
maps_to[contig] = a[0]
if len(maps_to) > 0:
scores = []
for contig, read in maps_to.items():
score = calculate_alignment_score(read)
scores.append((contig, read, score, read.reference_end - read.reference_start))
max_score = max([(t[2], t[3]) for t in scores])[0:2]
best_alignments = []
for contig, read, score1, score2 in scores:
if max_score == (score1, score2):
best_alignments.append((contig, read))
assert(len(best_alignments) >= 1)
for contig, read in best_alignments:
contig.add_mapped_sequence(read, len(best_alignments))
log('assemblies complete. scores (build, remap):', [(c.score, round(c.remap_score(), 1)) for c in contigs])
return contigs
[docs]def kmers(s, size):
"""
for a sequence, compute and return a list of all kmers of a specified size
Args:
s (str): the input sequence
size (int): the size of the kmers
Returns:
:class:`list` of :class:`str`: the list of kmers
Example:
>>> kmers('abcdef', 2)
['ab', 'bc', 'cd', 'de', 'ef']
"""
kmers = []
for i in range(0, len(s)):
if i + size > len(s):
break
kmers.append(s[i:i + size])
return kmers