diff --git a/baybe/exceptions.py b/baybe/exceptions.py index dd3010ace3..340d8b8e4f 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -1,6 +1,16 @@ """Custom exceptions.""" +### WARNINGS ### +class NoSearchspaceMatchWarning(UserWarning): + """The provided input has no match in the searchspace.""" + + +class TooManySearchspaceMatchesWarning(UserWarning): + """The provided input has multiple matches in the searchspace.""" + + +### EXCEPTIONS ### class NotEnoughPointsLeftError(Exception): """ More recommendations are requested than there are viable parameter configurations diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 5bc09c2709..b8d4ed02ee 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings from collections.abc import Iterable, Iterator, Sequence from typing import ( TYPE_CHECKING, @@ -13,6 +14,7 @@ import numpy as np import pandas as pd +from baybe.exceptions import NoSearchspaceMatchWarning, TooManySearchspaceMatchesWarning from baybe.targets.enum import TargetMode from baybe.utils.numerical import DTypeFloatNumpy @@ -417,17 +419,17 @@ def fuzzy_row_match( # We expect exactly one match. If that's not the case, print a warning. inds_found = left_df.index[match].to_list() if len(inds_found) == 0 and len(num_cols) > 0: - _logger.warning( - "Input row with index %s could not be matched to the search space. " + warnings.warn( + f"Input row with index {ind} could not be matched to the search space. " "This could indicate that something went wrong.", - ind, + NoSearchspaceMatchWarning, ) elif len(inds_found) > 1: - _logger.warning( - "Input row with index %s has multiple matches with " + warnings.warn( + f"Input row with index {ind} has multiple matches with " "the search space. This could indicate that something went wrong. " "Matching only first occurrence.", - ind, + TooManySearchspaceMatchesWarning, ) inds_matched.append(inds_found[0]) else: diff --git a/tests/test_input_output.py b/tests/test_input_output.py index cc10607959..4ec96184d0 100644 --- a/tests/test_input_output.py +++ b/tests/test_input_output.py @@ -1,13 +1,18 @@ """Tests for basic input-output and iterative loop.""" +import warnings + import numpy as np +import pandas as pd import pytest +from baybe.constraints import DiscreteNoLabelDuplicatesConstraint +from baybe.exceptions import NoSearchspaceMatchWarning +from baybe.utils.augmentation import ( + df_apply_dependency_augmentation, + df_apply_permutation_augmentation, +) from baybe.utils.dataframe import add_fake_results -# List of tests that are expected to fail (still missing implementation etc) -param_xfails = [] -target_xfails = [] - @pytest.mark.parametrize( "bad_val", @@ -16,9 +21,6 @@ ) def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, request): """Test attempting to read in an invalid parameter value.""" - if request.node.callspec.id in param_xfails: - pytest.xfail() - rec = campaign.recommend(batch_size=3) add_fake_results( rec, @@ -27,7 +29,11 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req ) # Add an invalid value - rec.Num_disc_1.iloc[0] = bad_val + with warnings.catch_warnings(): + # Ignore warning about incompatible data type assignment + warnings.simplefilter("ignore", FutureWarning) + rec.iloc[0, rec.columns.get_loc("Num_disc_1")] = bad_val + with pytest.raises((ValueError, TypeError)): campaign.add_measurements(rec) @@ -39,9 +45,6 @@ def test_bad_parameter_input_value(campaign, good_reference_values, bad_val, req ) def test_bad_target_input_value(campaign, good_reference_values, bad_val, request): """Test attempting to read in an invalid target value.""" - if request.node.callspec.id in target_xfails: - pytest.xfail() - rec = campaign.recommend(batch_size=3) add_fake_results( rec, @@ -50,6 +53,131 @@ def test_bad_target_input_value(campaign, good_reference_values, bad_val, reques ) # Add an invalid value - rec.Target_max.iloc[0] = bad_val + with warnings.catch_warnings(): + # Ignore warning about incompatible data type assignment + warnings.simplefilter("ignore", FutureWarning) + rec.iloc[0, rec.columns.get_loc("Target_max")] = bad_val + with pytest.raises((ValueError, TypeError)): campaign.add_measurements(rec) + + +# Reused parameter names for the mixture mock example +_mixture_columns = [ + "Solvent_1", + "Solvent_2", + "Solvent_3", + "Fraction_1", + "Fraction_2", + "Fraction_3", +] + + +@pytest.mark.parametrize("n_grid_points", [5]) +@pytest.mark.parametrize( + "entry", + [ + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 25.0, 75.0]], columns=_mixture_columns + ), + ], +) +@pytest.mark.parametrize("parameter_names", [_mixture_columns]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]] +) +def test_permutation_invariant_input(campaign, entry): + """Test whether permutation invariant measurements can be added.""" + add_fake_results(entry, campaign) + + # Create augmented combinations + entries = df_apply_permutation_augmentation( + entry, + columns=["Solvent_1", "Solvent_2", "Solvent_3"], + dependents=["Fraction_1", "Fraction_2", "Fraction_3"], + ) + + for _, row in entries.iterrows(): + # Reset searchspace metadata + campaign.searchspace.discrete.metadata["was_measured"] = False + + # Assert that not NoSearchspaceMatchWarning is thrown + with warnings.catch_warnings(): + print(row.to_frame().T) + warnings.simplefilter("error", category=NoSearchspaceMatchWarning) + campaign.add_measurements(pd.DataFrame([row])) + + # Assert exactly one searchspace entry has been marked + num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum() + assert num_nonzero == 1, ( + "Measurement ingestion was successful, but did not correctly update the " + f"searchspace metadata. Number of non-zero entries: {num_nonzero} " + f"(expected 1)" + ) + + +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize( + "entry", + [ + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 25.0, 75.0]], + columns=_mixture_columns, + ), + pd.DataFrame.from_records( + [["THF", "Water", "DMF", 0.0, 0.0, 50.0]], + columns=_mixture_columns, + ), + ], + ids=["single_degen", "double_degen"], +) +@pytest.mark.parametrize("parameter_names", [_mixture_columns]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_7", "Constraint_11", "Constraint_12"]] +) +def test_dependency_invariant_input(campaign, entry): + """Test whether dependency invariant measurements can be added.""" + # Get an entry from the searchspace + add_fake_results(entry, campaign) + sol_vals = campaign.searchspace.get_parameters_by_name(["Solvent_1"])[0].values + + # Create augmented combinations + entries = df_apply_dependency_augmentation( + entry, causing=("Fraction_1", [0.0]), affected=[("Solvent_1", sol_vals)] + ) + entries = df_apply_dependency_augmentation( + entries, causing=("Fraction_2", [0.0]), affected=[("Solvent_2", sol_vals)] + ) + entries = df_apply_dependency_augmentation( + entries, causing=("Fraction_3", [0.0]), affected=[("Solvent_3", sol_vals)] + ) + + # Remove falsely created label duplicates + entries.reset_index(drop=True, inplace=True) + for c in campaign.searchspace.discrete.constraints: + if isinstance(c, DiscreteNoLabelDuplicatesConstraint): + entries.drop(index=c.get_invalid(entries), inplace=True) + + # Add nan entries for testing nan input in the invariant parameters + entry_nan = entry.copy() + entry_nan.loc[entry_nan["Fraction_1"] == 0.0, "Solvent_1"] = np.nan + entry_nan.loc[entry_nan["Fraction_2"] == 0.0, "Solvent_2"] = np.nan + entry_nan.loc[entry_nan["Fraction_3"] == 0.0, "Solvent_3"] = np.nan + + for _, row in pd.concat([entries, entry_nan]).iterrows(): + # Reset searchspace metadata + campaign.searchspace.discrete.metadata["was_measured"] = False + + # Assert that not NoSearchspaceMatchWarning is thrown + with warnings.catch_warnings(): + print(row.to_frame().T) + warnings.simplefilter("error", category=NoSearchspaceMatchWarning) + campaign.add_measurements(pd.DataFrame([row])) + + # Assert exactly one searchspace entry has been marked + num_nonzero = campaign.searchspace.discrete.metadata["was_measured"].sum() + assert num_nonzero == 1, ( + "Measurement ingestion was successful, but did not correctly update the " + f"searchspace metadata. Number of non-zero entries: {num_nonzero} " + f"(expected 1)" + )