Skip to content

Commit

Permalink
fix: allow skipping nan in relative error
Browse files Browse the repository at this point in the history
  • Loading branch information
eckelsjd committed Dec 4, 2024
1 parent 8ee6a0a commit 3a122cc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/amisc/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def refine(self, targets: list = None, num_refine: int = 100, update_bounds: boo
error = {}
for var, arr in y_cand.items():
if var in targets:
error[var] = relative_error(arr, y_curr[var])
error[var] = relative_error(arr, y_curr[var], skip_nan=True)

if update_bounds and var in coupling_vars:
y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True)
Expand Down
6 changes: 4 additions & 2 deletions src/amisc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,18 @@ def parse_function_string(call_string: str) -> tuple[str, list, dict]:
return name, args, kwargs


def relative_error(pred, targ, axis=None):
def relative_error(pred, targ, axis=None, skip_nan=False):
"""Compute the relative L2 error between two vectors along the given axis.
:param pred: the predicted values
:param targ: the target values
:param axis: the axis along which to compute the error
:param skip_nan: whether to skip NaN values in the error calculation
:returns: the relative L2 error
"""
with np.errstate(divide='ignore', invalid='ignore'):
err = np.sqrt(np.sum((pred - targ)**2, axis=axis) / np.sum(targ**2, axis=axis))
sum_func = np.nansum if skip_nan else np.sum
err = np.sqrt(sum_func((pred - targ)**2, axis=axis) / sum_func(targ**2, axis=axis))
return np.nan_to_num(err, nan=np.nan, posinf=np.nan, neginf=np.nan)


Expand Down

0 comments on commit 3a122cc

Please sign in to comment.