Skip to content

Commit

Permalink
Deal with median calculation failing due to NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters committed Feb 13, 2024
1 parent 926b0a5 commit 808913c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
1 change: 0 additions & 1 deletion nannyml/sampling_error/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def summary_stats_median_sampling_error(sampling_error_components, col) -> float
sampling_error: float
"""
# median = sampling_error_components[0] # TODO: check if this can be removed with Nikoss
fmedian = sampling_error_components[1]
_size = col.shape[0]
err = np.sqrt(1 / (4 * _size * (fmedian**2)))
Expand Down
4 changes: 3 additions & 1 deletion nannyml/stats/median/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs):
)

for col in self.column_names:
self._sampling_error_components[col] = summary_stats_median_sampling_error_components(reference_data[col])
self._sampling_error_components[col] = summary_stats_median_sampling_error_components(
reference_data[col].dropna()
)

for column in self.column_names:
reference_chunk_results = np.asarray(
Expand Down
14 changes: 13 additions & 1 deletion tests/stats/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# License: Apache Software License 2.0

"""Tests for Drift package."""

import numpy as np
import pytest

from nannyml.datasets import load_synthetic_car_loan_dataset
Expand Down Expand Up @@ -33,3 +33,15 @@ def test_stats_median_calculator_with_default_params_should_not_fail(): # noqa:
_ = calc.calculate(data=analysis)
except Exception:
pytest.fail()


def test_stats_median_calculator_should_not_fail_given_nan_values(): # noqa: D103
reference, analysis, _ = load_synthetic_car_loan_dataset()
reference.loc[20000:30000, 'car_value'] = np.NaN
try:
calc = SummaryStatsMedianCalculator(
column_names=['car_value'],
).fit(reference)
_ = calc.calculate(data=analysis)
except Exception:
pytest.fail()

0 comments on commit 808913c

Please sign in to comment.