Skip to content

Commit

Permalink
Merge pull request #358 from monarch-initiative/357-remove-duplicated…
Browse files Browse the repository at this point in the history
…-code-in-variantgenedisease-analysis

357 remove duplicated code in variant/gene/disease analysis
  • Loading branch information
yaseminbridges authored Sep 27, 2024
2 parents ee41166 + 5a529b2 commit e8e1591
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 297 deletions.
108 changes: 108 additions & 0 deletions src/pheval/analyse/assess_prioritisation_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Union

from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
from pheval.post_processing.post_processing import (
RankedPhEvalDiseaseResult,
RankedPhEvalGeneResult,
RankedPhEvalVariantResult,
)


class AssessPrioritisationBase:
def __init__(
self,
db_connection: BenchmarkDBManager,
table_name: str,
column: str,
threshold: float,
score_order: str,
):
"""
Initialise AssessPrioritisationBase class
Args:
db_connection (BenchmarkDBManager): DB connection.
table_name (str): Table name.
column (str): Column name.
threshold (float): Threshold for scores
score_order (str): Score order for results, either ascending or descending
"""
self.threshold = threshold
self.score_order = score_order
self.db_connection = db_connection
self.conn = db_connection.conn
self.column = column
self.table_name = table_name
db_connection.add_column_integer_default(
table_name=table_name, column=self.column, default=0
)

def _assess_with_threshold_ascending_order(
self,
result_entry: Union[
RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult
],
) -> int:
"""
Record the prioritisation rank if it meets the ascending order threshold.
Args:
result_entry (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]):
Ranked PhEval result entry
Returns:
int: Recorded prioritisation rank
"""
if float(self.threshold) > float(result_entry.score):
return result_entry.rank
else:
return 0

def _assess_with_threshold(
self,
result_entry: Union[
RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult
],
) -> int:
"""
Record the prioritisation rank if it meets the score threshold.
Args:
result_entry (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]):
Ranked PhEval result entry
Returns:
int: Recorded prioritisation rank
"""
if float(self.threshold) < float(result_entry.score):
return result_entry.rank
else:
return 0

def _record_matched_entity(
self,
standardised_result: Union[
RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult
],
) -> int:
"""
Return the rank result - handling the specification of a threshold.
Args:
standardised_result (Union[RankedPhEvalGeneResult, RankedPhEvalDiseaseResult, RankedPhEvalVariantResult]):
Ranked PhEval disease result entry
Returns:
int: Recorded entity prioritisation rank
"""
if float(self.threshold) == 0.0:
return standardised_result.rank
else:
return (
self._assess_with_threshold(standardised_result)
if self.score_order != "ascending"
else self._assess_with_threshold_ascending_order(
standardised_result,
)
)
104 changes: 3 additions & 101 deletions src/pheval/analyse/disease_prioritisation_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

from pheval.analyse.assess_prioritisation_base import AssessPrioritisationBase
from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
from pheval.analyse.benchmarking_data import BenchmarkRunResults
from pheval.analyse.binary_classification_stats import BinaryClassificationStats
Expand All @@ -9,108 +10,9 @@
from pheval.utils.file_utils import all_files


class AssessDiseasePrioritisation:
class AssessDiseasePrioritisation(AssessPrioritisationBase):
"""Class for assessing disease prioritisation based on thresholds and scoring orders."""

def __init__(
self,
db_connection: BenchmarkDBManager,
table_name: str,
column: str,
threshold: float,
score_order: str,
):
"""
Initialise AssessDiseasePrioritisation class
Args:
db_connection (BenchmarkDBManager): Database connection
table_name (str): Table name
column (Path): Column name
threshold (float): Threshold for scores
score_order (str): Score order for results, either ascending or descending
"""
self.threshold = threshold
self.score_order = score_order
self.db_connection = db_connection
self.conn = db_connection.conn
self.column = column
self.table_name = table_name
db_connection.add_column_integer_default(
table_name=table_name, column=self.column, default=0
)

def _assess_disease_with_threshold_ascending_order(
self,
result_entry: RankedPhEvalDiseaseResult,
) -> int:
"""
Record the disease prioritisation rank if it meets the ascending order threshold.
This method checks if the disease prioritisation rank meets the ascending order threshold.
If the score of the result entry is less than the threshold, it records the disease rank.
Args:
result_entry (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry
Returns:
int: Recorded disease prioritisation rank
"""
if float(self.threshold) > float(result_entry.score):
return result_entry.rank
else:
return 0

def _assess_disease_with_threshold(
self,
result_entry: RankedPhEvalDiseaseResult,
) -> int:
"""
Record the disease prioritisation rank if it meets the score threshold.
This method checks if the disease prioritisation rank meets the score threshold.
If the score of the result entry is greater than the threshold, it records the disease rank.
Args:
result_entry (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry
Returns:
int: Recorded disease prioritisation rank
"""
if float(self.threshold) < float(result_entry.score):
return result_entry.rank
else:
return 0

