From 39497f3692812376a62eba92c82a69a7cca17ac5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 20 Jan 2025 19:51:02 +0000 Subject: [PATCH] Fix bug with initial state output --- src/anemoi/inference/context.py | 1 + src/anemoi/inference/output.py | 4 ++-- src/anemoi/inference/outputs/apply_mask.py | 4 ++-- src/anemoi/inference/outputs/extract_lam.py | 4 ++-- src/anemoi/inference/outputs/grib.py | 4 ++-- src/anemoi/inference/outputs/gribfile.py | 4 ++-- src/anemoi/inference/outputs/netcdf.py | 4 ++-- src/anemoi/inference/outputs/plot.py | 8 ++------ src/anemoi/inference/outputs/raw.py | 8 ++------ src/anemoi/inference/outputs/tee.py | 6 +++--- src/anemoi/inference/runner.py | 2 ++ src/anemoi/inference/runners/default.py | 1 + 12 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/anemoi/inference/context.py b/src/anemoi/inference/context.py index b3ecaff..e5a27e5 100644 --- a/src/anemoi/inference/context.py +++ b/src/anemoi/inference/context.py @@ -31,6 +31,7 @@ class Context(ABC): time_step = None lead_time = None output_frequency = None + write_initial_state = True ################################################################## diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index 382542c..dfc0e64 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -16,14 +16,14 @@ class Output(ABC): """_summary_""" - def __init__(self, context, output_frequency=None, write_initial_step=False): + def __init__(self, context, output_frequency=None, write_initial_state=True): from anemoi.utils.dates import as_timedelta self.context = context self.checkpoint = context.checkpoint self.reference_date = None - self.write_step_zero = write_initial_step and context.write_initial_step + self.write_step_zero = write_initial_state and context.write_initial_state self.output_frequency = output_frequency or context.output_frequency if self.output_frequency is not None: diff --git a/src/anemoi/inference/outputs/apply_mask.py b/src/anemoi/inference/outputs/apply_mask.py index 73dccd3..493e913 100644 --- a/src/anemoi/inference/outputs/apply_mask.py +++ b/src/anemoi/inference/outputs/apply_mask.py @@ -20,8 +20,8 @@ class ApplyMaskOutput(Output): """_summary_""" - def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *, mask, output, output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.mask = self.checkpoint.load_supporting_array(mask) self.output = create_output(context, output) diff --git a/src/anemoi/inference/outputs/extract_lam.py b/src/anemoi/inference/outputs/extract_lam.py index e7bdbd6..19a4ccf 100644 --- a/src/anemoi/inference/outputs/extract_lam.py +++ b/src/anemoi/inference/outputs/extract_lam.py @@ -22,8 +22,8 @@ class ExtractLamOutput(Output): """_summary_""" - def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) if isinstance(points, str): mask = self.checkpoint.load_supporting_array(points) diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index a9f0269..0c92740 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -82,9 +82,9 @@ def __init__( grib2_keys=None, modifiers=None, output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self._first = True self.typed_variables = self.checkpoint.typed_variables self.quiet = set() diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index db36345..48b27a3 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -91,7 +91,7 @@ def __init__( grib2_keys=None, modifiers=None, output_frequency=None, - write_initial_step=False, + write_initial_state=True, **kwargs, ): super().__init__( @@ -102,7 +102,7 @@ def __init__( grib2_keys=grib2_keys, modifiers=modifiers, output_frequency=output_frequency, - write_initial_step=write_initial_step, + write_initial_state=write_initial_state, ) self.path = path self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs) diff --git a/src/anemoi/inference/outputs/netcdf.py b/src/anemoi/inference/outputs/netcdf.py index e3f5431..bd66a6a 100644 --- a/src/anemoi/inference/outputs/netcdf.py +++ b/src/anemoi/inference/outputs/netcdf.py @@ -24,8 +24,8 @@ class NetCDFOutput(Output): """_summary_""" - def __init__(self, context, path, output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, path, output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.ncfile = None self.float_size = "f4" diff --git a/src/anemoi/inference/outputs/plot.py b/src/anemoi/inference/outputs/plot.py index 7748a27..ec3578c 100644 --- a/src/anemoi/inference/outputs/plot.py +++ b/src/anemoi/inference/outputs/plot.py @@ -36,9 +36,9 @@ def __init__( dpi=300, format="png", output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.format = format self.variables = variables @@ -50,10 +50,6 @@ def __init__( if not isinstance(self.variables, (list, tuple)): self.variables = [self.variables] - def write_initial_step(self, state): - reduced_state = self.reduce(state) - self.write_step(reduced_state) - def write_step(self, state, step): import cartopy.crs as ccrs import cartopy.feature as cfeature diff --git a/src/anemoi/inference/outputs/raw.py b/src/anemoi/inference/outputs/raw.py index db3502c..4df6b40 100644 --- a/src/anemoi/inference/outputs/raw.py +++ b/src/anemoi/inference/outputs/raw.py @@ -31,9 +31,9 @@ def __init__( template="{date}.npz", strftime="%Y%m%d%H%M%S", output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.template = template self.strftime = strftime @@ -41,10 +41,6 @@ def __init__( def __repr__(self): return f"RawOutput({self.path})" - def write_initial_step(self, state): - reduced_state = self.reduce(state) - self.write_step(reduced_state) - def write_step(self, state, step): os.makedirs(self.path, exist_ok=True) date = state["date"].strftime(self.strftime) diff --git a/src/anemoi/inference/outputs/tee.py b/src/anemoi/inference/outputs/tee.py index 0b237e5..62b016c 100644 --- a/src/anemoi/inference/outputs/tee.py +++ b/src/anemoi/inference/outputs/tee.py @@ -20,14 +20,14 @@ class TeeOutput(Output): """_summary_""" - def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_step=False, **kwargs): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_state=True, **kwargs): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) if outputs is None: outputs = args assert isinstance(outputs, (list, tuple)), outputs self.outputs = [create_output(context, output) for output in outputs] - def write_initial_step(self, state): + def write_initial_step(self, state, step): for output in self.outputs: output.write_initial_state(state) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index a5f6dbd..4d2563e 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -65,6 +65,7 @@ def __init__( patch_metadata={}, development_hacks={}, # For testing purposes, don't use in production output_frequency=None, + write_initial_state=True, ): self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata) @@ -79,6 +80,7 @@ def __init__( self.development_hacks = development_hacks self.hacks = bool(development_hacks) self.output_frequency = output_frequency + self.write_initial_state = write_initial_state # This could also be passed as an argument diff --git a/src/anemoi/inference/runners/default.py b/src/anemoi/inference/runners/default.py index c43fd85..99df321 100644 --- a/src/anemoi/inference/runners/default.py +++ b/src/anemoi/inference/runners/default.py @@ -51,6 +51,7 @@ def __init__(self, config): patch_metadata=config.patch_metadata, development_hacks=config.development_hacks, output_frequency=config.output_frequency, + write_initial_state=config.write_initial_state, ) def create_input(self):