Skip to content

Commit

Permalink
Merge pull request #4005 from broadinstitute/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
hanars authored Apr 4, 2024
2 parents ce90cf0 + ac4a471 commit 6ecc75a
Show file tree
Hide file tree
Showing 52 changed files with 1,584 additions and 726 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: hail search dataset version release
name: hail search persistent volume snapshot release
on:
workflow_dispatch:
inputs:
Expand All @@ -8,12 +8,12 @@ on:
- dev
- prod
reference_genome:
type: choice
description: Reference Genome
options:
- GRCh37
- GRCh38
required: true
type: choice
description: Reference Genome
options:
- GRCh37
- GRCh38
required: true
dataset_type:
type: choice
description: Dataset Type
Expand All @@ -25,6 +25,8 @@ on:
required: true
version:
required: true
volume_handle:
required: true

jobs:
helm_update:
Expand All @@ -45,6 +47,12 @@ jobs:
cmd: >
yq -i '.global.hail_search.datasetVersions.${{ inputs.reference_genome }}/${{ inputs.dataset_type }} = "${{ inputs.version }}"' charts/broad-seqr/values-${{ inputs.environment }}.yaml
- name: update volume handle in the broad-seqr chart
uses: mikefarah/yq@v4.22.1
with:
cmd: >
yq -i '.hail-search.persistentVolume.volumeHandle = "${{ inputs.volume_handle }}"' charts/broad-seqr/values-${{ inputs.environment }}.yaml
- name: Commit and Push changes
uses: Andro999b/push@v1.3
with:
Expand All @@ -53,4 +61,4 @@ jobs:
github_token: ${{ secrets.SEQR_VERSION_UPDATE_TOKEN }}
author_email: ${{ github.actor }}@users.noreply.github.com
author_name: tgg-automation
message: "Updating ${{ inputs.reference_genome }}/${{ inputs.dataset_type }} dataset version to ${{ inputs.version }}"
message: "Updating ${{ inputs.environment }} ${{ inputs.reference_genome }}/${{ inputs.dataset_type }} dataset version to ${{ inputs.version }} and volume handle to ${{ inputs.volume_handle }} "
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## dev

