Skip to content

Commit

Permalink
Refactor more data cleaning methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-nml committed Nov 19, 2023
1 parent de16772 commit 7131419
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 32 deletions.
28 changes: 6 additions & 22 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import nannyml.sampling_error.binary_classification as bse
import nannyml.sampling_error.multiclass_classification as mse
from nannyml._typing import ModelOutputsType, ProblemType, class_labels
from nannyml.base import _remove_nans
from nannyml.chunk import Chunk, Chunker
from nannyml.exceptions import CalculatorException, InvalidArgumentsException
from nannyml.performance_estimation.confidence_based import SUPPORTED_METRIC_VALUES
Expand Down Expand Up @@ -234,30 +235,13 @@ def _common_cleaning(
)
y_pred_proba_column_name = self.y_pred_proba

clean_targets = self.y_true in data.columns and not data[self.y_true].isna().all()

y_pred_proba = data[y_pred_proba_column_name]
y_pred = data[self.y_pred]
y_true = data[self.y_true] if clean_targets else None

# Create mask to filter out NaN values
mask = ~(y_pred.isna() | y_pred_proba.isna())
if clean_targets:
mask = mask | ~(y_true.isna())
data = _remove_nans(data, [self.y_pred, y_pred_proba_column_name])

# Drop missing values (NaN/None)
y_pred_proba = y_pred_proba[mask]
y_pred = y_pred[mask]
if clean_targets:
y_true = y_true[mask]

# NaN values have been dropped. Try to infer types again
y_pred_proba = y_pred_proba.infer_objects()
y_pred = y_pred.infer_objects()
clean_targets = self.y_true in data.columns and not data[self.y_true].isna().all()
if clean_targets:
y_true = y_true.infer_objects()
data = _remove_nans(data, [self.y_true])

return y_pred_proba, y_pred, y_true
return data[y_pred_proba_column_name], data[self.y_pred], (data[self.y_true] if clean_targets else None)

def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict:
"""Returns a dictionary containing the performance metrics for a given chunk.
Expand Down Expand Up @@ -1584,7 +1568,7 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
if y_true.shape[0] == 0:
warnings.warn("Calculated Business Value contains NaN values.")
return np.NaN

tp_value = self.business_value_matrix[1, 1]
tn_value = self.business_value_matrix[0, 0]
fp_value = self.business_value_matrix[0, 1]
Expand Down
16 changes: 6 additions & 10 deletions nannyml/performance_estimation/direct_loss_estimation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

from nannyml._typing import ProblemType
from nannyml.base import _raise_exception_for_negative_values
from nannyml.base import _raise_exception_for_negative_values, _remove_nans
from nannyml.chunk import Chunk, Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.sampling_error.regression import (
Expand Down Expand Up @@ -271,18 +271,14 @@ def __eq__(self, other):
"""Establishes equality by comparing all properties."""
return self.display_name == other.display_name and self.column_name == other.column_name

def _common_cleaning(self, data: pd.DataFrame) -> Tuple[pd.Series, pd.Series]:
clean_targets = self.y_true in data.columns and not data[self.y_true].isna().all()
def _common_cleaning(self, data: pd.DataFrame) -> Tuple[pd.Series, Optional[pd.Series]]:
data = _remove_nans(data, [self.y_pred])

y_pred = data[self.y_pred]
clean_targets = self.y_true in data.columns and not data[self.y_true].isna().all()
if clean_targets:
y_true = data[self.y_true]
y_pred = y_pred[~y_true.isna()]
y_true.dropna(inplace=True)
else:
y_true = None
data = _remove_nans(data, [self.y_pred, self.y_true])

return y_pred, y_true
return data[self.y_pred], (data[self.y_true] if clean_targets else None)

def _train_direct_error_estimation_model(
self,
Expand Down

0 comments on commit 7131419

Please sign in to comment.