Skip to content

Commit

Permalink
Change the return type annotation of BaseEmulator.predict_values() to…
Browse files Browse the repository at this point in the history
… tuple[np.array, np.array] from np.array. Applied formatting script
MatthieuSchaller committed Aug 10, 2024
1 parent 6bc0f31 commit b65c35b
Showing 8 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def fit_model(

def predict_values(
self, independent: np.array, model_parameters: Dict[str, float]
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
2 changes: 1 addition & 1 deletion swiftemulator/emulators/gaussian_process.py
Original file line number Diff line number Diff line change
@@ -229,7 +229,7 @@ def grad_negative_log_likelihood(p):

def predict_values(
self, independent: np.array, model_parameters: Dict[str, float]
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
2 changes: 1 addition & 1 deletion swiftemulator/emulators/gaussian_process_bins.py
Original file line number Diff line number Diff line change
@@ -267,7 +267,7 @@ def grad_negative_log_likelihood(p):

def predict_values(
self, independent: np.array, model_parameters: Dict[str, float]
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
2 changes: 1 addition & 1 deletion swiftemulator/emulators/gaussian_process_mcmc.py
Original file line number Diff line number Diff line change
@@ -339,7 +339,7 @@ def predict_values(
self,
independent: np.array,
model_parameters: Dict[str, float],
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
4 changes: 3 additions & 1 deletion swiftemulator/emulators/gaussian_process_one_dim.py
Original file line number Diff line number Diff line change
@@ -226,7 +226,9 @@ def grad_negative_log_likelihood(p):

return

def predict_values(self, model_parameters: Dict[str, float]) -> np.array:
def predict_values(
self, model_parameters: Dict[str, float]
) -> tuple[np.array, np.array]:
"""
Predict a value from the trained emulator contained within this object.
returns the value at the input model parameters.
2 changes: 1 addition & 1 deletion swiftemulator/emulators/linear_model.py
Original file line number Diff line number Diff line change
@@ -164,7 +164,7 @@ def fit_model(

def predict_values(
self, independent: np.array, model_parameters: Dict[str, float]
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
18 changes: 11 additions & 7 deletions swiftemulator/emulators/multi_gaussian_process.py
Original file line number Diff line number Diff line change
@@ -182,7 +182,7 @@ def predict_values(
self,
independent: np.array,
model_parameters: Dict[str, float],
) -> np.array:
) -> tuple[np.array, np.array]:
"""
Predict values from the trained emulator contained within this object.
@@ -231,12 +231,16 @@ def predict_values(

for index, (low, high) in enumerate(self.independent_regions):
mask = np.logical_and(
independent > low
if low is not None
else np.ones_like(independent).astype(bool),
independent < high
if high is not None
else np.ones_like(independent).astype(bool),
(
independent > low
if low is not None
else np.ones_like(independent).astype(bool)
),
(
independent < high
if high is not None
else np.ones_like(independent).astype(bool)
),
)

predicted, errors = self.emulators[index].predict_values(
4 changes: 2 additions & 2 deletions swiftemulator/io/swift.py
Original file line number Diff line number Diff line change
@@ -103,8 +103,8 @@ def load_pipeline_outputs(
"adaptive_mass_function",
"histogram",
]
recursive_search = (
lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
recursive_search = lambda d, k: (
d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
)
line_search = lambda d: recursive_search(d, line_types)

0 comments on commit b65c35b

Please sign in to comment.