Skip to content

Commit

Permalink
Update MsprimeSimulator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-huang committed Jun 6, 2024
1 parent 1cc089d commit 3420a0e
Showing 1 changed file with 62 additions and 12 deletions.
74 changes: 62 additions & 12 deletions gaia/utils/simulators/MsprimeSimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,71 @@ def _get_true_tracts(self, ts: tskit.TreeSequence, tgt_id: str, src_id: str,
except IndexError:
raise ValueError(f'Population {tgt_id} is not found.')

for m in ts.migrations():
if (m.dest==src_id) and (m.source==tgt_id):
tgt_samples = ts.samples(tgt_id)
ts, migtable = self._simplify_ts(ts=ts, tgt_id=tgt_id, src_id=src_id)

#for m in ts.migrations():
# if (m.dest==src_id) and (m.source==tgt_id):
for m in migtable:
# For simulations with a long sequence, large sample size, and/or deep generation time
# This function may become slow
# Use new arguments from https://github.com/tskit-dev/tskit/pull/2762
# for t in ts.trees(left=m.left, right=m.right):
for t in ts.trees():
if m.left >= t.interval.right: continue
if m.right <= t.interval.left: break # [l, r)
for n in ts.samples(tgt_id):
if t.is_descendant(n, m.node):
left = m.left if m.left > t.interval.left else t.interval.left
right = m.right if m.right < t.interval.right else t.interval.right
if is_phased: sample_id = f'tsk_{ts.node(n).individual}_{int(n%ploidy+1)}'
else: sample_id = f'tsk_{ts.node(n).individual}'
tracts += f'1\t{int(left)}\t{int(right)}\t{sample_id}\n'
for t in ts.trees():
if m.left >= t.interval.right: continue
if m.right <= t.interval.left: break # [l, r)
#for n in ts.samples(tgt_id):
for n in tgt_samples:
if t.is_descendant(n, m.node):
left = m.left if m.left > t.interval.left else t.interval.left
right = m.right if m.right < t.interval.right else t.interval.right
if is_phased: sample_id = f'tsk_{ts.node(n).individual}_{int(n%ploidy+1)}'
else: sample_id = f'tsk_{ts.node(n).individual}'
tracts += f'1\t{int(left)}\t{int(right)}\t{sample_id}\n'

return tracts


def _simplify_ts(self, ts: tskit.TreeSequence, tgt_id: str, src_id: str) -> tskit.TreeSequence:
"""
"""
from copy import deepcopy

#now we create reduced tree sequence objects
ts_dump_mig = ts.dump_tables()
migtable = ts_dump_mig.migrations
migtable2 = deepcopy(migtable)
migtable2.clear()

#we search for all rows involving source and target
for mrow in migtable:
if (mrow.dest==src_id) and (mrow.source==tgt_id):
migtable2.append(mrow)

#the new tree sequence stores only the relevant migrations (involving source and target)
ts_dump_mig.migrations.replace_with(migtable2)
ts_dump_sequence_mig = ts_dump_mig.tree_sequence()

#in the other tree sequence, we delete all migration events
ts_dump = ts.dump_tables()
ts_dump.migrations.clear()
ts_dump_sequence = ts_dump.tree_sequence()

#we search for all nodes involving the relevant populations
populations_not_to_remove = [src_id, tgt_id]
individuals_not_to_remove = []
for ind in ts.nodes():
if ind.population in populations_not_to_remove:
individuals_not_to_remove.append(ind.id)

#the tree sequence object without migrations can be simplified
#the simplification contains only the relevant (source-target involving) information
ts_dump_sequence_simplified = ts_dump_sequence.simplify(
individuals_not_to_remove,
filter_populations=False,
filter_individuals=False,
filter_sites=False,
filter_nodes=False
)

return ts_dump_sequence_simplified, migtable2

0 comments on commit 3420a0e

Please sign in to comment.