From 0d0e6d91d4b8aa2bb06b6948b61ace15d9062704 Mon Sep 17 00:00:00 2001 From: Kobi Felton Date: Fri, 2 Dec 2022 17:37:54 +0000 Subject: [PATCH] Improve TSEMO categorical (#220) --- summit/domain.py | 7 +++++++ summit/strategies/tsemo.py | 10 ++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/summit/domain.py b/summit/domain.py index 45f86910..48adac6d 100644 --- a/summit/domain.py +++ b/summit/domain.py @@ -551,6 +551,13 @@ def output_variables(self): pass return output_variables + def num_categorical_variables(self): + k = 0 + for v in self.variables: + if v.variable_type == "categorical": + k += 1 + return k + def get_categorical_combinations(self): """Get all combinations of categoricals using full factorial design diff --git a/summit/strategies/tsemo.py b/summit/strategies/tsemo.py index aa8a695b..2b8cb899 100644 --- a/summit/strategies/tsemo.py +++ b/summit/strategies/tsemo.py @@ -231,15 +231,17 @@ def suggest_experiments(self, num_experiments, prev_res: DataSet = None, **kwarg # NSGAII internal optimisation on spectrally sampled functions self.logger.info("Optimizing models using NSGAII.") + # Categorical only domain + if (self.domain.num_continuous_dimensions() == 0) and ( + self.domain.num_categorical_variables() == 1 + ): + X, y = self._categorical_enumerate(models) # Mixed domains - if self.categorical_combos is not None and len(self.input_columns) > 0: + elif self.categorical_combos is not None and len(self.input_columns) > 1: X, y = self._nsga_optimize_mixed(models) # Continous domains elif self.categorical_combos is None and len(self.input_columns) > 0: X, y = self._nsga_optimize(models) - # Categorical only domain - else: - X, y = self._categorical_enumerate(models) # Return if no suggestiosn found if X.shape[0] == 0 and y.shape[0] == 0: