Skip to content

Commit

Permalink
Fix threshold computation handling nan values (#333)
Browse files Browse the repository at this point in the history
* replace std and mean with nanstd and nanmean

* fix test

* Added test to properly check outcome

---------

Co-authored-by: Giovanni Davoli <giovanni.davoli@axyon.ai>
Co-authored-by: Niels Nuyttens <niels@nannyml.com>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent 58d44bc commit fcc72e7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
6 changes: 3 additions & 3 deletions nannyml/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
self,
std_lower_multiplier: Optional[Union[float, int]] = 3,
std_upper_multiplier: Optional[Union[float, int]] = 3,
offset_from: Callable[[np.ndarray], Any] = np.mean,
offset_from: Callable[[np.ndarray], Any] = np.nanmean,
):
"""Creates a new StandardDeviationThreshold instance.
Expand All @@ -166,7 +166,7 @@ def __init__(
The number the standard deviation of the input array will be multiplied with to form the upper offset.
This value will be added to the aggregate of the input array.
Defaults to 3.
offset_from: Callable[[np.ndarray], Any], default=np.mean
offset_from: Callable[[np.ndarray], Any], default=np.nanmean
A function that will be applied to the input array to aggregate it into a single value.
Adding the upper offset to this value will yield the upper threshold, subtracting the lower offset
will yield the lower threshold.
Expand All @@ -180,7 +180,7 @@ def __init__(

def thresholds(self, data: np.ndarray, **kwargs) -> Tuple[Optional[float], Optional[float]]:
aggregate = self.offset_from(data)
std = np.std(data)
std = np.nanstd(data)

lower_threshold = aggregate - std * self.std_lower_multiplier if self.std_lower_multiplier is not None else None

Expand Down
10 changes: 9 additions & 1 deletion tests/test_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_standard_deviation_threshold_init_sets_default_instance_attributes():

assert sut.std_lower_multiplier == 3
assert sut.std_upper_multiplier == 3
assert sut.offset_from == np.mean
assert sut.offset_from == np.nanmean


@pytest.mark.parametrize(
Expand Down Expand Up @@ -153,3 +153,11 @@ def test_standard_deviation_threshold_raises_threshold_exception_when_negative_l
def test_standard_deviation_threshold_raises_threshold_exception_when_negative_upper_multiplier_given():
with pytest.raises(ThresholdException, match="'std_upper_multiplier' should be greater than 0 but got value -1"):
_ = StandardDeviationThreshold(0, -1)


def test_standard_deviation_threshold_deals_with_nan_values():
t = StandardDeviationThreshold()
upper, lower = t.thresholds(np.asarray([-1, 1, np.nan, 1, np.nan]))

assert not np.isnan(upper)
assert not np.isnan(lower)

0 comments on commit fcc72e7

Please sign in to comment.