Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jul 3, 2024
1 parent 56eccbe commit 528c40a
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 18 deletions.
10 changes: 10 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Custom exceptions."""


### WARNINGS ###
class NoSearchspaceMatchWarning(UserWarning):
"""The provided input has no match in the searchspace."""


class TooManySearchspaceMatchesWarning(UserWarning):
"""The provided input has multiple matches in the searchspace."""


### EXCEPTIONS ###
class NotEnoughPointsLeftError(Exception):
"""
More recommendations are requested than there are viable parameter configurations
Expand Down
14 changes: 8 additions & 6 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterable, Iterator, Sequence
from typing import (
TYPE_CHECKING,
Expand All @@ -13,6 +14,7 @@
import numpy as np
import pandas as pd

from baybe.exceptions import NoSearchspaceMatchWarning, TooManySearchspaceMatchesWarning
from baybe.targets.enum import TargetMode
from baybe.utils.numerical import DTypeFloatNumpy

Expand Down Expand Up @@ -417,17 +419,17 @@ def fuzzy_row_match(
# We expect exactly one match. If that's not the case, print a warning.
inds_found = left_df.index[match].to_list()
if len(inds_found) == 0 and len(num_cols) > 0:
_logger.warning(
"Input row with index %s could not be matched to the search space. "
warnings.warn(
f"Input row with index {ind} could not be matched to the search space. "
"This could indicate that something went wrong.",
ind,
NoSearchspaceMatchWarning,
)
elif len(inds_found) > 1:
_logger.warning(
"Input row with index %s has multiple matches with "
warnings.warn(
f"Input row with index {ind} has multiple matches with "
"the search space. This could indicate that something went wrong. "
"Matching only first occurrence.",
ind,
TooManySearchspaceMatchesWarning,
)
inds_matched.append(inds_found[0])
else:
Expand Down
152 changes: 140 additions & 12 deletions tests/test_input_output.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Tests for basic input-output and iterative loop."""
import warnings

import numpy as np
import pandas as pd
import pytest

from baybe.constraints import DiscreteNoLabelDuplicatesConstraint
from baybe.exceptions import NoSearchspaceMatchWarning
from baybe.utils.augmentation import (
df_apply_dependency_augmentation,
df_apply_permutation_augmentation,
)
from baybe.utils.dataframe import add_fake_results

# List of tests that are expected to fail (still missing implementation etc)
param_xfails = []
target_xfails = []


@pytest.mark.parametrize(
"bad_val",
Expand All @@ -16,9 +21,6 @@
)
def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, request):
"""Test attempting to read in an invalid parameter value."""
if request.node.callspec.id in param_xfails:
pytest.xfail()

rec = campaign.recommend(batch_size=3)
add_fake_results(
rec,
Expand All @@ -27,7 +29,11 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req
)

# Add an invalid value
rec.Num_disc_1.iloc[0] = bad_val
with warnings.catch_warnings():
# Ignore warning about incompatible data type assignment
warnings.simplefilter("ignore", FutureWarning)
rec.iloc[0, rec.columns.get_loc("Num_disc_1")] = bad_val

with pytest.raises((ValueError, TypeError)):
campaign.add_measurements(rec)

Expand All @@ -39,9 +45,6 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req
)
def test_bad_target_input_value(campaign, good_reference_values, bad_val, request):
"""Test attempting to read in an invalid target value."""
if request.node.callspec.id in target_xfails:
pytest.xfail()

rec = campaign.recommend(batch_size=3)
add_fake_results(
rec,
Expand All @@ -50,6 +53,131 @@ def test_bad_target_input_value(campaign, good_reference_values, bad_val, reques
)

# Add an invalid value
rec.Target_max.iloc[0] = bad_val
with warnings.catch_warnings():
# Ignore warning about incompatible data type assignment
warnings.simplefilter("ignore", FutureWarning)
rec.iloc[0, rec.columns.get_loc("Target_max")] = bad_val

with pytest.raises((ValueError, TypeError)):
campaign.add_measurements(rec)


# Reused parameter names for the mixture mock example
_mixture_columns = [
"Solvent_1",
"Solvent_2",
"Solvent_3",
"Fraction_1",
"Fraction_2",
"Fraction_3",
]


@pytest.mark.parametrize("n_grid_points", [5])
@pytest.mark.parametrize(
"entry",
[
pd.DataFrame.from_records(
[["THF", "Water", "DMF", 0.0, 25.0, 75.0]], columns=_mixture_columns
),
],
)
@pytest.mark.parametrize("parameter_names", [_mixture_columns])
@pytest.mark.parametrize(
"constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]]
)
def test_permutation_invariant_input(campaign, entry):
"""Test whether permutation invariant measurements can be added."""
add_fake_results(entry, campaign)

# Create augmented combinations
entries = df_apply_permutation_augmentation(
entry,
columns=["Solvent_1", "Solvent_2", "Solvent_3"],
dependents=["Fraction_1", "Fraction_2", "Fraction_3"],
)

for _, row in entries.iterrows():
# Reset searchspace metadata
campaign.searchspace.discrete.metadata["was_measured"] = False

# Assert that not NoSearchspaceMatchWarning is thrown
with warnings.catch_warnings():
print(row.to_frame().T)
warnings.simplefilter("error", category=NoSearchspaceMatchWarning)
campaign.add_measurements(pd.DataFrame([row]))

# Assert exactly one searchspace entry has been marked
num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum()
assert num_nonzero == 1, (
"Measurement ingestion was successful, but did not correctly update the "
f"searchspace metadata. Number of non-zero entries: {num_nonzero} "
f"(expected 1)"
)


@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
@pytest.mark.parametrize(
"entry",
[
pd.DataFrame.from_records(
[["THF", "Water", "DMF", 0.0, 25.0, 75.0]],
columns=_mixture_columns,
),
pd.DataFrame.from_records(
[["THF", "Water", "DMF", 0.0, 0.0, 50.0]],
columns=_mixture_columns,
),
],
ids=["single_degen", "double_degen"],
)
@pytest.mark.parametrize("parameter_names", [_mixture_columns])
@pytest.mark.parametrize(
"constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]]
)
def test_dependency_invariant_input(campaign, entry):
"""Test whether dependency invariant measurements can be added."""
# Get an entry from the searchspace
add_fake_results(entry, campaign)
sol_vals = campaign.searchspace.get_parameters_by_name(["Solvent_1"])[0].values

# Create augmented combinations
entries = df_apply_dependency_augmentation(
entry, causing=("Fraction_1", [0.0]), affected=[("Solvent_1", sol_vals)]
)
entries = df_apply_dependency_augmentation(
entries, causing=("Fraction_2", [0.0]), affected=[("Solvent_2", sol_vals)]
)
entries = df_apply_dependency_augmentation(
entries, causing=("Fraction_3", [0.0]), affected=[("Solvent_3", sol_vals)]
)

# Remove falsely created label duplicates
entries.reset_index(drop=True, inplace=True)
for c in campaign.searchspace.discrete.constraints:
if isinstance(c, DiscreteNoLabelDuplicatesConstraint):
entries.drop(index=c.get_invalid(entries), inplace=True)

# Add nan entries for testing nan input in the invariant parameters
entry_nan = entry.copy()
entry_nan.loc[entry_nan["Fraction_1"] == 0.0, "Solvent_1"] = np.nan
entry_nan.loc[entry_nan["Fraction_2"] == 0.0, "Solvent_2"] = np.nan
entry_nan.loc[entry_nan["Fraction_3"] == 0.0, "Solvent_3"] = np.nan

for _, row in pd.concat([entries, entry_nan]).iterrows():
# Reset searchspace metadata
campaign.searchspace.discrete.metadata["was_measured"] = False

# Assert that not NoSearchspaceMatchWarning is thrown
with warnings.catch_warnings():
print(row.to_frame().T)
warnings.simplefilter("error", category=NoSearchspaceMatchWarning)
campaign.add_measurements(pd.DataFrame([row]))

# Assert exactly one searchspace entry has been marked
num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum()
assert num_nonzero == 1, (
"Measurement ingestion was successful, but did not correctly update the "
f"searchspace metadata. Number of non-zero entries: {num_nonzero} "
f"(expected 1)"
)

0 comments on commit 528c40a

Please sign in to comment.