diff --git a/src/pheval/analyse/disease_prioritisation_analysis.py b/src/pheval/analyse/disease_prioritisation_analysis.py index 88bcbeb6..0b6d5c47 100644 --- a/src/pheval/analyse/disease_prioritisation_analysis.py +++ b/src/pheval/analyse/disease_prioritisation_analysis.py @@ -38,11 +38,16 @@ def assess_disease_prioritisation( for _i, row in df.iterrows(): result = ( self.conn.execute( - f"SELECT * FROM '{standardised_disease_result_path}' " - f"WHERE contains_entity_function(CAST(COALESCE(disease_identifier, '') AS VARCHAR)," - f" '{row['disease_identifier']}') " - f"OR contains_entity_function(CAST(COALESCE(disease_name, '') AS VARCHAR), " - f"'{row['disease_name']}')" + ( + f"SELECT * FROM '{standardised_disease_result_path}' " + f"WHERE contains_entity_function(CAST(COALESCE(disease_identifier, '') AS VARCHAR)," + f" '{row['disease_identifier']}') " + f"OR contains_entity_function(CAST(COALESCE(disease_name, '') AS VARCHAR), " + f"'{row['disease_name']}')" + ) + if standardised_disease_result_path.exists() + and standardised_disease_result_path.stat().st_size > 0 + else "SELECT NULL WHERE FALSE" ) .fetchdf() .to_dict(orient="records") @@ -56,9 +61,15 @@ def assess_disease_prioritisation( f'UPDATE {self.table_name} SET "{self.column}" = ? WHERE identifier = ?', (disease_match, primary_key), ) + elif len(result) == 0: + relevant_ranks.append(0) binary_classification_stats.add_classification( - self.db_connection.parse_table_into_dataclass( - str(standardised_disease_result_path), RankedPhEvalDiseaseResult + ( + self.db_connection.parse_table_into_dataclass( + str(standardised_disease_result_path), RankedPhEvalDiseaseResult + ) + if standardised_disease_result_path.exists() + else [] ), relevant_ranks, ) diff --git a/src/pheval/analyse/gene_prioritisation_analysis.py b/src/pheval/analyse/gene_prioritisation_analysis.py index 407ed82f..55b47ff9 100644 --- a/src/pheval/analyse/gene_prioritisation_analysis.py +++ b/src/pheval/analyse/gene_prioritisation_analysis.py @@ -36,11 +36,16 @@ def assess_gene_prioritisation( for _i, row in df.iterrows(): result = ( self.conn.execute( - f"SELECT * FROM '{standardised_gene_result_path}' " - f"WHERE contains_entity_function(CAST(COALESCE(gene_identifier, '') AS VARCHAR)," - f" '{row['gene_identifier']}') " - f"OR contains_entity_function(CAST(COALESCE(gene_symbol, '') AS VARCHAR), " - f"'{row['gene_symbol']}')" + ( + f"SELECT * FROM '{standardised_gene_result_path}' " + f"WHERE contains_entity_function(CAST(COALESCE(gene_identifier, '') AS VARCHAR), " + f"'{row['gene_identifier']}') " + f"OR contains_entity_function(CAST(COALESCE(gene_symbol, '') AS VARCHAR), " + f"'{row['gene_symbol']}')" + ) + if standardised_gene_result_path.exists() + and standardised_gene_result_path.stat().st_size > 0 + else "SELECT NULL WHERE FALSE" ) .fetchdf() .to_dict(orient="records") @@ -53,9 +58,15 @@ def assess_gene_prioritisation( f'UPDATE {self.table_name} SET "{self.column}" = ? WHERE identifier = ?', (gene_match, primary_key), ) + if not result: + relevant_ranks.append(0) binary_classification_stats.add_classification( - self.db_connection.parse_table_into_dataclass( - str(standardised_gene_result_path), RankedPhEvalGeneResult + ( + self.db_connection.parse_table_into_dataclass( + str(standardised_gene_result_path), RankedPhEvalGeneResult + ) + if standardised_gene_result_path.exists() + else [] ), relevant_ranks, ) diff --git a/src/pheval/analyse/variant_prioritisation_analysis.py b/src/pheval/analyse/variant_prioritisation_analysis.py index 39a6f639..0f26bf72 100644 --- a/src/pheval/analyse/variant_prioritisation_analysis.py +++ b/src/pheval/analyse/variant_prioritisation_analysis.py @@ -44,12 +44,16 @@ def assess_variant_prioritisation( ) result = ( self.conn.execute( - f"SELECT * FROM '{standardised_variant_result_path}' " - f"WHERE " - f"chromosome == '{causative_variant.chrom}' AND " - f"start == {causative_variant.pos} AND " - f"ref == '{causative_variant.ref}' AND " - f"alt == '{causative_variant.alt}'" + ( + f"SELECT * FROM '{standardised_variant_result_path}' " + f"WHERE " + f"chromosome == '{causative_variant.chrom}' AND " + f"start == {causative_variant.pos} AND " + f"ref == '{causative_variant.ref}' AND " + f"alt == '{causative_variant.alt}'" + ) + if standardised_variant_result_path.exists() + else "SELECT NULL WHERE FALSE" ) .fetchdf() .to_dict(orient="records") @@ -66,10 +70,15 @@ def assess_variant_prioritisation( f'UPDATE {self.table_name} SET "{self.column}" = ? WHERE identifier = ?', (variant_match, primary_key), ) - + elif len(result) == 0: + relevant_ranks.append(0) binary_classification_stats.add_classification( - self.db_connection.parse_table_into_dataclass( - str(standardised_variant_result_path), RankedPhEvalVariantResult + ( + self.db_connection.parse_table_into_dataclass( + str(standardised_variant_result_path), RankedPhEvalVariantResult + ) + if standardised_variant_result_path.exists() + else [] ), relevant_ranks, ) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 28f8b6dd..7e6bbc8f 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1,7 +1,7 @@ import unittest from copy import copy from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import duckdb @@ -130,9 +130,13 @@ def test_assess_gene_with_threshold_meets_cutoff(self): ) def test_assess_gene_prioritisation_no_threshold(self): + mock_path = MagicMock(spec=Path) + mock_path.exists.return_value = True + mock_path.stat.return_value.st_size = 100 + mock_path.__str__.return_value = "result" self.db_connector.add_contains_function() self.assess_gene_prioritisation.assess_gene_prioritisation( - "result", + mock_path, Path("/path/to/phenopacket_1.json"), self.binary_classification_stats, ) @@ -150,7 +154,7 @@ def test_assess_gene_prioritisation_no_threshold(self): true_positives=1, true_negatives=3, false_positives=0, - false_negatives=0, + false_negatives=1, labels=[1, 0, 0, 0], scores=[0.8764, 0.5777, 0.5777, 0.3765], ), @@ -282,9 +286,13 @@ def test_assess_variant_with_threshold_meets_cutoff(self): ) def test_assess_variant_prioritisation(self): + mock_path = MagicMock(spec=Path) + mock_path.exists.return_value = True + mock_path.stat.return_value.st_size = 100 + mock_path.__str__.return_value = "result" self.db_connector.add_contains_function() self.assess_variant_prioritisation.assess_variant_prioritisation( - "result", + mock_path, Path("/path/to/phenopacket_1.json"), self.binary_classification_stats, ) @@ -318,7 +326,7 @@ def test_assess_variant_prioritisation(self): true_positives=0, true_negatives=0, false_positives=2, - false_negatives=1, + false_negatives=2, labels=[0, 0, 1], scores=[0.0484, 0.0484, 0.0484], ), @@ -439,9 +447,13 @@ def test_assess_disease_with_threshold_meets_cutoff(self): ) def test_assess_disease_prioritisation(self): + mock_path = MagicMock(spec=Path) + mock_path.exists.return_value = True + mock_path.stat.return_value.st_size = 100 + mock_path.__str__.return_value = "result" self.db_connector.add_contains_function() self.assess_disease_prioritisation.assess_disease_prioritisation( - "result", + mock_path, Path("/path/to/phenopacket_1.json"), self.binary_classification_stats, )