def _record_matched_disease(
self,
standardised_disease_result: RankedPhEvalDiseaseResult,
) -> int:
"""
Return the disease rank result - handling the specification of a threshold.
This method determines and returns the disease rank result based on the specified threshold
and score order. If the threshold is 0.0, it records the disease rank directly.
Otherwise, it assesses the disease with the threshold based on the score order.
Args:
standardised_disease_result (RankedPhEvalDiseaseResult): Ranked PhEval disease result entry
Returns:
int: Recorded disease prioritisation rank
"""
if float(self.threshold) == 0.0:
return standardised_disease_result.rank
else:
return (
self._assess_disease_with_threshold(standardised_disease_result)
if self.score_order != "ascending"
else self._assess_disease_with_threshold_ascending_order(
standardised_disease_result,
)
)

def assess_disease_prioritisation(
self,
standardised_disease_result_path: Path,
Expand Down Expand Up @@ -147,7 +49,7 @@ def assess_disease_prioritisation(
)

if len(result) > 0:
disease_match = self._record_matched_disease(RankedPhEvalDiseaseResult(**result[0]))
disease_match = self._record_matched_entity(RankedPhEvalDiseaseResult(**result[0]))
relevant_ranks.append(disease_match)
primary_key = f"{phenopacket_path.name}-{row['disease_identifier']}"
self.conn.execute(
Expand Down
91 changes: 3 additions & 88 deletions src/pheval/analyse/gene_prioritisation_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

from pheval.analyse.assess_prioritisation_base import AssessPrioritisationBase
from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
from pheval.analyse.benchmarking_data import BenchmarkRunResults
from pheval.analyse.binary_classification_stats import BinaryClassificationStats
Expand All @@ -9,95 +10,9 @@
from pheval.utils.file_utils import all_files


class AssessGenePrioritisation:
class AssessGenePrioritisation(AssessPrioritisationBase):
"""Class for assessing gene prioritisation based on thresholds and scoring orders."""

def __init__(
self,
db_connection: BenchmarkDBManager,
table_name: str,
column: str,
threshold: float,
score_order: str,
):
"""
Initialise AssessGenePrioritisation class.
Args:
db_connection (BenchmarkDBManager): Database connection
table_name (str): Table name
column (Path): Column name
threshold (float): Threshold for scores
score_order (str): Score order for results, either ascending or descending
"""
self.threshold = threshold
self.score_order = score_order
self.db_connection = db_connection
self.conn = db_connection.conn
self.column = column
self.table_name = table_name
db_connection.add_column_integer_default(
table_name=table_name, column=self.column, default=0
)

def _assess_gene_with_threshold_ascending_order(
self,
result_entry: RankedPhEvalGeneResult,
) -> int:
"""
Record the gene prioritisation rank if it meets the ascending order threshold.
Args:
result_entry (RankedPhEvalGeneResult): Ranked PhEval gene result entry
Returns:
int: Recorded gene prioritisation rank.
"""
if float(self.threshold) > float(result_entry.score):
return result_entry.rank
else:
return 0

def _assess_gene_with_threshold(
self,
result_entry: RankedPhEvalGeneResult,
) -> int:
"""
Record the gene prioritisation rank if it meets the score threshold.
Args:
result_entry (RankedPhEvalResult): Ranked PhEval gene result entry
Returns:
int: Recorded correct gene prioritisation rank.
"""
if float(self.threshold) < float(result_entry.score):
return result_entry.rank
else:
return 0

def _record_matched_gene(
self,
standardised_gene_result: RankedPhEvalGeneResult,
) -> int:
"""
Return the gene rank result - handling the specification of a threshold.
This method determines and returns the gene rank result based on the specified threshold
and score order. If the threshold is 0.0, it records the gene rank directly.
Otherwise, it assesses the gene with the threshold based on the score order.
Args:
standardised_gene_result (RankedPhEvalGeneResult): Ranked PhEval gene result entry
Returns:
GenePrioritisationResult: Recorded correct gene prioritisation rank result
"""
if float(self.threshold) == 0.0:
return standardised_gene_result.rank
else:
return (
self._assess_gene_with_threshold(standardised_gene_result)
if self.score_order != "ascending"
else self._assess_gene_with_threshold_ascending_order(
standardised_gene_result,
)
)

def assess_gene_prioritisation(
self,
standardised_gene_result_path: Path,
Expand Down Expand Up @@ -131,7 +46,7 @@ def assess_gene_prioritisation(
.to_dict(orient="records")
)
if len(result) > 0:
gene_match = self._record_matched_gene(RankedPhEvalGeneResult(**result[0]))
gene_match = self._record_matched_entity(RankedPhEvalGeneResult(**result[0]))
relevant_ranks.append(gene_match)
primary_key = f"{phenopacket_path.name}-{row['gene_symbol']}"
self.conn.execute(
Expand Down
Loading

0 comments on commit e8e1591

Please sign in to comment.