Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not use query_table for variant id queries #4518

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def passes_quality_field(gt):
def _passes_vcf_filters(gt):
return hl.is_missing(gt.filters) | (gt.filters.length() < 1)

def _parse_variant_keys(self, variant_keys=None, **kwargs):
def _parse_variant_keys(self, variant_keys):
return [hl.struct(**{self.KEY_FIELD[0]: key}) for key in (variant_keys or [])]

def _prefilter_entries_table(self, ht, **kwargs):
Expand Down Expand Up @@ -728,12 +728,15 @@ def _filter_rs_ids(self, ht, rs_ids):
rs_id_set = hl.set(rs_ids)
return ht.filter(rs_id_set.contains(ht.rsid))

def _parse_intervals(self, intervals, gene_ids=None, **kwargs):
parsed_variant_keys = self._parse_variant_keys(**kwargs)
def _parse_intervals(self, intervals, gene_ids=None, variant_keys=None, variant_ids=None, **kwargs):
parsed_variant_keys = self._parse_variant_keys(variant_keys)
if parsed_variant_keys:
self._load_table_kwargs['variant_ht'] = hl.Table.parallelize(parsed_variant_keys).key_by(*self.KEY_FIELD)
return intervals

if variant_ids:
intervals = [(chrom, pos, pos+1) for chrom, pos, _, _ in variant_ids]

is_x_linked = self._inheritance_mode == X_LINKED_RECESSIVE
if not (intervals or is_x_linked):
return intervals
Expand Down Expand Up @@ -1216,15 +1219,19 @@ def gene_counts(self):
ht.gene_ids, hl.struct(total=hl.agg.count(), families=hl.agg.counter(ht.families))
))

def _filter_variant_ids(self, ht, variant_ids):
return ht

def lookup_variants(self, variant_ids, include_project_data=False, **kwargs):
self._parse_intervals(intervals=None, variant_ids=variant_ids, variant_keys=variant_ids)
ht = self._read_table('annotations.ht', drop_globals=['versions'])
ht = self._filter_variant_ids(ht, variant_ids)
ht = ht.filter(hl.is_defined(ht[XPOS]))

annotation_fields = self.annotation_fields(include_genotype_overrides=False)
include_sample_annotations = False
if include_project_data:
ht, include_sample_annotations = self._add_project_lookup_data(ht, annotation_fields, **kwargs)
ht, include_sample_annotations = self._add_project_lookup_data(ht, annotation_fields, variant_ids=variant_ids, **kwargs)
if not include_sample_annotations:
annotation_fields = {
k: v for k, v in annotation_fields.items()
Expand Down
27 changes: 12 additions & 15 deletions hail_search/queries/mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,30 +394,25 @@ def _filter_variant_ids(self, ht, variant_ids):
variant_id_q = ht.alleles == [variant_ids[0][2], variant_ids[0][3]]
else:
variant_id_q = hl.any([
(ht.locus == hl.locus(chrom, pos, reference_genome=self.GENOME_VERSION)) &
(ht.locus == hl.locus(f'chr{chrom}' if self._should_add_chr_prefix() else chrom, pos, reference_genome=self.GENOME_VERSION)) &
(ht.alleles == [ref, alt])
for chrom, pos, ref, alt in variant_ids
])
return ht.filter(variant_id_q)

def _parse_variant_keys(self, variant_ids=None, **kwargs):
if not variant_ids:
return variant_ids
def _parse_variant_keys(self, variant_keys):
return None

return [
hl.struct(
locus=hl.locus(f'chr{chrom}' if self._should_add_chr_prefix() else chrom, pos, reference_genome=self.GENOME_VERSION),
alleles=[ref, alt],
) for chrom, pos, ref, alt in variant_ids
]

def _prefilter_entries_table(self, ht, parsed_intervals=None, exclude_intervals=False, **kwargs):
def _prefilter_entries_table(self, ht, parsed_intervals=None, exclude_intervals=False, variant_ids=None, **kwargs):
num_intervals = len(parsed_intervals or [])
if exclude_intervals and parsed_intervals:
ht = hl.filter_intervals(ht, parsed_intervals, keep=False)
elif num_intervals >= MAX_LOAD_INTERVALS:
ht = hl.filter_intervals(ht, parsed_intervals)

if variant_ids:
ht = self._filter_variant_ids(ht, variant_ids)

if '_n_partitions' not in self._load_table_kwargs and num_intervals > self._n_partitions:
ht = ht.naive_coalesce(self._n_partitions)

Expand Down Expand Up @@ -513,10 +508,11 @@ def _omim_sort(cls, r, omim_gene_set):
def _gene_rank_sort(cls, r, gene_ranks):
return [gene_ranks.get(r.selected_transcript.gene_id)] + super()._gene_rank_sort(r, gene_ranks)

def _add_project_lookup_data(self, ht, annotation_fields, *args, **kwargs):
def _add_project_lookup_data(self, ht, annotation_fields, *args, variant_ids=None, **kwargs):
# Get all the project-families for the looked up variant formatted as a dict of dicts:
# {<project_guid>: {<sample_type>: {<family_guid>: True}, <sample_type_2>: {<family_guid_2>: True}}, <project_guid_2>: ...}
lookup_ht = self._read_table('lookup.ht', skip_missing_field='project_stats')
lookup_ht = self._filter_variant_ids(lookup_ht, variant_ids)
if lookup_ht is None:
raise HTTPNotFound()
variant_projects = lookup_ht.aggregate(hl.agg.take(
Expand All @@ -536,11 +532,12 @@ def _add_project_lookup_data(self, ht, annotation_fields, *args, **kwargs):
lambda project_data: hl.dict(project_data.starmap(
lambda project_key, families: (project_key[1], families)
)))), 1)
)[0]
)

# Variant can be present in the lookup table with only ref calls, so is still not present in any projects
if not variant_projects:
if not (variant_projects and variant_projects[0]):
raise HTTPNotFound()
variant_projects = variant_projects[0]

self._has_both_sample_types = True
annotation_fields.update({
Expand Down
Loading