Skip to content

Commit

Permalink
Fix bug with initial state output
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Jan 20, 2025
1 parent fccad0e commit 39497f3
Show file tree
Hide file tree
Showing 12 changed files with 23 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Context(ABC):
time_step = None
lead_time = None
output_frequency = None
write_initial_state = True

##################################################################

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
class Output(ABC):
"""_summary_"""

def __init__(self, context, output_frequency=None, write_initial_step=False):
def __init__(self, context, output_frequency=None, write_initial_state=True):
from anemoi.utils.dates import as_timedelta

self.context = context
self.checkpoint = context.checkpoint
self.reference_date = None

self.write_step_zero = write_initial_step and context.write_initial_step
self.write_step_zero = write_initial_state and context.write_initial_state

self.output_frequency = output_frequency or context.output_frequency
if self.output_frequency is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/outputs/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
class ApplyMaskOutput(Output):
"""_summary_"""

def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
def __init__(self, context, *, mask, output, output_frequency=None, write_initial_state=True):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.mask = self.checkpoint.load_supporting_array(mask)
self.output = create_output(context, output)

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
class ExtractLamOutput(Output):
"""_summary_"""

def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_state=True):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)

if isinstance(points, str):
mask = self.checkpoint.load_supporting_array(points)
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def __init__(
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
write_initial_state=True,
):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
self.quiet = set()
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
write_initial_state=True,
**kwargs,
):
super().__init__(
Expand All @@ -102,7 +102,7 @@ def __init__(
grib2_keys=grib2_keys,
modifiers=modifiers,
output_frequency=output_frequency,
write_initial_step=write_initial_step,
write_initial_state=write_initial_state,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/outputs/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
class NetCDFOutput(Output):
"""_summary_"""

def __init__(self, context, path, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
def __init__(self, context, path, output_frequency=None, write_initial_state=True):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.path = path
self.ncfile = None
self.float_size = "f4"
Expand Down
8 changes: 2 additions & 6 deletions src/anemoi/inference/outputs/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
dpi=300,
format="png",
output_frequency=None,
write_initial_step=False,
write_initial_state=True,
):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.path = path
self.format = format
self.variables = variables
Expand All @@ -50,10 +50,6 @@ def __init__(
if not isinstance(self.variables, (list, tuple)):
self.variables = [self.variables]

def write_initial_step(self, state):
reduced_state = self.reduce(state)
self.write_step(reduced_state)

def write_step(self, state, step):
import cartopy.crs as ccrs
import cartopy.feature as cfeature
Expand Down
8 changes: 2 additions & 6 deletions src/anemoi/inference/outputs/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,16 @@ def __init__(
template="{date}.npz",
strftime="%Y%m%d%H%M%S",
output_frequency=None,
write_initial_step=False,
write_initial_state=True,
):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.path = path
self.template = template
self.strftime = strftime

def __repr__(self):
return f"RawOutput({self.path})"

def write_initial_step(self, state):
reduced_state = self.reduce(state)
self.write_step(reduced_state)

def write_step(self, state, step):
os.makedirs(self.path, exist_ok=True)
date = state["date"].strftime(self.strftime)
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/inference/outputs/tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
class TeeOutput(Output):
"""_summary_"""

def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_step=False, **kwargs):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_state=True, **kwargs):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
if outputs is None:
outputs = args
assert isinstance(outputs, (list, tuple)), outputs
self.outputs = [create_output(context, output) for output in outputs]

def write_initial_step(self, state):
def write_initial_step(self, state, step):
for output in self.outputs:
output.write_initial_state(state)

Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
patch_metadata={},
development_hacks={}, # For testing purposes, don't use in production
output_frequency=None,
write_initial_state=True,
):
self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata)

Expand All @@ -79,6 +80,7 @@ def __init__(
self.development_hacks = development_hacks
self.hacks = bool(development_hacks)
self.output_frequency = output_frequency
self.write_initial_state = write_initial_state

# This could also be passed as an argument

Expand Down
1 change: 1 addition & 0 deletions src/anemoi/inference/runners/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, config):
patch_metadata=config.patch_metadata,
development_hacks=config.development_hacks,
output_frequency=config.output_frequency,
write_initial_state=config.write_initial_state,
)

def create_input(self):
Expand Down

0 comments on commit 39497f3

Please sign in to comment.