diff --git a/frouros/detectors/data_drift/batch/distance_based/mmd.py b/frouros/detectors/data_drift/batch/distance_based/mmd.py index 61708ba..928ff35 100644 --- a/frouros/detectors/data_drift/batch/distance_based/mmd.py +++ b/frouros/detectors/data_drift/batch/distance_based/mmd.py @@ -51,10 +51,6 @@ def __init__( ) self.kernel = kernel self.chunk_size = chunk_size - self._chunk_size_x = None - self.X_chunks_combinations = None - self.X_num_samples = None - self.expected_k_xx = None @property def chunk_size(self) -> Optional[int]: @@ -108,55 +104,16 @@ def _distance_measure( X: np.ndarray, # noqa: N803 **kwargs, ) -> DistanceResult: - mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel, **kwargs) + mmd = self._mmd( + X=X_ref, + Y=X, + kernel=self.kernel, + chunk_size=self.chunk_size, + **kwargs, + ) distance_test = DistanceResult(distance=mmd) return distance_test - def _fit( - self, - X: np.ndarray, # noqa: N803 - **kwargs, - ) -> None: - super()._fit(X=X) - # Add dimension only for the kernel calculation (if dim == 1) - if X.ndim == 1: - X = np.expand_dims(X, axis=1) # noqa: N806 - self.X_num_samples = len(self.X_ref) # type: ignore # noqa: N806 - - self._chunk_size_x = ( - self.X_num_samples - if self.chunk_size is None - else self.chunk_size # type: ignore - ) - - X_chunks = self._get_chunks( # noqa: N806 - data=X, - chunk_size=self._chunk_size_x, # type: ignore - ) - xx_chunks_combinations = itertools.product(X_chunks, repeat=2) # noqa: N806 - - if kwargs.get("verbose", False): - num_chunks = ( - math.ceil(self.X_num_samples / self._chunk_size_x) ** 2 # type: ignore - ) - xx_chunks_combinations = tqdm.tqdm( - xx_chunks_combinations, - total=num_chunks, - ) - - k_xx_sum = ( - self._compute_kernel( - chunk_combinations=xx_chunks_combinations, # type: ignore - kernel=self.kernel, - ) - # Remove diagonal (j!=i case) - - self.X_num_samples # type: ignore - ) - - self.expected_k_xx = k_xx_sum / ( # type: ignore - self.X_num_samples * (self.X_num_samples - 1) # type: ignore - ) - @staticmethod def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float: k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum() @@ -170,8 +127,8 @@ def _get_chunks(data: np.ndarray, chunk_size: int) -> Generator: ) return chunks + @staticmethod def _mmd( # pylint: disable=too-many-locals - self, X: np.ndarray, # noqa: N803 Y: np.ndarray, *, @@ -183,33 +140,56 @@ def _mmd( # pylint: disable=too-many-locals X = np.expand_dims(X, axis=1) # noqa: N806 Y = np.expand_dims(Y, axis=1) # noqa: N806 - X_chunks = self._get_chunks( # noqa: N806 - data=X, - chunk_size=self._chunk_size_x, # type: ignore + x_num_samples = len(X) # noqa: N806 + chunk_size_x = ( + kwargs["chunk_size"] + if "chunk_size" in kwargs and kwargs["chunk_size"] is not None + else x_num_samples + ) + x_chunks, x_chunks_copy = itertools.tee( # noqa: N806 + MMD._get_chunks( + data=X, + chunk_size=chunk_size_x, # type: ignore + ), + 2, ) y_num_samples = len(Y) # noqa: N806 - chunk_size_y = y_num_samples if self.chunk_size is None else self.chunk_size + chunk_size_y = ( + kwargs["chunk_size"] + if "chunk_size" in kwargs and kwargs["chunk_size"] is not None + else y_num_samples + ) y_chunks, y_chunks_copy = itertools.tee( # noqa: N806 - self._get_chunks( + MMD._get_chunks( data=Y, chunk_size=chunk_size_y, # type: ignore ), 2, ) + x_chunks_combinations = itertools.product( # noqa: N806 + x_chunks, + repeat=2, + ) y_chunks_combinations = itertools.product( # noqa: N806 y_chunks, repeat=2, ) xy_chunks_combinations = itertools.product( # noqa: N806 - X_chunks, + x_chunks_copy, y_chunks_copy, ) if kwargs.get("verbose", False): + num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore + num_chunks_x_combinations = num_chunks_x**2 num_chunks_y_combinations = num_chunks_y**2 num_chunks_xy = ( - math.ceil(len(X) / self._chunk_size_x) * num_chunks_y # type: ignore + math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore + ) + x_chunks_combinations = tqdm.tqdm( + x_chunks_combinations, + total=num_chunks_x_combinations, ) y_chunks_combinations = tqdm.tqdm( y_chunks_combinations, @@ -220,21 +200,29 @@ def _mmd( # pylint: disable=too-many-locals total=num_chunks_xy, ) + k_xx_sum = ( + MMD._compute_kernel( + chunk_combinations=x_chunks_combinations, # type: ignore + kernel=kernel, + ) + # Remove diagonal (j!=i case) + - x_num_samples # type: ignore + ) k_yy_sum = ( - self._compute_kernel( + MMD._compute_kernel( chunk_combinations=y_chunks_combinations, # type: ignore kernel=kernel, ) # Remove diagonal (j!=i case) - y_num_samples # type: ignore ) - k_xy_sum = self._compute_kernel( + k_xy_sum = MMD._compute_kernel( chunk_combinations=xy_chunks_combinations, # type: ignore kernel=kernel, ) mmd = ( - self.expected_k_xx # type: ignore + +k_xx_sum / (x_num_samples * (x_num_samples - 1)) + k_yy_sum / (y_num_samples * (y_num_samples - 1)) - - 2 * k_xy_sum / (self.X_num_samples * y_num_samples) # type: ignore + - 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore ) return mmd