Skip to content

Commit

Permalink
chunk size 1 and stats std results (#363)
Browse files Browse the repository at this point in the history
* no inf std fix

* add test for chunk size 1

* cause alert when nan result for stats module

* refactor stats std sampling error to better catch issues
  • Loading branch information
nikml authored Feb 16, 2024
1 parent 3e11cd7 commit 3e292ac
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
18 changes: 15 additions & 3 deletions nannyml/sampling_error/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde, moment
from logging import getLogger

logger = getLogger(__name__)

def summary_stats_std_sampling_error_components(col: pd.Series) -> Tuple:
"""
Expand Down Expand Up @@ -38,8 +40,10 @@ def summary_stats_std_sampling_error(sampling_error_components, col) -> float:
Parameters
----------
sampling_error_components : a set of parameters that were derived from reference data.
col : the (analysis) column you want to calculate sampling error for.
sampling_error_components:
a set of parameters that were derived from reference data.
col:
the (analysis) column you want to calculate sampling error for.
Returns
-------
Expand All @@ -49,7 +53,15 @@ def summary_stats_std_sampling_error(sampling_error_components, col) -> float:
_std = sampling_error_components[0]
_mu4 = sampling_error_components[1]
_size = col.shape[0]
err_var = np.sqrt((1 / _size) * (_mu4 - ((_size - 3) * (_std**4) / (_size - 1))))

err_var_parenthesis_part = (_mu4 - ((_size - 3) * (_std**4) / (_size - 1)))
if not (
np.isfinite(err_var_parenthesis_part) and
err_var_parenthesis_part >= 0
):
logger.debug("Summary Stats sampling error calculation imputed to nan because of non finite positive parenthesis factor.")
return np.nan
err_var = np.sqrt((1 / _size) * err_var_parenthesis_part)
return (1 / (2 * _std)) * err_var


Expand Down
4 changes: 4 additions & 0 deletions nannyml/stats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

"""Module containing base classes for data quality calculations."""

from numpy import isnan


def _add_alert_flag(row_result: dict) -> bool:
flag = False
if isnan(row_result['value']):
flag = True
if row_result['upper_threshold'] is not None:
if row_result['value'] > row_result['upper_threshold']:
flag = True
Expand Down
37 changes: 37 additions & 0 deletions tests/stats/test_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
"""Tests for Drift package."""

import pytest
import pandas as pd
import numpy as np


from nannyml.datasets import load_synthetic_car_loan_dataset
from nannyml.stats import SummaryStatsStdCalculator
from nannyml.chunk import SizeBasedChunker

# @pytest.fixture(scope="module")
# def status_sum_result() -> Result:
Expand All @@ -33,3 +37,36 @@ def test_stats_std_calculator_with_default_params_should_not_fail(): # noqa: D1
_ = calc.calculate(data=analysis)
except Exception:
pytest.fail()


def test_stats_std_calculator_with_default_params_chunk_size_one(): # noqa: D103
reference, analysis, _ = load_synthetic_car_loan_dataset()

chunker = SizeBasedChunker(chunk_size=5_000, incomplete='keep')
calc = SummaryStatsStdCalculator(
column_names=['car_value'],
chunker=chunker
).fit(reference)
result = calc.calculate(data=analysis.head(5_001))
expected = pd.DataFrame(
{
('chunk', 'key'): ['[0:4999]', '[5000:5000]'],
('chunk', 'chunk_index'): [0,1],
('chunk', 'start_index'): [0,5000],
('chunk', 'end_index'): [4999,5000],
('chunk', 'start_date'): [None,None],
('chunk', 'end_date'): [None,None],
('chunk', 'period'): ['analysis','analysis'],
('car_value', 'value'): [20614.8926,np.nan],
('car_value', 'sampling_error'): [271.9917,np.nan],
('car_value', 'upper_confidence_boundary'): [21430.8679,np.nan],
('car_value', 'lower_confidence_boundary'): [19798.9174,np.nan],
('car_value', 'upper_threshold'): [20978.5658, 20978.5658],
('car_value', 'lower_threshold'): [19816.9091, 19816.9091],
('car_value', 'alert'): [False, True],
}
)
pd.testing.assert_frame_equal(
expected,
result.filter(period='analysis').to_df().round(4)
)

0 comments on commit 3e292ac

Please sign in to comment.