diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index bfaaeb1..486492c 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -216,6 +216,10 @@ def variable_categories(self): def load_supporting_array(self, name): return self._metadata.load_supporting_array(name) + @property + def supporting_arrays(self): + return self._metadata.supporting_arrays + ########################################################################### @cached_property diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index 25f29f6..ef59e98 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -741,6 +741,10 @@ def load_supporting_array(self, name): raise ValueError(f"Supporting array `{name}` not found") return self._supporting_arrays[name] + @property + def supporting_arrays(self): + return self._supporting_arrays + @property def latitudes(self): return self._supporting_arrays.get("latitudes") diff --git a/src/anemoi/inference/outputs/extract_lam.py b/src/anemoi/inference/outputs/extract_lam.py index a1ed2b7..8c23a7f 100644 --- a/src/anemoi/inference/outputs/extract_lam.py +++ b/src/anemoi/inference/outputs/extract_lam.py @@ -22,11 +22,26 @@ class ExtractLamOutput(Output): """_summary_""" - def __init__(self, context, *, output, points="cutout_mask"): + def __init__(self, context, *, output, lam="lam_0"): super().__init__(context) - if isinstance(points, str): - mask = self.checkpoint.load_supporting_array(points) - points = -np.sum(mask) # This is the global, we want the lam + + LOG.info("context.checkpoint.supporting_arrays %s", list(context.checkpoint.supporting_arrays.keys())) + LOG.info("%s", len(context.checkpoint.supporting_arrays["grid_indices"])) + + if "cutout_mask" in self.checkpoint.supporting_arrays: + # Backwards compatibility + mask = self.checkpoint.load_supporting_array("cutout_mask") + points = slice(None, -np.sum(mask)) + else: + if lam != "lam_0": + raise NotImplementedError("Only lam_0 is supported") + + if "lam_1/cutout_mask" in self.checkpoint.supporting_arrays: + raise NotImplementedError("Only lam_0 is supported") + + mask = self.checkpoint.load_supporting_array(f"{lam}/cutout_mask") + assert len(mask) == np.sum(mask) + points = slice(None, np.sum(mask)) self.points = points self.output = create_output(context, output) @@ -42,21 +57,17 @@ def write_state(self, state): def _apply_mask(self, state): - if self.points < 0: - # This is the global, we want the lam - self.points = state["latitudes"].size + self.points - state = state.copy() state["fields"] = state["fields"].copy() - state["latitudes"] = state["latitudes"][: self.points] - state["longitudes"] = state["longitudes"][: self.points] + state["latitudes"] = state["latitudes"][self.points] + state["longitudes"] = state["longitudes"][self.points] for field in state["fields"]: data = state["fields"][field] if data.ndim == 1: - data = data[: self.points] + data = data[self.points] else: - data = data[..., : self.points] + data = data[..., self.points] state["fields"][field] = data return state