From 3a122cc62eeeb185494cbf1e763b6a2c4cd437fb Mon Sep 17 00:00:00 2001 From: Joshua Eckels Date: Tue, 3 Dec 2024 18:04:36 -0700 Subject: [PATCH] fix: allow skipping nan in relative error --- src/amisc/system.py | 2 +- src/amisc/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/amisc/system.py b/src/amisc/system.py index 9988a06..e598c17 100644 --- a/src/amisc/system.py +++ b/src/amisc/system.py @@ -849,7 +849,7 @@ def refine(self, targets: list = None, num_refine: int = 100, update_bounds: boo error = {} for var, arr in y_cand.items(): if var in targets: - error[var] = relative_error(arr, y_curr[var]) + error[var] = relative_error(arr, y_curr[var], skip_nan=True) if update_bounds and var in coupling_vars: y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True) diff --git a/src/amisc/utils.py b/src/amisc/utils.py index 6de5499..4feb44a 100644 --- a/src/amisc/utils.py +++ b/src/amisc/utils.py @@ -529,16 +529,18 @@ def parse_function_string(call_string: str) -> tuple[str, list, dict]: return name, args, kwargs -def relative_error(pred, targ, axis=None): +def relative_error(pred, targ, axis=None, skip_nan=False): """Compute the relative L2 error between two vectors along the given axis. :param pred: the predicted values :param targ: the target values :param axis: the axis along which to compute the error + :param skip_nan: whether to skip NaN values in the error calculation :returns: the relative L2 error """ with np.errstate(divide='ignore', invalid='ignore'): - err = np.sqrt(np.sum((pred - targ)**2, axis=axis) / np.sum(targ**2, axis=axis)) + sum_func = np.nansum if skip_nan else np.sum + err = np.sqrt(sum_func((pred - targ)**2, axis=axis) / sum_func(targ**2, axis=axis)) return np.nan_to_num(err, nan=np.nan, posinf=np.nan, neginf=np.nan)