Skip to content

Commit

Permalink
Merge pull request #249 from IFCA/fix-permutation-test-p-values
Browse files Browse the repository at this point in the history
Fix p-values workaround computation
  • Loading branch information
jaime-cespedes-sisniega authored Jul 22, 2023
2 parents f04d4f8 + 7f55b86 commit 8368845
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 110 deletions.
148 changes: 76 additions & 72 deletions docs/source/examples/data_drift/MMD_advance.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion frouros/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Frouros."""

__version__ = "0.5.0"
__version__ = "0.5.1"
11 changes: 2 additions & 9 deletions frouros/callbacks/batch/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np # type: ignore
from scipy.stats import norm # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.utils.stats import permutation, z_score
from frouros.utils.stats import permutation


class PermutationTestDistanceBased(BaseCallbackBatch):
Expand Down Expand Up @@ -122,13 +121,7 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
verbose=verbose,
)
permuted_statistic = np.array(permuted_statistic)
# Use z-score to calculate p-value
observed_z_score = z_score(
value=observed_statistic,
mean=permuted_statistic.mean(), # type: ignore
std=permuted_statistic.std(), # type: ignore
)
p_value = norm.sf(np.abs(observed_z_score)) * 2
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
return permuted_statistic, p_value

def on_compare_end(self, **kwargs) -> None:
Expand Down
15 changes: 7 additions & 8 deletions frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
"detector_class, expected_distance, expected_p_value",
[
(BhattacharyyaDistance, 0.55516059, 0.0),
(EMD, 3.85346006, 9.21632493e-101),
(HellingerDistance, 0.74509099, 3.13808126e-50),
(HINormalizedComplement, 0.78, 1.31340683e-55),
(JS, 0.67010107, 2.30485343e-63),
(KL, np.inf, np.nan),
(MMD, 0.69509004, 2.53277069e-137),
(PSI, 461.20379435, 4.45088795e-238),
(EMD, 3.85346006, 0.0),
(HellingerDistance, 0.74509099, 0.0),
(HINormalizedComplement, 0.78, 0.0),
(JS, 0.67010107, 0.0),
(KL, np.inf, 0.06),
(MMD, 0.69509004, 0.0),
(PSI, 461.20379435, 0.0),
],
)
def test_batch_permutation_test_data_univariate_different_distribution(
Expand Down Expand Up @@ -103,7 +103,6 @@ def test_batch_permutation_test_data_univariate_different_distribution(
assert np.isclose(
callback_logs[permutation_test_name]["p_value"],
expected_p_value,
equal_nan=True,
)


Expand Down
19 changes: 0 additions & 19 deletions frouros/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,3 @@ def permutation( # pylint: disable=too-many-arguments,too-many-locals
).get()

return permuted_statistics


def z_score(
value: np.ndarray,
mean: float,
std: float,
) -> np.ndarray:
"""Z-score method.
:param value: value to use to compute the z-score
:type value: np.ndarray
:param mean: mean value
:type mean: float
:param std: standard deviation value
:type std: float
:return: z-score
:rtype: np.ndarray
"""
return (value - mean) / std
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "frouros"
version = "0.5.0"
version = "0.5.1"
description = "An open-source Python library for drift detection in machine learning systems"
authors = [
{name = "Jaime Céspedes Sisniega", email = "cespedes@ifca.unican.es"}
Expand Down

0 comments on commit 8368845

Please sign in to comment.