From 808913cbce0eee3294e2a14de15b607be3adc471 Mon Sep 17 00:00:00 2001 From: Niels Nuyttens Date: Tue, 13 Feb 2024 13:27:10 +0100 Subject: [PATCH] Deal with median calculation failing due to NaNs --- nannyml/sampling_error/summary_stats.py | 1 - nannyml/stats/median/calculator.py | 4 +++- tests/stats/test_median.py | 14 +++++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/nannyml/sampling_error/summary_stats.py b/nannyml/sampling_error/summary_stats.py index 2e61f30da..8b072c3af 100644 --- a/nannyml/sampling_error/summary_stats.py +++ b/nannyml/sampling_error/summary_stats.py @@ -92,7 +92,6 @@ def summary_stats_median_sampling_error(sampling_error_components, col) -> float sampling_error: float """ - # median = sampling_error_components[0] # TODO: check if this can be removed with Nikoss fmedian = sampling_error_components[1] _size = col.shape[0] err = np.sqrt(1 / (4 * _size * (fmedian**2))) diff --git a/nannyml/stats/median/calculator.py b/nannyml/stats/median/calculator.py index c9ce3c87e..1a5410f2a 100644 --- a/nannyml/stats/median/calculator.py +++ b/nannyml/stats/median/calculator.py @@ -120,7 +120,9 @@ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs): ) for col in self.column_names: - self._sampling_error_components[col] = summary_stats_median_sampling_error_components(reference_data[col]) + self._sampling_error_components[col] = summary_stats_median_sampling_error_components( + reference_data[col].dropna() + ) for column in self.column_names: reference_chunk_results = np.asarray( diff --git a/tests/stats/test_median.py b/tests/stats/test_median.py index 03aed69b9..615d826e9 100644 --- a/tests/stats/test_median.py +++ b/tests/stats/test_median.py @@ -4,7 +4,7 @@ # License: Apache Software License 2.0 """Tests for Drift package.""" - +import numpy as np import pytest from nannyml.datasets import load_synthetic_car_loan_dataset @@ -33,3 +33,15 @@ def test_stats_median_calculator_with_default_params_should_not_fail(): # noqa: _ = calc.calculate(data=analysis) except Exception: pytest.fail() + + +def test_stats_median_calculator_should_not_fail_given_nan_values(): # noqa: D103 + reference, analysis, _ = load_synthetic_car_loan_dataset() + reference.loc[20000:30000, 'car_value'] = np.NaN + try: + calc = SummaryStatsMedianCalculator( + column_names=['car_value'], + ).fit(reference) + _ = calc.calculate(data=analysis) + except Exception: + pytest.fail()