Skip to content

Commit

Permalink
fix: extract_lam output
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Jan 20, 2025
1 parent e89974c commit a9f869a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
4 changes: 4 additions & 0 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def variable_categories(self):
def load_supporting_array(self, name):
return self._metadata.load_supporting_array(name)

@property
def supporting_arrays(self):
return self._metadata.supporting_arrays

###########################################################################

@cached_property
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,10 @@ def load_supporting_array(self, name):
raise ValueError(f"Supporting array `{name}` not found")
return self._supporting_arrays[name]

@property
def supporting_arrays(self):
return self._supporting_arrays

@property
def latitudes(self):
return self._supporting_arrays.get("latitudes")
Expand Down
35 changes: 23 additions & 12 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,26 @@
class ExtractLamOutput(Output):
"""_summary_"""

def __init__(self, context, *, output, points="cutout_mask"):
def __init__(self, context, *, output, lam="lam_0"):
super().__init__(context)
if isinstance(points, str):
mask = self.checkpoint.load_supporting_array(points)
points = -np.sum(mask) # This is the global, we want the lam

LOG.info("context.checkpoint.supporting_arrays %s", list(context.checkpoint.supporting_arrays.keys()))
LOG.info("%s", len(context.checkpoint.supporting_arrays["grid_indices"]))

if "cutout_mask" in self.checkpoint.supporting_arrays:
# Backwards compatibility
mask = self.checkpoint.load_supporting_array("cutout_mask")
points = slice(None, -np.sum(mask))
else:
if lam != "lam_0":
raise NotImplementedError("Only lam_0 is supported")

if "lam_1/cutout_mask" in self.checkpoint.supporting_arrays:
raise NotImplementedError("Only lam_0 is supported")

mask = self.checkpoint.load_supporting_array(f"{lam}/cutout_mask")
assert len(mask) == np.sum(mask)
points = slice(None, np.sum(mask))

self.points = points
self.output = create_output(context, output)
Expand All @@ -42,21 +57,17 @@ def write_state(self, state):

def _apply_mask(self, state):

if self.points < 0:
# This is the global, we want the lam
self.points = state["latitudes"].size + self.points

state = state.copy()
state["fields"] = state["fields"].copy()
state["latitudes"] = state["latitudes"][: self.points]
state["longitudes"] = state["longitudes"][: self.points]
state["latitudes"] = state["latitudes"][self.points]
state["longitudes"] = state["longitudes"][self.points]

for field in state["fields"]:
data = state["fields"][field]
if data.ndim == 1:
data = data[: self.points]
data = data[self.points]
else:
data = data[..., : self.points]
data = data[..., self.points]
state["fields"][field] = data

return state
Expand Down

0 comments on commit a9f869a

Please sign in to comment.