diff --git a/hail_search/queries/base.py b/hail_search/queries/base.py index db7a02c8e2..3f9784b80d 100644 --- a/hail_search/queries/base.py +++ b/hail_search/queries/base.py @@ -697,7 +697,7 @@ def passes_quality_field(gt): def _passes_vcf_filters(gt): return hl.is_missing(gt.filters) | (gt.filters.length() < 1) - def _parse_variant_keys(self, variant_keys=None, **kwargs): + def _parse_variant_keys(self, variant_keys): return [hl.struct(**{self.KEY_FIELD[0]: key}) for key in (variant_keys or [])] def _prefilter_entries_table(self, ht, **kwargs): @@ -728,12 +728,15 @@ def _filter_rs_ids(self, ht, rs_ids): rs_id_set = hl.set(rs_ids) return ht.filter(rs_id_set.contains(ht.rsid)) - def _parse_intervals(self, intervals, gene_ids=None, **kwargs): - parsed_variant_keys = self._parse_variant_keys(**kwargs) + def _parse_intervals(self, intervals, gene_ids=None, variant_keys=None, variant_ids=None, **kwargs): + parsed_variant_keys = self._parse_variant_keys(variant_keys) if parsed_variant_keys: self._load_table_kwargs['variant_ht'] = hl.Table.parallelize(parsed_variant_keys).key_by(*self.KEY_FIELD) return intervals + if variant_ids: + intervals = [(chrom, pos, pos+1) for chrom, pos, _, _ in variant_ids] + is_x_linked = self._inheritance_mode == X_LINKED_RECESSIVE if not (intervals or is_x_linked): return intervals @@ -1216,15 +1219,19 @@ def gene_counts(self): ht.gene_ids, hl.struct(total=hl.agg.count(), families=hl.agg.counter(ht.families)) )) + def _filter_variant_ids(self, ht, variant_ids): + return ht + def lookup_variants(self, variant_ids, include_project_data=False, **kwargs): self._parse_intervals(intervals=None, variant_ids=variant_ids, variant_keys=variant_ids) ht = self._read_table('annotations.ht', drop_globals=['versions']) + ht = self._filter_variant_ids(ht, variant_ids) ht = ht.filter(hl.is_defined(ht[XPOS])) annotation_fields = self.annotation_fields(include_genotype_overrides=False) include_sample_annotations = False if include_project_data: - ht, include_sample_annotations = self._add_project_lookup_data(ht, annotation_fields, **kwargs) + ht, include_sample_annotations = self._add_project_lookup_data(ht, annotation_fields, variant_ids=variant_ids, **kwargs) if not include_sample_annotations: annotation_fields = { k: v for k, v in annotation_fields.items() diff --git a/hail_search/queries/mito.py b/hail_search/queries/mito.py index 0cfe9fbdbb..ebadd65883 100644 --- a/hail_search/queries/mito.py +++ b/hail_search/queries/mito.py @@ -394,30 +394,25 @@ def _filter_variant_ids(self, ht, variant_ids): variant_id_q = ht.alleles == [variant_ids[0][2], variant_ids[0][3]] else: variant_id_q = hl.any([ - (ht.locus == hl.locus(chrom, pos, reference_genome=self.GENOME_VERSION)) & + (ht.locus == hl.locus(f'chr{chrom}' if self._should_add_chr_prefix() else chrom, pos, reference_genome=self.GENOME_VERSION)) & (ht.alleles == [ref, alt]) for chrom, pos, ref, alt in variant_ids ]) return ht.filter(variant_id_q) - def _parse_variant_keys(self, variant_ids=None, **kwargs): - if not variant_ids: - return variant_ids + def _parse_variant_keys(self, variant_keys): + return None - return [ - hl.struct( - locus=hl.locus(f'chr{chrom}' if self._should_add_chr_prefix() else chrom, pos, reference_genome=self.GENOME_VERSION), - alleles=[ref, alt], - ) for chrom, pos, ref, alt in variant_ids - ] - - def _prefilter_entries_table(self, ht, parsed_intervals=None, exclude_intervals=False, **kwargs): + def _prefilter_entries_table(self, ht, parsed_intervals=None, exclude_intervals=False, variant_ids=None, **kwargs): num_intervals = len(parsed_intervals or []) if exclude_intervals and parsed_intervals: ht = hl.filter_intervals(ht, parsed_intervals, keep=False) elif num_intervals >= MAX_LOAD_INTERVALS: ht = hl.filter_intervals(ht, parsed_intervals) + if variant_ids: + ht = self._filter_variant_ids(ht, variant_ids) + if '_n_partitions' not in self._load_table_kwargs and num_intervals > self._n_partitions: ht = ht.naive_coalesce(self._n_partitions) @@ -513,10 +508,11 @@ def _omim_sort(cls, r, omim_gene_set): def _gene_rank_sort(cls, r, gene_ranks): return [gene_ranks.get(r.selected_transcript.gene_id)] + super()._gene_rank_sort(r, gene_ranks) - def _add_project_lookup_data(self, ht, annotation_fields, *args, **kwargs): + def _add_project_lookup_data(self, ht, annotation_fields, *args, variant_ids=None, **kwargs): # Get all the project-families for the looked up variant formatted as a dict of dicts: # {: {: {: True}, : {: True}}, : ...} lookup_ht = self._read_table('lookup.ht', skip_missing_field='project_stats') + lookup_ht = self._filter_variant_ids(lookup_ht, variant_ids) if lookup_ht is None: raise HTTPNotFound() variant_projects = lookup_ht.aggregate(hl.agg.take( @@ -536,11 +532,12 @@ def _add_project_lookup_data(self, ht, annotation_fields, *args, **kwargs): lambda project_data: hl.dict(project_data.starmap( lambda project_key, families: (project_key[1], families) )))), 1) - )[0] + ) # Variant can be present in the lookup table with only ref calls, so is still not present in any projects - if not variant_projects: + if not (variant_projects and variant_projects[0]): raise HTTPNotFound() + variant_projects = variant_projects[0] self._has_both_sample_types = True annotation_fields.update({