Skip to content

Commit

Permalink
Add MMD kernel initialization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jaime-cespedes-sisniega committed Dec 6, 2024
1 parent 762c17b commit 32fe6b3
Showing 1 changed file with 57 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from typing import (
Any,
Callable,
Optional,
Tuple,
)
Expand Down Expand Up @@ -174,7 +175,7 @@ def test_mmd_chunk_size_equivalence(
2,
],
)
def test_mmd_chunk_size_initialization_valid(
def test_mmd_chunk_size_valid(
chunk_size: Optional[int],
) -> None:
"""Test MMD initialization with valid chunk sizes.
Expand Down Expand Up @@ -230,3 +231,58 @@ def test_mmd_chunk_size_invalid(
kernel=kernel,
chunk_size=chunk_size,
)


@pytest.mark.parametrize(
"kernel",
[
partial(
rbf_kernel,
sigma=DEFAULT_SIGMA,
),
lambda X, Y: X + Y, # simple kernel
],
)
def test_mmd_kernel_valid(
kernel: Callable, # type: ignore
) -> None:
"""Test MMD initialization with valid kernels.
:param kernel: kernel to test
:type kernel: Callable
"""
np.random.seed(seed=RANDOM_SEED)
X_ref = np.random.normal(0, 1, 100)
X_test = np.random.normal(0, 1, 100)

detector = MMD(
kernel=kernel,
)
_ = detector.fit(X=X_ref)
result = detector.compare(X=X_test)[0]

assert result is not None


@pytest.mark.parametrize(
"kernel",
[
None,
"invalid",
123,
[1, 2],
{1: 2},
],
)
def test_mmd_kernel_invalid(
kernel: Any,
) -> None:
"""Test MMD initialization with invalid kernels.
:param kernel: kernel to test
:type kernel: Any
"""
with pytest.raises((TypeError, ValueError)):
MMD(
kernel=kernel,
)

0 comments on commit 32fe6b3

Please sign in to comment.