## 4/4/24
* Add ability to import project metadata from gregor metadata
* Only enabled for a project if tag is first created via
```
./manage.py add_project_tag --name="GREGoR Finding" --order=0.5 --color=#c25fc4 --project=<project>
```
* Support FRASER2 data (REQUIRES DB MIGRATION)
* Add solve_status to Individual model (REQUIRES DB MIGRATION)
* Update data deployment for hail backend to disk snapshots
## 3/13/24
* Add "Probably Solved" analysis status (REQUIRES DB MIGRATION)
Expand Down
188 changes: 94 additions & 94 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ def _read_table(self, path, drop_globals=None, use_ssd_dir=False, skip_missing_f
ht = self._query_table_annotations(self._load_table_kwargs['variant_ht'], table_path)
if skip_missing_field and not ht.any(hl.is_defined(ht[skip_missing_field])):
return None
ht_globals = hl.read_table(table_path).globals
ht_globals = hl.read_table(table_path).index_globals()
if drop_globals:
ht_globals = ht_globals.drop(*drop_globals)
return ht.annotate_globals(**hl.eval(ht_globals))
return ht.annotate_globals(**ht_globals)
return hl.read_table(table_path, **self._load_table_kwargs)

@staticmethod
Expand All @@ -278,39 +278,48 @@ def _parse_sample_data(self, sample_data):
logger.info(f'Loading {self.DATA_TYPE} data for {num_families} families in {len(project_samples)} projects')
return project_samples, num_families

def _load_filtered_project_hts(self, project_samples, skip_all_missing=False, **kwargs):
def _load_filtered_project_hts(self, project_samples, skip_all_missing=False, n_partitions=MAX_PARTITIONS, **kwargs):
if len(project_samples) == 1:
project_guid = list(project_samples.keys())[0]
project_ht = self._read_table(f'projects/{project_guid}.ht', use_ssd_dir=True)
return self._filter_entries_table(project_ht, project_samples[project_guid], **kwargs)

# Need to chunk tables or else evaluating table globals throws LineTooLong exception
# However, minimizing number of chunks minimizes number of aggregations/ evals and improves performance
# Adapted from https://discuss.hail.is/t/importing-many-sample-specific-vcfs/2002/8
chunk_size = 64
filtered_project_hts = []
exception_messages = set()
for i, (project_guid, project_sample_data) in enumerate(project_samples.items()):
filtered_comp_het_project_hts = []
project_hts = []
sample_data = {}
for project_guid, project_sample_data in project_samples.items():
project_ht = self._read_table(
f'projects/{project_guid}.ht',
use_ssd_dir=True,
skip_missing_field='family_entries' if skip_all_missing or i > 0 else None,
skip_missing_field='family_entries' if skip_all_missing else None,
)
if project_ht is None:
continue
try:
filtered_project_hts.append(
(*self._filter_entries_table(project_ht, project_sample_data, **kwargs), len(project_sample_data))
project_hts.append(project_ht.select_globals('sample_type', 'family_guids', 'family_samples'))
sample_data.update(project_sample_data)

if len(project_hts) >= chunk_size:
self._filter_merged_project_hts(
project_hts, sample_data, filtered_project_hts, filtered_comp_het_project_hts, n_partitions, **kwargs,
)
except HTTPBadRequest as e:
exception_messages.add(e.reason)
project_hts = []
sample_data = {}

self._filter_merged_project_hts(
project_hts, sample_data, filtered_project_hts, filtered_comp_het_project_hts, n_partitions, **kwargs,
)

if exception_messages:
raise HTTPBadRequest(reason='; '.join(exception_messages))
ht = self._merge_project_hts(filtered_project_hts, n_partitions)
comp_het_ht = self._merge_project_hts(filtered_comp_het_project_hts, n_partitions)

if len(project_samples) > len(filtered_project_hts):
logger.info(f'Found {len(filtered_project_hts)} {self.DATA_TYPE} projects with matched entries')
return filtered_project_hts
return ht, comp_het_ht

def import_filtered_table(self, project_samples, num_families, intervals=None, **kwargs):
use_annotations_ht_first = len(project_samples) > 1 and (kwargs.get('parsed_intervals') or kwargs.get('padded_interval'))
if use_annotations_ht_first:
# For multi-project interval search, faster to first read and filter the annotation table and then add entries
ht = self._read_table('annotations.ht')
ht = self._filter_annotated_table(ht, **kwargs, is_comp_het=self._has_comp_het_search)
self._load_table_kwargs['variant_ht'] = ht.select()

if num_families == 1:
family_sample_data = list(project_samples.values())[0]
family_guid = list(family_sample_data.keys())[0]
Expand All @@ -321,60 +330,50 @@ def import_filtered_table(self, project_samples, num_families, intervals=None, *
)
families_ht, comp_het_families_ht = self._filter_entries_table(family_ht, family_sample_data, **kwargs)
else:
filtered_project_hts = self._load_filtered_project_hts(project_samples, **kwargs)
families_ht, comp_het_families_ht, num_families = filtered_project_hts[0]
main_ht = comp_het_families_ht if families_ht is None else families_ht
entry_type = main_ht.family_entries.dtype.element_type
num_projects_added = 1
for project_ht, comp_het_project_ht, num_project_families in filtered_project_hts[1:]:
ht_added = False
if families_ht is not None:
families_ht, ht_added = self._add_project_ht(families_ht, project_ht, default=hl.empty_array(entry_type))
if comp_het_families_ht is not None:
comp_het_families_ht, ht_added = self._add_project_ht(
comp_het_families_ht, comp_het_project_ht,
default=hl.range(num_families).map(lambda i: hl.missing(entry_type)),
default_1=hl.range(num_project_families).map(lambda i: hl.missing(entry_type)),
)
if ht_added:
num_families += num_project_families
num_projects_added += 1

if len(filtered_project_hts) > num_projects_added:
logger.info(f'Found {num_projects_added} {self.DATA_TYPE} projects with matched entries')
families_ht, comp_het_families_ht = self._load_filtered_project_hts(project_samples, **kwargs)

if comp_het_families_ht is not None:
self._comp_het_ht = self._query_table_annotations(comp_het_families_ht, self._get_table_path('annotations.ht'))
if not use_annotations_ht_first:
self._comp_het_ht = self._filter_annotated_table(self._comp_het_ht, is_comp_het=True, **kwargs)
self._comp_het_ht = self._filter_annotated_table(self._comp_het_ht, is_comp_het=True, **kwargs)
self._comp_het_ht = self._filter_compound_hets()

if families_ht is not None:
self._ht = self._query_table_annotations(families_ht, self._get_table_path('annotations.ht'))
if not use_annotations_ht_first:
self._ht = self._filter_annotated_table(self._ht, **kwargs)
elif self._has_comp_het_search:
self._ht = self._filter_by_annotations(self._ht, **(kwargs.get('parsed_annotations') or {}))
self._ht = self._filter_annotated_table(self._ht, **kwargs)

def _filter_merged_project_hts(self, project_hts, sample_data, filtered_project_hts, filtered_comp_het_project_hts, n_partitions, **kwargs):
if not project_hts:
return
ht = self._merge_project_hts(project_hts, n_partitions, include_all_globals=True)
ht, comp_het_ht = self._filter_entries_table(ht, sample_data, **kwargs)
if ht is not None:
filtered_project_hts.append(ht)
if comp_het_ht is not None:
filtered_comp_het_project_hts.append(comp_het_ht)

def _add_project_ht(self, families_ht, project_ht, default, default_1=None):
if default_1 is None:
default_1 = default

if not project_ht.any(project_ht.family_entries.any(hl.is_defined)):
return families_ht, False

families_ht = families_ht.join(project_ht, how='outer')
families_ht = families_ht.select_globals(
family_guids=families_ht.family_guids.extend(families_ht.family_guids_1)
@staticmethod
def _merge_project_hts(project_hts, n_partitions, include_all_globals=False):
if not project_hts:
return None
ht = hl.Table.multi_way_zip_join(project_hts, 'project_entries', 'project_globals')
ht = ht.repartition(n_partitions)
ht = ht.transmute(
filters=ht.project_entries.fold(lambda f, x: f.union(x.filters), hl.empty_set(hl.tstr)),
family_entries=hl.enumerate(ht.project_entries).starmap(lambda i, x: hl.or_else(
x.family_entries,
ht.project_globals[i].family_guids.map(lambda f: hl.missing(x.family_entries.dtype.element_type)),
)).flatmap(lambda x: x),
)
return families_ht.select(
filters=families_ht.filters.union(families_ht.filters_1),
family_entries=hl.bind(
lambda a1, a2: a1.extend(a2),
hl.or_else(families_ht.family_entries, default),
hl.or_else(families_ht.family_entries_1, default_1),
),
), True
global_expressions = {
'family_guids': ht.project_globals.flatmap(lambda x: x.family_guids),
}
if include_all_globals:
global_expressions.update({
'sample_types': ht.project_globals.flatmap(lambda x: x.family_guids.map(lambda _: x.sample_type)),
'family_samples': hl.dict(ht.project_globals.flatmap(lambda x: x.family_samples.items())),
})

return ht.transmute_globals(**global_expressions)

def _filter_entries_table(self, ht, sample_data, inheritance_filter=None, quality_filter=None, **kwargs):
ht = self._prefilter_entries_table(ht, **kwargs)
Expand Down Expand Up @@ -411,13 +410,16 @@ def _add_entry_sample_families(self, ht, sample_data):
if missing_family_samples:
missing_samples.update(missing_family_samples)
else:
family_index = ht_globals.family_guids.index(family_guid)
family_entry_data = {
'sampleType': self._get_sample_type(family_index, ht_globals),
'familyGuid': family_guid,
}
sample_index_data = [
(ht_family_samples.index(s['sample_id']), self._sample_entry_data(s, family_guid, ht_globals))
(ht_family_samples.index(s['sample_id']), {**family_entry_data, **self._sample_entry_data(s)})
for s in samples
]
family_sample_index_data.append(
(ht_globals.family_guids.index(family_guid), sample_index_data)
)
family_sample_index_data.append((family_index, sample_index_data))
self.entry_samples_by_family_guid[family_guid] = [s['sampleId'] for _, s in sample_index_data]

if missing_samples:
Expand All @@ -437,18 +439,18 @@ def _add_entry_sample_families(self, ht, sample_data):
return ht, sorted_family_sample_data

@classmethod
def _sample_entry_data(cls, sample, family_guid, ht_globals):
def _sample_entry_data(cls, sample):
return dict(
sampleId=sample['sample_id'],
sampleType=cls._get_sample_type(ht_globals),
individualGuid=sample['individual_guid'],
familyGuid=family_guid,
affected_id=AFFECTED_ID_MAP.get(sample['affected']),
is_male='sex' in sample and sample['sex'] == MALE,
)

@classmethod
def _get_sample_type(cls, ht_globals):
def _get_sample_type(cls, family_index, ht_globals):
if 'sample_types' in ht_globals:
return ht_globals.sample_types[family_index]
return ht_globals.sample_type

def _filter_inheritance(self, ht, inheritance_filter, sorted_family_sample_data):
Expand Down Expand Up @@ -1061,15 +1063,20 @@ def gene_counts(self):
ht.gene_ids, hl.struct(total=hl.agg.count(), families=hl.agg.counter(ht.families))
))

def lookup_variants(self, variant_ids, annotation_fields=None):
def lookup_variants(self, variant_ids, project_samples=None, annotation_overrides=None):
self._parse_intervals(intervals=None, variant_ids=variant_ids, variant_keys=variant_ids)
ht = self._read_table('annotations.ht', drop_globals=['paths', 'versions'])
self._load_table_kwargs['_n_partitions'] = 1
ht = ht.filter(hl.is_defined(ht[XPOS]))

if not annotation_fields:
annotation_fields = self.annotation_fields(include_genotype_overrides=False)
if project_samples:
projects_ht, _ = self._load_filtered_project_hts(project_samples, skip_all_missing=True, n_partitions=1)
ht = ht.annotate(**projects_ht[ht.key])
elif annotation_overrides:
annotation_fields.update(annotation_overrides)
else:
annotation_fields = {
k: v for k, v in self.annotation_fields(include_genotype_overrides=False).items()
k: v for k, v in annotation_fields.items()
if k not in {FAMILY_GUID_FIELD, GENOTYPES_FIELD, 'genotypeFilters'}
}

Expand All @@ -1078,24 +1085,17 @@ def lookup_variants(self, variant_ids, annotation_fields=None):
return formatted.aggregate(hl.agg.take(formatted.row, len(variant_ids)))

def lookup_variant(self, variant_id, sample_data=None):
annotation_fields = self.annotation_fields(include_genotype_overrides=False)
entry_annotations = {k: annotation_fields[k] for k in [FAMILY_GUID_FIELD, GENOTYPES_FIELD]}
annotation_fields.update({
annotation_overrides = {
FAMILY_GUID_FIELD: lambda ht: hl.empty_array(hl.tstr),
GENOTYPES_FIELD: lambda ht: hl.empty_dict(hl.tstr, hl.tstr),
'genotypeFilters': lambda ht: hl.str(''),
})

variants = self.lookup_variants([variant_id], annotation_fields=annotation_fields)
if not variants:
raise HTTPNotFound()
variant = dict(variants[0])
}

project_samples = None
if sample_data:
project_samples, _ = self._parse_sample_data(sample_data)
for pht, _, _ in self._load_filtered_project_hts(project_samples, skip_all_missing=True):
project_entries = pht.aggregate(hl.agg.take(hl.struct(**{k: v(pht) for k, v in entry_annotations.items()}), 1))
variant[FAMILY_GUID_FIELD] += project_entries[0][FAMILY_GUID_FIELD]
variant[GENOTYPES_FIELD].update(project_entries[0][GENOTYPES_FIELD])

return variant
variants = self.lookup_variants([variant_id], project_samples=project_samples, annotation_overrides=annotation_overrides)
if not variants:
raise HTTPNotFound()
return dict(variants[0])
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ oauthlib==3.2.2
# social-auth-core
openpyxl==3.1.1
# via -r requirements.in
pillow==10.2.0
pillow==10.3.0
# via -r requirements.in
protobuf==3.20.2
# via
Expand Down
Loading

0 comments on commit 6ecc75a

Please sign in to comment.