Skip to content

Commit

Permalink
Merge pull request #321 from monarch-initiative/320-index-error-in-cr…
Browse files Browse the repository at this point in the history
…eate-spiked-vcf-command

Handle IndexError for spiking variants into a VCF
  • Loading branch information
yaseminbridges authored May 7, 2024
2 parents 5e0f43f + 79fb20c commit f10b502
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
34 changes: 25 additions & 9 deletions src/pheval/prepare/create_spiked_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,22 +328,35 @@ def construct_variant_entry(self, proband_variant_data: ProbandCausativeVariant)
genotype_codes[proband_variant_data.genotype.lower()] + "\n",
]

def construct_vcf_records(self) -> List[str]:
def construct_vcf_records(self, template_vcf_name: str) -> List[str]:
"""
Construct updated VCF records by inserting spiked variants into the correct positions within the VCF.
Args:
template_vcf_name (str): Name of the template VCF file.
Returns:
List[str]: Updated VCF records containing the spiked variants.
"""
updated_vcf_records = copy(self.vcf_contents)
for variant in self.proband_causative_variants:
variant = self.construct_variant_entry(variant)
variant_entry_position = [
variant_entry = self.construct_variant_entry(variant)
matching_indices = [
i
for i, val in enumerate(updated_vcf_records)
if val.split("\t")[0] == variant[0] and int(val.split("\t")[1]) < int(variant[1])
][-1] + 1
updated_vcf_records.insert(variant_entry_position, "\t".join(variant))
if val.split("\t")[0] == variant_entry[0]
and int(val.split("\t")[1]) < int(variant_entry[1])
]
if matching_indices:
variant_entry_position = matching_indices[-1] + 1
else:
info_log.warning(
f"Could not find entry position for {variant.variant.chrom}-{variant.variant.pos}-"
f"{variant.variant.ref}-{variant.variant.alt} in {template_vcf_name}, "
"inserting at end of VCF contents."
)
variant_entry_position = len(updated_vcf_records)
updated_vcf_records.insert(variant_entry_position, "\t".join(variant_entry))
return updated_vcf_records

def construct_header(self, updated_vcf_records: List[str]) -> List[str]:
Expand All @@ -365,14 +378,17 @@ def construct_header(self, updated_vcf_records: List[str]) -> List[str]:
updated_vcf_file.append(text)
return updated_vcf_file

def construct_vcf(self) -> List[str]:
def construct_vcf(self, template_vcf_name: str) -> List[str]:
"""
Construct the entire spiked VCF file by incorporating the spiked variants into the VCF.
Args:
template_vcf_name (str): Name of the template VCF file.
Returns:
List[str]: The complete spiked VCF file content as a list of strings.
"""
return self.construct_header(self.construct_vcf_records())
return self.construct_header(self.construct_vcf_records(template_vcf_name))


class VcfWriter:
Expand Down Expand Up @@ -454,7 +470,7 @@ def spike_vcf_contents(
chosen_template_vcf.vcf_contents,
phenopacket_causative_variants,
chosen_template_vcf.vcf_header,
).construct_vcf(),
).construct_vcf(chosen_template_vcf.vcf_file_name),
)


Expand Down
23 changes: 21 additions & 2 deletions tests/test_create_spiked_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def setUpClass(cls) -> None:
],
VcfHeader("TEMPLATE", "GRCh37", True),
)
cls.vcf_spiker_new_variant_chrom = VcfSpiker(
hg19_vcf,
[
ProbandCausativeVariant(
"TEST1",
"GRCh37",
GenomicVariant("X", 123450, "G", "A"),
"heterozygous",
)
],
VcfHeader("TEMPLATE", "GRCh37", True),
)
cls.vcf_spiker_multiple_variants = VcfSpiker(
hg19_vcf,
[
Expand Down Expand Up @@ -324,12 +336,12 @@ def test_construct_variant_structural_variant(self):

def test_construct_vcf_records_single_variant(self):
self.assertEqual(
self.vcf_spiker.construct_vcf_records()[40],
self.vcf_spiker.construct_vcf_records("template.vcf")[40],
"chr1\t886190\t.\tG\tA\t100\tPASS\t.\t" "GT\t0/1\n",
)

def test_construct_vcf_records_multiple_variants(self):
updated_records = self.vcf_spiker_multiple_variants.construct_vcf_records()
updated_records = self.vcf_spiker_multiple_variants.construct_vcf_records("template.vcf")
self.assertEqual(
updated_records[40],
"chr1\t886190\t.\tG\tA\t100\tPASS\t.\t" "GT\t0/1\n",
Expand All @@ -339,6 +351,13 @@ def test_construct_vcf_records_multiple_variants(self):
"chr3\t61580860\t.\tG\tA\t100\tPASS\t.\t" "GT\t1/1\n",
)

def test_construct_vcf_records_new_variant_pos(self):
updated_records = self.vcf_spiker_new_variant_chrom.construct_vcf_records("template.vcf")
self.assertEqual(
updated_records[48],
"chrX\t123450\t.\tG\tA\t100\tPASS\t.\tGT\t0/1\n",
)

def test_construct_header(self):
self.assertEqual(
[
Expand Down

0 comments on commit f10b502

Please sign in to comment.