diff --git a/src/pheval/analyse/assess_prioritisation_base.py b/src/pheval/analyse/assess_prioritisation_base.py new file mode 100644 index 00000000..b6fcfbd1 --- /dev/null +++ b/src/pheval/analyse/assess_prioritisation_base.py @@ -0,0 +1,108 @@ +from typing import Union + +from pheval.analyse.benchmark_db_manager import BenchmarkDBManager +from pheval.post_processing.post_processing import ( + RankedPhEvalDiseaseResult, + RankedPhEvalGeneResult, + RankedPhEvalVariantResult, +) + + +class AssessPrioritisationBase: + def __init__( + self, + db_connection: BenchmarkDBManager, + table_name: str, + column: str, + threshold: float, + score_order: str, + ): + """ + Initialise AssessPrioritisationBase class + + Args: + db_connection (BenchmarkDBManager): DB connection. + table_name (str): Table name. + column (str): Column name. + threshold (float): Threshold for scores + score_order (str): Score order for results, either ascending or descending + + """ + self.threshold = threshold + self.score_order = score_order + self.db_connection = db_connection + self.conn = db_connection.conn + self.column = column + self.table_name = table_name + db_connection.add_column_integer_default( + table_name=table_name, column=self.column, default=0 + ) + + def _assess_with_threshold_ascending_order( + self, + result_entry: Union[ + RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult + ], + ) -> int: + """ + Record the prioritisation rank if it meets the ascending order threshold. + + + Args: + result_entry (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]): + Ranked PhEval result entry + + Returns: + int: Recorded prioritisation rank + """ + if float(self.threshold) > float(result_entry.score): + return result_entry.rank + else: + return 0 + + def _assess_with_threshold( + self, + result_entry: Union[ + RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult + ], + ) -> int: + """ + Record the prioritisation rank if it meets the score threshold. + + Args: + result_entry (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]): + Ranked PhEval result entry + + Returns: + int: Recorded prioritisation rank + """ + if float(self.threshold) < float(result_entry.score): + return result_entry.rank + else: + return 0 + + def _record_matched_entity( + self, + standardised_result: Union[ + RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult + ], + ) -> int: + """ + Return the rank result - handling the specification of a threshold. + Args: + standardised_result (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]): + Ranked PhEval disease result entry + + Returns: + int: Recorded entity prioritisation rank + """ + if float(self.threshold) == 0.0: + return standardised_result.rank + else: + return ( + self._assess_with_threshold(standardised_result) + if self.score_order != "ascending" + else self._assess_with_threshold_ascending_order( + standardised_result, + ) + ) diff --git a/src/pheval/analyse/disease_prioritisation_analysis.py b/src/pheval/analyse/disease_prioritisation_analysis.py index 23f95d9b..88bcbeb6 100644 --- a/src/pheval/analyse/disease_prioritisation_analysis.py +++ b/src/pheval/analyse/disease_prioritisation_analysis.py @@ -1,5 +1,6 @@ from pathlib import Path +from pheval.analyse.assess_prioritisation_base import AssessPrioritisationBase from pheval.analyse.benchmark_db_manager import BenchmarkDBManager from pheval.analyse.benchmarking_data import BenchmarkRunResults from pheval.analyse.binary_classification_stats import BinaryClassificationStats @@ -9,108 +10,9 @@ from pheval.utils.file_utils import all_files -class AssessDiseasePrioritisation: +class AssessDiseasePrioritisation(AssessPrioritisationBase): """Class for assessing disease prioritisation based on thresholds and scoring orders.""" - def __init__( - self, - db_connection: BenchmarkDBManager, - table_name: str, - column: str, - threshold: float, - score_order: str, - ): - """ - Initialise AssessDiseasePrioritisation class - - Args: - db_connection (BenchmarkDBManager): Database connection - table_name (str): Table name - column (Path): Column name - threshold (float): Threshold for scores - score_order (str): Score order for results, either ascending or descending - - """ - self.threshold = threshold - self.score_order = score_order - self.db_connection = db_connection - self.conn = db_connection.conn - self.column = column - self.table_name = table_name - db_connection.add_column_integer_default( - table_name=table_name, column=self.column, default=0 - ) - - def _assess_disease_with_threshold_ascending_order( - self, - result_entry: RankedPhEvalDiseaseResult, - ) -> int: - """ - Record the disease prioritisation rank if it meets the ascending order threshold. - - This method checks if the disease prioritisation rank meets the ascending order threshold. - If the score of the result entry is less than the threshold, it records the disease rank. - - Args: - result_entry (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry - - Returns: - int: Recorded disease prioritisation rank - """ - if float(self.threshold) > float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _assess_disease_with_threshold( - self, - result_entry: RankedPhEvalDiseaseResult, - ) -> int: - """ - Record the disease prioritisation rank if it meets the score threshold. - - This method checks if the disease prioritisation rank meets the score threshold. - If the score of the result entry is greater than the threshold, it records the disease rank. - - Args: - result_entry (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry - - Returns: - int: Recorded disease prioritisation rank - """ - if float(self.threshold) < float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _record_matched_disease( - self, - standardised_disease_result: RankedPhEvalDiseaseResult, - ) -> int: - """ - Return the disease rank result - handling the specification of a threshold. - - This method determines and returns the disease rank result based on the specified threshold - and score order. If the threshold is 0.0, it records the disease rank directly. - Otherwise, it assesses the disease with the threshold based on the score order. - - Args: - standardised_disease_result (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry - - Returns: - int: Recorded disease prioritisation rank - """ - if float(self.threshold) == 0.0: - return standardised_disease_result.rank - else: - return ( - self._assess_disease_with_threshold(standardised_disease_result) - if self.score_order != "ascending" - else self._assess_disease_with_threshold_ascending_order( - standardised_disease_result, - ) - ) - def assess_disease_prioritisation( self, standardised_disease_result_path: Path, @@ -147,7 +49,7 @@ def assess_disease_prioritisation( ) if len(result) > 0: - disease_match = self._record_matched_disease(RankedPhEvalDiseaseResult(**result[0])) + disease_match = self._record_matched_entity(RankedPhEvalDiseaseResult(**result[0])) relevant_ranks.append(disease_match) primary_key = f"{phenopacket_path.name}-{row['disease_identifier']}" self.conn.execute( diff --git a/src/pheval/analyse/gene_prioritisation_analysis.py b/src/pheval/analyse/gene_prioritisation_analysis.py index 61edd319..407ed82f 100644 --- a/src/pheval/analyse/gene_prioritisation_analysis.py +++ b/src/pheval/analyse/gene_prioritisation_analysis.py @@ -1,5 +1,6 @@ from pathlib import Path +from pheval.analyse.assess_prioritisation_base import AssessPrioritisationBase from pheval.analyse.benchmark_db_manager import BenchmarkDBManager from pheval.analyse.benchmarking_data import BenchmarkRunResults from pheval.analyse.binary_classification_stats import BinaryClassificationStats @@ -9,95 +10,9 @@ from pheval.utils.file_utils import all_files -class AssessGenePrioritisation: +class AssessGenePrioritisation(AssessPrioritisationBase): """Class for assessing gene prioritisation based on thresholds and scoring orders.""" - def __init__( - self, - db_connection: BenchmarkDBManager, - table_name: str, - column: str, - threshold: float, - score_order: str, - ): - """ - Initialise AssessGenePrioritisation class. - - Args: - db_connection (BenchmarkDBManager): Database connection - table_name (str): Table name - column (Path): Column name - threshold (float): Threshold for scores - score_order (str): Score order for results, either ascending or descending - """ - self.threshold = threshold - self.score_order = score_order - self.db_connection = db_connection - self.conn = db_connection.conn - self.column = column - self.table_name = table_name - db_connection.add_column_integer_default( - table_name=table_name, column=self.column, default=0 - ) - - def _assess_gene_with_threshold_ascending_order( - self, - result_entry: RankedPhEvalGeneResult, - ) -> int: - """ - Record the gene prioritisation rank if it meets the ascending order threshold. - Args: - result_entry (RankedPhEvalGeneResult): Ranked PhEval gene result entry - Returns: - int: Recorded gene prioritisation rank. - """ - if float(self.threshold) > float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _assess_gene_with_threshold( - self, - result_entry: RankedPhEvalGeneResult, - ) -> int: - """ - Record the gene prioritisation rank if it meets the score threshold. - Args: - result_entry (RankedPhEvalResult): Ranked PhEval gene result entry - - Returns: - int: Recorded correct gene prioritisation rank. - """ - if float(self.threshold) < float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _record_matched_gene( - self, - standardised_gene_result: RankedPhEvalGeneResult, - ) -> int: - """ - Return the gene rank result - handling the specification of a threshold. - This method determines and returns the gene rank result based on the specified threshold - and score order. If the threshold is 0.0, it records the gene rank directly. - Otherwise, it assesses the gene with the threshold based on the score order. - Args: - standardised_gene_result (RankedPhEvalGeneResult): Ranked PhEval gene result entry - Returns: - GenePrioritisationResult: Recorded correct gene prioritisation rank result - """ - if float(self.threshold) == 0.0: - return standardised_gene_result.rank - else: - return ( - self._assess_gene_with_threshold(standardised_gene_result) - if self.score_order != "ascending" - else self._assess_gene_with_threshold_ascending_order( - standardised_gene_result, - ) - ) - def assess_gene_prioritisation( self, standardised_gene_result_path: Path, @@ -131,7 +46,7 @@ def assess_gene_prioritisation( .to_dict(orient="records") ) if len(result) > 0: - gene_match = self._record_matched_gene(RankedPhEvalGeneResult(**result[0])) + gene_match = self._record_matched_entity(RankedPhEvalGeneResult(**result[0])) relevant_ranks.append(gene_match) primary_key = f"{phenopacket_path.name}-{row['gene_symbol']}" self.conn.execute( diff --git a/src/pheval/analyse/variant_prioritisation_analysis.py b/src/pheval/analyse/variant_prioritisation_analysis.py index e4c1e349..39a6f639 100644 --- a/src/pheval/analyse/variant_prioritisation_analysis.py +++ b/src/pheval/analyse/variant_prioritisation_analysis.py @@ -1,5 +1,6 @@ from pathlib import Path +from pheval.analyse.assess_prioritisation_base import AssessPrioritisationBase from pheval.analyse.benchmark_db_manager import BenchmarkDBManager from pheval.analyse.benchmarking_data import BenchmarkRunResults from pheval.analyse.binary_classification_stats import BinaryClassificationStats @@ -10,103 +11,9 @@ from pheval.utils.phenopacket_utils import GenomicVariant -class AssessVariantPrioritisation: +class AssessVariantPrioritisation(AssessPrioritisationBase): """Class for assessing variant prioritisation based on thresholds and scoring orders.""" - def __init__( - self, - db_connection: BenchmarkDBManager, - table_name: str, - column: str, - threshold: float, - score_order: str, - ): - """ - Initialise AssessVariantPrioritisation class - - Args: - db_connection (BenchmarkDBManager): DB connection. - table_name (str): Table name. - column (str): Column name. - threshold (float): Threshold for scores - score_order (str): Score order for results, either ascending or descending - - """ - self.threshold = threshold - self.score_order = score_order - self.db_connection = db_connection - self.conn = db_connection.conn - self.column = column - self.table_name = table_name - db_connection.add_column_integer_default( - table_name=table_name, column=self.column, default=0 - ) - - def _assess_variant_with_threshold_ascending_order( - self, result_entry: RankedPhEvalVariantResult - ) -> int: - """ - Record the variant prioritisation rank if it meets the ascending order threshold. - - This method checks if the variant prioritisation rank meets the ascending order threshold. - If the score of the result entry is less than the threshold, it records the variant rank. - - Args: - result_entry (RankedPhEvalVariantResult): Ranked PhEval variant result entry - - Returns: - int: Recorded variant prioritisation rank - """ - if float(self.threshold) > float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _assess_variant_with_threshold(self, result_entry: RankedPhEvalVariantResult) -> int: - """ - Record the variant prioritisation rank if it meets the score threshold. - - This method checks if the variant prioritisation rank meets the score threshold. - If the score of the result entry is greater than the threshold, it records the variant rank. - - Args: - result_entry (RankedPhEvalVariantResult): Ranked PhEval variant result entry - - Returns: - int: Recorded variant prioritisation rank - """ - if float(self.threshold) < float(result_entry.score): - return result_entry.rank - else: - return 0 - - def _record_matched_variant( - self, standardised_variant_result: RankedPhEvalVariantResult - ) -> int: - """ - Return the variant rank result - handling the specification of a threshold. - - This method determines and returns the variant rank result based on the specified threshold - and score order. If the threshold is 0.0, it records the variant rank directly. - Otherwise, it assesses the variant with the threshold based on the score order. - - Args: - standardised_variant_result (RankedPhEvalVariantResult): Ranked PhEval variant result entry - - Returns: - int: Recorded variant prioritisation rank - """ - if float(self.threshold) == 0.0: - return standardised_variant_result.rank - else: - return ( - self._assess_variant_with_threshold(standardised_variant_result) - if self.score_order != "ascending" - else self._assess_variant_with_threshold_ascending_order( - standardised_variant_result, - ) - ) - def assess_variant_prioritisation( self, standardised_variant_result_path: Path, @@ -149,7 +56,7 @@ def assess_variant_prioritisation( ) if len(result) > 0: - variant_match = self._record_matched_variant(RankedPhEvalVariantResult(**result[0])) + variant_match = self._record_matched_entity(RankedPhEvalVariantResult(**result[0])) relevant_ranks.append(variant_match) primary_key = ( f"{phenopacket_path.name}-{causative_variant.chrom}-{causative_variant.pos}-"