diff --git a/swiftemulator/emulators/base.py b/swiftemulator/emulators/base.py index df2cbbc..6b1e4c8 100644 --- a/swiftemulator/emulators/base.py +++ b/swiftemulator/emulators/base.py @@ -78,7 +78,7 @@ def fit_model( def predict_values( self, independent: np.array, model_parameters: Dict[str, float] - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. diff --git a/swiftemulator/emulators/gaussian_process.py b/swiftemulator/emulators/gaussian_process.py index 3db612e..7d40183 100644 --- a/swiftemulator/emulators/gaussian_process.py +++ b/swiftemulator/emulators/gaussian_process.py @@ -229,7 +229,7 @@ def grad_negative_log_likelihood(p): def predict_values( self, independent: np.array, model_parameters: Dict[str, float] - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. diff --git a/swiftemulator/emulators/gaussian_process_bins.py b/swiftemulator/emulators/gaussian_process_bins.py index 90380ef..90e4110 100644 --- a/swiftemulator/emulators/gaussian_process_bins.py +++ b/swiftemulator/emulators/gaussian_process_bins.py @@ -267,7 +267,7 @@ def grad_negative_log_likelihood(p): def predict_values( self, independent: np.array, model_parameters: Dict[str, float] - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. diff --git a/swiftemulator/emulators/gaussian_process_mcmc.py b/swiftemulator/emulators/gaussian_process_mcmc.py index 82a3891..b0609a7 100644 --- a/swiftemulator/emulators/gaussian_process_mcmc.py +++ b/swiftemulator/emulators/gaussian_process_mcmc.py @@ -339,7 +339,7 @@ def predict_values( self, independent: np.array, model_parameters: Dict[str, float], - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. diff --git a/swiftemulator/emulators/gaussian_process_one_dim.py b/swiftemulator/emulators/gaussian_process_one_dim.py index 439ddbf..8f5a9b2 100644 --- a/swiftemulator/emulators/gaussian_process_one_dim.py +++ b/swiftemulator/emulators/gaussian_process_one_dim.py @@ -226,7 +226,9 @@ def grad_negative_log_likelihood(p): return - def predict_values(self, model_parameters: Dict[str, float]) -> np.array: + def predict_values( + self, model_parameters: Dict[str, float] + ) -> tuple[np.array, np.array]: """ Predict a value from the trained emulator contained within this object. returns the value at the input model parameters. diff --git a/swiftemulator/emulators/linear_model.py b/swiftemulator/emulators/linear_model.py index 671e409..94c5507 100644 --- a/swiftemulator/emulators/linear_model.py +++ b/swiftemulator/emulators/linear_model.py @@ -164,7 +164,7 @@ def fit_model( def predict_values( self, independent: np.array, model_parameters: Dict[str, float] - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. diff --git a/swiftemulator/emulators/multi_gaussian_process.py b/swiftemulator/emulators/multi_gaussian_process.py index 2b5f45e..764ff13 100644 --- a/swiftemulator/emulators/multi_gaussian_process.py +++ b/swiftemulator/emulators/multi_gaussian_process.py @@ -182,7 +182,7 @@ def predict_values( self, independent: np.array, model_parameters: Dict[str, float], - ) -> np.array: + ) -> tuple[np.array, np.array]: """ Predict values from the trained emulator contained within this object. @@ -231,12 +231,16 @@ def predict_values( for index, (low, high) in enumerate(self.independent_regions): mask = np.logical_and( - independent > low - if low is not None - else np.ones_like(independent).astype(bool), - independent < high - if high is not None - else np.ones_like(independent).astype(bool), + ( + independent > low + if low is not None + else np.ones_like(independent).astype(bool) + ), + ( + independent < high + if high is not None + else np.ones_like(independent).astype(bool) + ), ) predicted, errors = self.emulators[index].predict_values( diff --git a/swiftemulator/io/swift.py b/swiftemulator/io/swift.py index e9b3247..8429525 100644 --- a/swiftemulator/io/swift.py +++ b/swiftemulator/io/swift.py @@ -103,8 +103,8 @@ def load_pipeline_outputs( "adaptive_mass_function", "histogram", ] - recursive_search = ( - lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None + recursive_search = lambda d, k: ( + d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None ) line_search = lambda d: recursive_search(d, line_types)