from __future__ import division
from ..interval import Interval
from ..constants import STRAND, ORIENT
from ..breakpoint import BreakpointPair, Breakpoint
from ..util import log
from collections import namedtuple
from copy import copy
import itertools
[docs]class BreakpointPairGroupKey(namedtuple('BreakpointPairGroupKey', [
'chr1', 'chr2', 'orient1', 'orient2', 'strand1', 'strand2', 'opposing_strands', 'explicit_strand'
])):
def __new__(cls, chr1, chr2, orient1, orient2, strand1, strand2, opposing_strands=None, explicit_strand=False):
if STRAND.NS in [strand1, strand2] and explicit_strand:
raise ValueError('cannot have unspecified strand when explicit_strand is set')
if not explicit_strand and opposing_strands is None:
raise ValueError('opposing_strands must be specified when explicit_strand is false')
if explicit_strand:
opp = (strand1 != strand2)
if opposing_strands is None:
opposing_strands = opp
elif opposing_strands != opp:
raise ValueError('strand1 v strand2 v opposing_strands conflict.', strand1, strand2, opposing_strands)
STRAND.enforce(strand1)
STRAND.enforce(strand2)
ORIENT.enforce(orient1)
ORIENT.enforce(orient2)
self = super(BreakpointPairGroupKey, cls).__new__(
cls, chr1, chr2, orient1, orient2, strand1, strand2, opposing_strands, explicit_strand)
return self
[docs]def weighted_mean(values, weights=None):
if weights is None:
weights = [1 for v in values]
return sum(x * w for x, w in zip(values, weights)) / sum(weights)
[docs]def merge_integer_intervals(*intervals, weight_adjustment=0):
"""
Merges a set of integer intervals into a single interval where the center is the
weighted mean of the input intervals. The weight is inversely proportional to the
length of each interval. The length of the final interval is the average of the lengths
of the input intervals capped in size so that it never extends beyond the union of the
input intervals
Args:
weight_adjustment (int): add to length to lower weighting differences between small intervals
"""
intervals = list(intervals)
centers = []
weights = []
lengths = []
if len(intervals) == 0:
raise AttributeError('cannot compute the weighted mean interval of an empty set of intervals')
for i in range(0, len(intervals)):
curr = intervals[i]
intervals[i] = Interval(curr[0], curr[1] + 0.9)
for temp in range(0, intervals[i].freq):
centers.append(intervals[i].center)
weights.append((weight_adjustment + 1) / (intervals[i].length() + weight_adjustment))
lengths.append(intervals[i].length())
center = round(weighted_mean(centers, weights=weights) * 2, 0) / 2
size = weighted_mean(lengths) # -1 b/c center counts as one
start = max([round(center - size / 2, 0), min([i[0] for i in intervals])])
end = min([round(center + size / 2, 0), max([i[1] for i in intervals])])
offset = min([center - start, end - center])
result = Interval(int(round(center - offset, 0)), int(round(center + offset / 2, 0)))
return result
[docs]def pair_key(pair):
return (
pair.break1.chr,
pair.break2.chr,
pair.break1.start,
pair.break2.start,
pair.break1.end,
pair.break2.end,
pair.break1.orient,
pair.break2.orient,
pair.break1.strand if pair.stranded else STRAND.NS,
pair.break2.strand if pair.stranded else STRAND.NS,
pair.stranded,
pair.opposing_strands)
[docs]def all_pair_group_keys(pair, explicit_strand=False):
opt = [
[pair.break1.chr],
[pair.break2.chr],
ORIENT.expand(pair.break1.orient),
ORIENT.expand(pair.break2.orient),
[STRAND.NS] if not explicit_strand else STRAND.expand(pair.break1.strand),
[STRAND.NS] if not explicit_strand else STRAND.expand(pair.break2.strand),
[pair.opposing_strands]
]
result = []
for c1, c2, o1, o2, s1, s2, opp in list(itertools.product(*opt)):
if explicit_strand and (s1 != s2) != opp:
continue
elif opp == (o1 != o2):
continue
result.append(BreakpointPairGroupKey(c1, c2, o1, o2, s1, s2, opp, explicit_strand=explicit_strand))
return result
[docs]def merge_by_union(input_pairs, group_key, weight_adjustment=10, cluster_radius=200):
"""
for a given set of breakpoint pairs, merge the union of all pairs that are
within the given distance (cluster_radius)
"""
pairs_by_start = sorted(input_pairs, key=lambda x: x.break1.start)
pairs_by_end = sorted(input_pairs, key=lambda x: x.break2.start)
edges = {pair_key(p): set() for p in input_pairs}
pairs_by_key = {}
for i in range(0, len(input_pairs)):
# try all combinations until start distance alone is too far
curr = pairs_by_start[i]
ckey = pair_key(curr)
edges.setdefault(ckey, set())
pairs_by_key.setdefault(ckey, []).append(curr)
for j in range(i + 1, len(input_pairs)):
other = pairs_by_start[j]
okey = pair_key(other)
d = abs(Interval.dist(curr.break1, other.break1))
if d > cluster_radius:
break
d += abs(Interval.dist(curr.break2, other.break2))
if d <= cluster_radius:
edges[ckey].add(okey)
edges[okey].add(ckey)
# now try all until the end distance alone is too great
curr = pairs_by_end[i]
for j in range(i + 1, len(input_pairs)):
other = pairs_by_end[j]
okey = pair_key(other)
d = abs(Interval.dist(curr.break2, other.break2))
if d > cluster_radius:
break
d += abs(Interval.dist(curr.break1, other.break1))
if d <= cluster_radius:
edges[okey].add(ckey)
edges[ckey].add(okey)
merged = set()
merge_nodes = []
for node in edges:
if node in merged:
continue
adj = edges[node] | {node}
merged.add(node)
unmerged = adj - merged
while len(unmerged) > 0:
for node in unmerged:
adj = adj | edges[node]
merged.add(node)
unmerged = adj - merged
merge_nodes.append(adj)
nodes = {}
for node_keys in merge_nodes:
pairs = []
for pkey in node_keys:
pairs.extend(pairs_by_key[pkey])
itvl1 = merge_integer_intervals(*[p.break1 for p in pairs], weight_adjustment=weight_adjustment)
itvl2 = merge_integer_intervals(*[p.break2 for p in pairs], weight_adjustment=weight_adjustment)
if group_key.chr1 == group_key.chr2:
itvl1.end = min(itvl2.end, itvl1.end)
itvl2.start = max(itvl2.start, itvl1.start)
itvl1.start = min(itvl1.start, itvl1.end)
itvl2.end = max(itvl2.end, itvl2.start)
b1 = Breakpoint(group_key.chr1, itvl1.start, itvl1.end, orient=group_key.orient1, strand=group_key.strand1)
b2 = Breakpoint(group_key.chr2, itvl2.start, itvl2.end, orient=group_key.orient2, strand=group_key.strand2)
# create the new bpp representing the merge of the input pairs
new_bpp = BreakpointPair(
b1, b2, opposing_strands=group_key.opposing_strands, stranded=group_key.explicit_strand)
nodes.setdefault(new_bpp, []).extend(pairs)
return nodes
[docs]def merge_breakpoint_pairs(input_pairs, cluster_radius=200, cluster_initial_size_limit=25, verbose=False):
"""
two-step merging process
1. merges all 'small' (see cluster_initial_size_limit) events as the union of all events that
fall within the cluster_radius
2. for all remaining events choose the 'best' merge for any event within cluster_radius of an
existing node. Otherwise the node is added unmerged. The events in the second phase are
done in order of smallest total breakpoint interval size to largest
Args:
input_pairs (list of BreakpointPair): the pairs to be merged
cluster_radius (int) maximum distance allowed for a node to merge
cluster_initial_size_limit (int): maximum size of breakpoint intervals allowed in the first merging phase
Returns:
dict of list of BreakpointPair by BreakpointPair: mapping of merged breakpoint pairs to the input pairs used in the merge
"""
def pair_center_distance(pair1, pair2):
d = abs(pair1.break1.center - pair2.break1.center)
d += abs(pair1.break2.center - pair2.break2.center)
return d
mapping = {}
groups = {} # split the groups by putative pairings
pair_weight = {}
explicit_strand = False
phase2_groups = {}
for pair in input_pairs:
if pair.stranded:
explicit_strand = True
break
doubled = 0
for i, old_pair in enumerate(input_pairs):
pair = copy(old_pair)
pair.data['tag'] = i
k = pair_key(pair)
pair_weight.setdefault(k, []).append(pair)
putative_group_keys = all_pair_group_keys(pair, explicit_strand=explicit_strand)
doubled += len(putative_group_keys)
if len(putative_group_keys) < 1:
raise NotImplementedError('bad breakpoint input does not fit any groups', pair)
for key in putative_group_keys:
if len(pair.break1) + len(pair.break2) > cluster_initial_size_limit:
phase2_groups.setdefault(key, []).append(pair)
else:
groups.setdefault(key, []).append(pair)
# now try all pairwise combinations within groups
for group_key in sorted(set(list(groups) + list(phase2_groups))):
count = len(groups.get(group_key, [])) + len(phase2_groups.get(group_key, []))
if verbose:
log(group_key, 'pairs:', count)
nodes = merge_by_union(
groups.get(group_key, []), group_key,
weight_adjustment=cluster_initial_size_limit, cluster_radius=cluster_radius)
# phase 2. Sort all the breakpoint pairs left by size and merge the smaller ones in first
# this is be/c we assume that a larger breakpoint interval indicates less certainty in the call
phase2_pairs = sorted(
phase2_groups.get(group_key, []), key=lambda p: (len(p.break1) + len(p.break2), pair_key(p)))
for pair in phase2_pairs:
distances = sorted([(pair_center_distance(pair, node), node) for node in nodes], key=lambda x: x[0])
merged = False
if len(distances) > 0:
best = min(distances, key=lambda x: x[0])
for dist, node in distances:
if dist > best[0] or dist > cluster_radius:
break
pairs = nodes[node] + [pair]
itvl1 = merge_integer_intervals(
*[p.break1 for p in pairs], weight_adjustment=cluster_initial_size_limit)
itvl2 = merge_integer_intervals(
*[p.break2 for p in pairs], weight_adjustment=cluster_initial_size_limit)
if group_key.chr1 == group_key.chr2:
itvl1.end = min(itvl2.end, itvl1.end)
itvl2.start = max(itvl2.start, itvl1.start)
itvl1.start = min(itvl1.start, itvl1.end)
itvl2.end = max(itvl2.end, itvl2.start)
b1 = Breakpoint(
group_key.chr1, itvl1.start, itvl1.end, orient=group_key.orient1, strand=group_key.strand1)
b2 = Breakpoint(
group_key.chr2, itvl2.start, itvl2.end, orient=group_key.orient2, strand=group_key.strand2)
new_bpp = BreakpointPair(
b1, b2, opposing_strands=group_key.opposing_strands, stranded=explicit_strand)
del nodes[node]
nodes.setdefault(new_bpp, []).extend(pairs)
merged = True
if not merged:
b1 = Breakpoint(
group_key.chr1, pair.break1.start, pair.break1.end,
orient=group_key.orient1, strand=group_key.strand1)
b2 = Breakpoint(
group_key.chr2, pair.break2.start, pair.break2.end,
orient=group_key.orient2, strand=group_key.strand2)
new_bpp = BreakpointPair(
b1, b2, opposing_strands=group_key.opposing_strands, stranded=explicit_strand)
nodes.setdefault(new_bpp, []).append(pair)
if verbose:
log('merged', count, 'down to', len(nodes))
for node, pairs in nodes.items():
if node in mapping:
raise KeyError('duplicate merge node', str(node), node, pair_key(node))
mapping[node] = pairs
# assertion to check that no nodes were left out of merging
merge_sources = set()
for merge_node, sources in mapping.items():
merge_sources.update([p.data['tag'] for p in sources])
if len(merge_sources) != len(input_pairs):
raise AssertionError('merged node inputs ({}) does not equal the number of pairs input ({})'.format(
len(merge_sources), len(input_pairs)))
return mapping