Skip to content

Commit

Permalink
Merge pull request #248 from IFCA/fix-mmd
Browse files Browse the repository at this point in the history
Fix _mmd from MMD to be an static method
  • Loading branch information
jaime-cespedes-sisniega authored Jul 22, 2023
2 parents ea25e65 + 6dc1536 commit f04d4f8
Showing 1 changed file with 50 additions and 62 deletions.
112 changes: 50 additions & 62 deletions frouros/detectors/data_drift/batch/distance_based/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def __init__(
)
self.kernel = kernel
self.chunk_size = chunk_size
self._chunk_size_x = None
self.X_chunks_combinations = None
self.X_num_samples = None
self.expected_k_xx = None

@property
def chunk_size(self) -> Optional[int]:
Expand Down Expand Up @@ -108,55 +104,16 @@ def _distance_measure(
X: np.ndarray, # noqa: N803
**kwargs,
) -> DistanceResult:
mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel, **kwargs)
mmd = self._mmd(
X=X_ref,
Y=X,
kernel=self.kernel,
chunk_size=self.chunk_size,
**kwargs,
)
distance_test = DistanceResult(distance=mmd)
return distance_test

def _fit(
self,
X: np.ndarray, # noqa: N803
**kwargs,
) -> None:
super()._fit(X=X)
# Add dimension only for the kernel calculation (if dim == 1)
if X.ndim == 1:
X = np.expand_dims(X, axis=1) # noqa: N806
self.X_num_samples = len(self.X_ref) # type: ignore # noqa: N806

self._chunk_size_x = (
self.X_num_samples
if self.chunk_size is None
else self.chunk_size # type: ignore
)

X_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=self._chunk_size_x, # type: ignore
)
xx_chunks_combinations = itertools.product(X_chunks, repeat=2) # noqa: N806

if kwargs.get("verbose", False):
num_chunks = (
math.ceil(self.X_num_samples / self._chunk_size_x) ** 2 # type: ignore
)
xx_chunks_combinations = tqdm.tqdm(
xx_chunks_combinations,
total=num_chunks,
)

k_xx_sum = (
self._compute_kernel(
chunk_combinations=xx_chunks_combinations, # type: ignore
kernel=self.kernel,
)
# Remove diagonal (j!=i case)
- self.X_num_samples # type: ignore
)

self.expected_k_xx = k_xx_sum / ( # type: ignore
self.X_num_samples * (self.X_num_samples - 1) # type: ignore
)

@staticmethod
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
Expand All @@ -170,8 +127,8 @@ def _get_chunks(data: np.ndarray, chunk_size: int) -> Generator:
)
return chunks

@staticmethod
def _mmd( # pylint: disable=too-many-locals
self,
X: np.ndarray, # noqa: N803
Y: np.ndarray,
*,
Expand All @@ -183,33 +140,56 @@ def _mmd( # pylint: disable=too-many-locals
X = np.expand_dims(X, axis=1) # noqa: N806
Y = np.expand_dims(Y, axis=1) # noqa: N806

X_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=self._chunk_size_x, # type: ignore
x_num_samples = len(X) # noqa: N806
chunk_size_x = (
kwargs["chunk_size"]
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
else x_num_samples
)
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
MMD._get_chunks(
data=X,
chunk_size=chunk_size_x, # type: ignore
),
2,
)
y_num_samples = len(Y) # noqa: N806
chunk_size_y = y_num_samples if self.chunk_size is None else self.chunk_size
chunk_size_y = (
kwargs["chunk_size"]
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
else y_num_samples
)
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
self._get_chunks(
MMD._get_chunks(
data=Y,
chunk_size=chunk_size_y, # type: ignore
),
2,
)
x_chunks_combinations = itertools.product( # noqa: N806
x_chunks,
repeat=2,
)
y_chunks_combinations = itertools.product( # noqa: N806
y_chunks,
repeat=2,
)
xy_chunks_combinations = itertools.product( # noqa: N806
X_chunks,
x_chunks_copy,
y_chunks_copy,
)

if kwargs.get("verbose", False):
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
num_chunks_x_combinations = num_chunks_x**2
num_chunks_y_combinations = num_chunks_y**2
num_chunks_xy = (
math.ceil(len(X) / self._chunk_size_x) * num_chunks_y # type: ignore
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
)
x_chunks_combinations = tqdm.tqdm(
x_chunks_combinations,
total=num_chunks_x_combinations,
)
y_chunks_combinations = tqdm.tqdm(
y_chunks_combinations,
Expand All @@ -220,21 +200,29 @@ def _mmd( # pylint: disable=too-many-locals
total=num_chunks_xy,
)

k_xx_sum = (
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples # type: ignore
)
k_yy_sum = (
self._compute_kernel(
MMD._compute_kernel(
chunk_combinations=y_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- y_num_samples # type: ignore
)
k_xy_sum = self._compute_kernel(
k_xy_sum = MMD._compute_kernel(
chunk_combinations=xy_chunks_combinations, # type: ignore
kernel=kernel,
)
mmd = (
self.expected_k_xx # type: ignore
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
- 2 * k_xy_sum / (self.X_num_samples * y_num_samples) # type: ignore
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
)
return mmd

0 comments on commit f04d4f8

Please sign in to comment.