From fccad0efe8b4770dc3cd06dd92ec6aef08f39add Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 19 Jan 2025 10:31:27 +0000 Subject: [PATCH 1/2] feat: Add support for `output_frequency` to write less output --- CHANGELOG.md | 1 + src/anemoi/inference/commands/run.py | 3 +- src/anemoi/inference/config.py | 3 ++ src/anemoi/inference/context.py | 10 ++++ src/anemoi/inference/output.py | 53 +++++++++++++++++++-- src/anemoi/inference/outputs/apply_mask.py | 12 ++--- src/anemoi/inference/outputs/extract_lam.py | 13 ++--- src/anemoi/inference/outputs/grib.py | 21 ++++++-- src/anemoi/inference/outputs/gribfile.py | 4 ++ src/anemoi/inference/outputs/netcdf.py | 49 +++++-------------- src/anemoi/inference/outputs/none.py | 5 +- src/anemoi/inference/outputs/plot.py | 10 ++-- src/anemoi/inference/outputs/printer.py | 5 +- src/anemoi/inference/outputs/raw.py | 10 ++-- src/anemoi/inference/outputs/tee.py | 14 ++++-- src/anemoi/inference/runner.py | 2 + src/anemoi/inference/runners/default.py | 1 + 17 files changed, 137 insertions(+), 79 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 446bda1..f55ef1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Add CONTRIBUTORS.md file (#36) - Add sanetise command - Add support for huggingface +- Add support for `output_frequency` to write less output ### Changed - Change `write_initial_state` default value to `true` diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py index 2922d41..c94e396 100644 --- a/src/anemoi/inference/commands/run.py +++ b/src/anemoi/inference/commands/run.py @@ -42,8 +42,7 @@ def run(self, args): input_state = input.create_input_state(date=config.date) - if config.write_initial_state: - output.write_initial_state(input_state) + output.write_initial_state(input_state) for state in runner.run(input_state=input_state, lead_time=config.lead_time): output.write_state(state) diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 0f39212..fc6e001 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -82,6 +82,9 @@ class Config: """Wether to write the initial state to the output file. If the model is multi-step, only fields at the forecast reference date are written.""" + output_frequency: Optional[str] = None + """The frequency at which to write the output. This can be a string or an integer. If a string, it is parsed by :func:`anemoi.utils.dates.as_timedelta`.""" + env: Dict[str, str | int] = {} """Environment variables to set before running the model. This may be useful to control some packages such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended diff --git a/src/anemoi/inference/context.py b/src/anemoi/inference/context.py index 58116b4..b3ecaff 100644 --- a/src/anemoi/inference/context.py +++ b/src/anemoi/inference/context.py @@ -24,6 +24,16 @@ class Context(ABC): verbosity = 0 development_hacks = {} # For testing purposes, don't use in production + # Some runners will set these values, which can be queried by Output objects, + # but may remain as None + + reference_date = None + time_step = None + lead_time = None + output_frequency = None + + ################################################################## + @property @abstractmethod def checkpoint(self): diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index ddb58f1..382542c 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -6,26 +6,67 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # +import logging from abc import ABC from abc import abstractmethod +LOG = logging.getLogger(__name__) + class Output(ABC): """_summary_""" - def __init__(self, context): + def __init__(self, context, output_frequency=None, write_initial_step=False): + from anemoi.utils.dates import as_timedelta + self.context = context self.checkpoint = context.checkpoint + self.reference_date = None + + self.write_step_zero = write_initial_step and context.write_initial_step + + self.output_frequency = output_frequency or context.output_frequency + if self.output_frequency is not None: + self.output_frequency = as_timedelta(self.output_frequency) def __repr__(self): return f"{self.__class__.__name__}()" - @abstractmethod def write_initial_state(self, state): - pass + self._init(state) + if self.write_step_zero: + return self.write_initial_step(state, state["date"] - self.reference_date) - @abstractmethod def write_state(self, state): + self._init(state) + + step = state["date"] - self.reference_date + if self.output_frequency is not None: + if (step % self.output_frequency).total_seconds() != 0: + return + + return self.write_step(state, step) + + def _init(self, state): + if self.reference_date is not None: + return + + self.reference_date = state["date"] + + self.open(state) + + def write_initial_step(self, state, step): + """This method should not be called directly + call `write_initial_state` instead. + """ + reduced_state = self.reduce(state) + self.write_step(reduced_state, step) + + @abstractmethod + def write_step(self, state, step): + """This method should be be called directly + call `write_state` instead. + """ pass def reduce(self, state): @@ -36,5 +77,9 @@ def reduce(self, state): reduced_state["fields"][field] = values[-1, :] return reduced_state + def open(self, state): + # Override this method when initialisation is needed + pass + def close(self): pass diff --git a/src/anemoi/inference/outputs/apply_mask.py b/src/anemoi/inference/outputs/apply_mask.py index f7c1ae4..73dccd3 100644 --- a/src/anemoi/inference/outputs/apply_mask.py +++ b/src/anemoi/inference/outputs/apply_mask.py @@ -20,19 +20,19 @@ class ApplyMaskOutput(Output): """_summary_""" - def __init__(self, context, *, mask, output): - super().__init__(context) + def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False): + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) self.mask = self.checkpoint.load_supporting_array(mask) self.output = create_output(context, output) def __repr__(self): return f"ApplyMaskOutput({self.mask}, {self.output})" - def write_initial_state(self, state): - self.output.write_initial_state(self._apply_mask(state)) + def write_initial_step(self, state, step): + self.output.write_initial_step(self._apply_mask(state), step) - def write_state(self, state): - self.output.write_state(self._apply_mask(state)) + def write_step(self, state, step): + self.output.write_step(self._apply_mask(state), step) def _apply_mask(self, state): state = state.copy() diff --git a/src/anemoi/inference/outputs/extract_lam.py b/src/anemoi/inference/outputs/extract_lam.py index a1ed2b7..e7bdbd6 100644 --- a/src/anemoi/inference/outputs/extract_lam.py +++ b/src/anemoi/inference/outputs/extract_lam.py @@ -22,8 +22,9 @@ class ExtractLamOutput(Output): """_summary_""" - def __init__(self, context, *, output, points="cutout_mask"): - super().__init__(context) + def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False): + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + if isinstance(points, str): mask = self.checkpoint.load_supporting_array(points) points = -np.sum(mask) # This is the global, we want the lam @@ -34,11 +35,11 @@ def __init__(self, context, *, output, points="cutout_mask"): def __repr__(self): return f"ExtractLamOutput({self.points}, {self.output})" - def write_initial_state(self, state): - self.output.write_initial_state(self._apply_mask(state)) + def write_initial_step(self, state, step): + self.output.write_initial_step(self._apply_mask(state), step) - def write_state(self, state): - self.output.write_state(self._apply_mask(state)) + def write_step(self, state, step): + self.output.write_step(self._apply_mask(state), step) def _apply_mask(self, state): diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index 81f0208..a9f0269 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -72,8 +72,19 @@ class GribOutput(Output): Handles grib """ - def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None): - super().__init__(context) + def __init__( + self, + context, + *, + encoding=None, + templates=None, + grib1_keys=None, + grib2_keys=None, + modifiers=None, + output_frequency=None, + write_initial_step=False, + ): + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) self._first = True self.typed_variables = self.checkpoint.typed_variables self.quiet = set() @@ -88,7 +99,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g self.use_closest_template = False # Off for now self.modifiers = modifier_factory(modifiers) - def write_initial_state(self, state): + def write_initial_step(self, state, step): # We trust the GribInput class to provide the templates # matching the input state @@ -98,7 +109,7 @@ def write_initial_state(self, state): if template is None: # We can currently only write grib output if we have a grib input raise ValueError( - "GRIB output only works if the input is GRIB (for now). Set `write_initial_state` to `false`." + "GRIB output only works if the input is GRIB (for now). Set `write_initial_step` to `false`." ) variable = self.typed_variables[name] @@ -128,7 +139,7 @@ def write_initial_state(self, state): self.write_message(values, template=template, **keys) - def write_state(self, state): + def write_step(self, state, step): reference_date = self.context.reference_date date = state["date"] diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index a31b8bf..db36345 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -90,6 +90,8 @@ def __init__( grib1_keys=None, grib2_keys=None, modifiers=None, + output_frequency=None, + write_initial_step=False, **kwargs, ): super().__init__( @@ -99,6 +101,8 @@ def __init__( grib1_keys=grib1_keys, grib2_keys=grib2_keys, modifiers=modifiers, + output_frequency=output_frequency, + write_initial_step=write_initial_step, ) self.path = path self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs) diff --git a/src/anemoi/inference/outputs/netcdf.py b/src/anemoi/inference/outputs/netcdf.py index 2bef7ca..e3f5431 100644 --- a/src/anemoi/inference/outputs/netcdf.py +++ b/src/anemoi/inference/outputs/netcdf.py @@ -24,8 +24,8 @@ class NetCDFOutput(Output): """_summary_""" - def __init__(self, context, path): - super().__init__(context) + def __init__(self, context, path, output_frequency=None, write_initial_step=False): + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) self.path = path self.ncfile = None self.float_size = "f4" @@ -37,12 +37,9 @@ def __del__(self): if self.ncfile is not None: self.ncfile.close() - def _init(self, state): + def open(self, state): from netCDF4 import Dataset - if self.ncfile is not None: - return self.ncfile - # If the file exists, we may get a 'Permission denied' error if os.path.exists(self.path): os.remove(self.path) @@ -53,23 +50,21 @@ def _init(self, state): values = len(state["latitudes"]) - time = 0 - self.reference_date = state["date"] - if hasattr(self.context, "time_step") and hasattr(self.context, "lead_time"): - time = self.context.lead_time // self.context.time_step - if hasattr(self.context, "reference_date"): - self.reference_date = self.context.reference_date - self.values_dim = self.ncfile.createDimension("values", values) - self.time_dim = self.ncfile.createDimension("time", time) - self.time_var = self.ncfile.createVariable("time", "i4", ("time",), **compression) + self.time_dim = self.ncfile.createDimension("time", size=None) # infinite dimension + self.time_var = self.ncfile.createVariable("time", "i8", ("time",), **compression) self.time_var.units = "seconds since {0}".format(self.reference_date) self.time_var.long_name = "time" self.time_var.calendar = "gregorian" latitudes = state["latitudes"] - self.latitude_var = self.ncfile.createVariable("latitude", self.float_size, ("values",), **compression) + self.latitude_var = self.ncfile.createVariable( + "latitude", + self.float_size, + ("values",), + **compression, + ) self.latitude_var.units = "degrees_north" self.latitude_var.long_name = "latitude" @@ -84,20 +79,7 @@ def _init(self, state): self.longitude_var.long_name = "longitude" self.vars = {} - - for name in state["fields"].keys(): - chunksizes = (1, values) - - while np.prod(chunksizes) > 1000000: - chunksizes = tuple(int(np.ceil(x / 2)) for x in chunksizes) - - self.vars[name] = self.ncfile.createVariable( - name, - self.float_size, - ("time", "values"), - chunksizes=chunksizes, - **compression, - ) + self.ensure_variables(state) self.latitude_var[:] = latitudes self.longitude_var[:] = longitudes @@ -126,15 +108,10 @@ def ensure_variables(self, state): **compression, ) - def write_initial_state(self, state): - reduced_state = self.reduce(state) - self.write_state(reduced_state) + def write_step(self, state, step): - def write_state(self, state): - self._init(state) self.ensure_variables(state) - step = state["date"] - self.reference_date self.time_var[self.n] = step.total_seconds() for name, value in state["fields"].items(): diff --git a/src/anemoi/inference/outputs/none.py b/src/anemoi/inference/outputs/none.py index d0bf158..c369447 100644 --- a/src/anemoi/inference/outputs/none.py +++ b/src/anemoi/inference/outputs/none.py @@ -19,8 +19,5 @@ class NoneOutput(Output): """_summary_""" - def write_initial_state(self, state): - pass - - def write_state(self, state): + def write_step(self, state, step): pass diff --git a/src/anemoi/inference/outputs/plot.py b/src/anemoi/inference/outputs/plot.py index ad38ad1..7748a27 100644 --- a/src/anemoi/inference/outputs/plot.py +++ b/src/anemoi/inference/outputs/plot.py @@ -35,8 +35,10 @@ def __init__( template="plot_{variable}_{date}.{format}", dpi=300, format="png", + output_frequency=None, + write_initial_step=False, ): - super().__init__(context) + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) self.path = path self.format = format self.variables = variables @@ -48,11 +50,11 @@ def __init__( if not isinstance(self.variables, (list, tuple)): self.variables = [self.variables] - def write_initial_state(self, state): + def write_initial_step(self, state): reduced_state = self.reduce(state) - self.write_state(reduced_state) + self.write_step(reduced_state) - def write_state(self, state): + def write_step(self, state, step): import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.pyplot as plt diff --git a/src/anemoi/inference/outputs/printer.py b/src/anemoi/inference/outputs/printer.py index b8d6ff1..a74b35a 100644 --- a/src/anemoi/inference/outputs/printer.py +++ b/src/anemoi/inference/outputs/printer.py @@ -60,8 +60,5 @@ def print_state(state, print=print): class PrinterOutput(Output): """_summary_""" - def write_initial_state(self, state): - self.write_state(state) - - def write_state(self, state): + def write_step(self, state, step): print_state(state) diff --git a/src/anemoi/inference/outputs/raw.py b/src/anemoi/inference/outputs/raw.py index a46ec52..db3502c 100644 --- a/src/anemoi/inference/outputs/raw.py +++ b/src/anemoi/inference/outputs/raw.py @@ -30,8 +30,10 @@ def __init__( path, template="{date}.npz", strftime="%Y%m%d%H%M%S", + output_frequency=None, + write_initial_step=False, ): - super().__init__(context) + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) self.path = path self.template = template self.strftime = strftime @@ -39,11 +41,11 @@ def __init__( def __repr__(self): return f"RawOutput({self.path})" - def write_initial_state(self, state): + def write_initial_step(self, state): reduced_state = self.reduce(state) - self.write_state(reduced_state) + self.write_step(reduced_state) - def write_state(self, state): + def write_step(self, state, step): os.makedirs(self.path, exist_ok=True) date = state["date"].strftime(self.strftime) fn_state = f"{self.path}/{self.template.format(date=date)}" diff --git a/src/anemoi/inference/outputs/tee.py b/src/anemoi/inference/outputs/tee.py index 82a6a8c..0b237e5 100644 --- a/src/anemoi/inference/outputs/tee.py +++ b/src/anemoi/inference/outputs/tee.py @@ -20,21 +20,27 @@ class TeeOutput(Output): """_summary_""" - def __init__(self, context, *args, outputs=None, **kwargs): - super().__init__(context) + def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_step=False, **kwargs): + super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) if outputs is None: outputs = args assert isinstance(outputs, (list, tuple)), outputs self.outputs = [create_output(context, output) for output in outputs] - def write_initial_state(self, state): + def write_initial_step(self, state): for output in self.outputs: output.write_initial_state(state) - def write_state(self, state): + def write_step(self, state, step): + # We call write_state instead of write_step + # so we can have a per-output `output_frequency` for output in self.outputs: output.write_state(state) + def open(self, state): + for output in self.outputs: + output.open(state) + def close(self): for output in self.outputs: output.close() diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 4ba0a8a..a5f6dbd 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -64,6 +64,7 @@ def __init__( inference_options=None, patch_metadata={}, development_hacks={}, # For testing purposes, don't use in production + output_frequency=None, ): self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata) @@ -77,6 +78,7 @@ def __init__( self.use_grib_paramid = use_grib_paramid self.development_hacks = development_hacks self.hacks = bool(development_hacks) + self.output_frequency = output_frequency # This could also be passed as an argument diff --git a/src/anemoi/inference/runners/default.py b/src/anemoi/inference/runners/default.py index 7dac5cd..c43fd85 100644 --- a/src/anemoi/inference/runners/default.py +++ b/src/anemoi/inference/runners/default.py @@ -50,6 +50,7 @@ def __init__(self, config): use_grib_paramid=config.use_grib_paramid, patch_metadata=config.patch_metadata, development_hacks=config.development_hacks, + output_frequency=config.output_frequency, ) def create_input(self): From 39497f3692812376a62eba92c82a69a7cca17ac5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 20 Jan 2025 19:51:02 +0000 Subject: [PATCH 2/2] Fix bug with initial state output --- src/anemoi/inference/context.py | 1 + src/anemoi/inference/output.py | 4 ++-- src/anemoi/inference/outputs/apply_mask.py | 4 ++-- src/anemoi/inference/outputs/extract_lam.py | 4 ++-- src/anemoi/inference/outputs/grib.py | 4 ++-- src/anemoi/inference/outputs/gribfile.py | 4 ++-- src/anemoi/inference/outputs/netcdf.py | 4 ++-- src/anemoi/inference/outputs/plot.py | 8 ++------ src/anemoi/inference/outputs/raw.py | 8 ++------ src/anemoi/inference/outputs/tee.py | 6 +++--- src/anemoi/inference/runner.py | 2 ++ src/anemoi/inference/runners/default.py | 1 + 12 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/anemoi/inference/context.py b/src/anemoi/inference/context.py index b3ecaff..e5a27e5 100644 --- a/src/anemoi/inference/context.py +++ b/src/anemoi/inference/context.py @@ -31,6 +31,7 @@ class Context(ABC): time_step = None lead_time = None output_frequency = None + write_initial_state = True ################################################################## diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index 382542c..dfc0e64 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -16,14 +16,14 @@ class Output(ABC): """_summary_""" - def __init__(self, context, output_frequency=None, write_initial_step=False): + def __init__(self, context, output_frequency=None, write_initial_state=True): from anemoi.utils.dates import as_timedelta self.context = context self.checkpoint = context.checkpoint self.reference_date = None - self.write_step_zero = write_initial_step and context.write_initial_step + self.write_step_zero = write_initial_state and context.write_initial_state self.output_frequency = output_frequency or context.output_frequency if self.output_frequency is not None: diff --git a/src/anemoi/inference/outputs/apply_mask.py b/src/anemoi/inference/outputs/apply_mask.py index 73dccd3..493e913 100644 --- a/src/anemoi/inference/outputs/apply_mask.py +++ b/src/anemoi/inference/outputs/apply_mask.py @@ -20,8 +20,8 @@ class ApplyMaskOutput(Output): """_summary_""" - def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *, mask, output, output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.mask = self.checkpoint.load_supporting_array(mask) self.output = create_output(context, output) diff --git a/src/anemoi/inference/outputs/extract_lam.py b/src/anemoi/inference/outputs/extract_lam.py index e7bdbd6..19a4ccf 100644 --- a/src/anemoi/inference/outputs/extract_lam.py +++ b/src/anemoi/inference/outputs/extract_lam.py @@ -22,8 +22,8 @@ class ExtractLamOutput(Output): """_summary_""" - def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) if isinstance(points, str): mask = self.checkpoint.load_supporting_array(points) diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index a9f0269..0c92740 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -82,9 +82,9 @@ def __init__( grib2_keys=None, modifiers=None, output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self._first = True self.typed_variables = self.checkpoint.typed_variables self.quiet = set() diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index db36345..48b27a3 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -91,7 +91,7 @@ def __init__( grib2_keys=None, modifiers=None, output_frequency=None, - write_initial_step=False, + write_initial_state=True, **kwargs, ): super().__init__( @@ -102,7 +102,7 @@ def __init__( grib2_keys=grib2_keys, modifiers=modifiers, output_frequency=output_frequency, - write_initial_step=write_initial_step, + write_initial_state=write_initial_state, ) self.path = path self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs) diff --git a/src/anemoi/inference/outputs/netcdf.py b/src/anemoi/inference/outputs/netcdf.py index e3f5431..bd66a6a 100644 --- a/src/anemoi/inference/outputs/netcdf.py +++ b/src/anemoi/inference/outputs/netcdf.py @@ -24,8 +24,8 @@ class NetCDFOutput(Output): """_summary_""" - def __init__(self, context, path, output_frequency=None, write_initial_step=False): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, path, output_frequency=None, write_initial_state=True): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.ncfile = None self.float_size = "f4" diff --git a/src/anemoi/inference/outputs/plot.py b/src/anemoi/inference/outputs/plot.py index 7748a27..ec3578c 100644 --- a/src/anemoi/inference/outputs/plot.py +++ b/src/anemoi/inference/outputs/plot.py @@ -36,9 +36,9 @@ def __init__( dpi=300, format="png", output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.format = format self.variables = variables @@ -50,10 +50,6 @@ def __init__( if not isinstance(self.variables, (list, tuple)): self.variables = [self.variables] - def write_initial_step(self, state): - reduced_state = self.reduce(state) - self.write_step(reduced_state) - def write_step(self, state, step): import cartopy.crs as ccrs import cartopy.feature as cfeature diff --git a/src/anemoi/inference/outputs/raw.py b/src/anemoi/inference/outputs/raw.py index db3502c..4df6b40 100644 --- a/src/anemoi/inference/outputs/raw.py +++ b/src/anemoi/inference/outputs/raw.py @@ -31,9 +31,9 @@ def __init__( template="{date}.npz", strftime="%Y%m%d%H%M%S", output_frequency=None, - write_initial_step=False, + write_initial_state=True, ): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) self.path = path self.template = template self.strftime = strftime @@ -41,10 +41,6 @@ def __init__( def __repr__(self): return f"RawOutput({self.path})" - def write_initial_step(self, state): - reduced_state = self.reduce(state) - self.write_step(reduced_state) - def write_step(self, state, step): os.makedirs(self.path, exist_ok=True) date = state["date"].strftime(self.strftime) diff --git a/src/anemoi/inference/outputs/tee.py b/src/anemoi/inference/outputs/tee.py index 0b237e5..62b016c 100644 --- a/src/anemoi/inference/outputs/tee.py +++ b/src/anemoi/inference/outputs/tee.py @@ -20,14 +20,14 @@ class TeeOutput(Output): """_summary_""" - def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_step=False, **kwargs): - super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step) + def __init__(self, context, *args, outputs=None, output_frequency=None, write_initial_state=True, **kwargs): + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) if outputs is None: outputs = args assert isinstance(outputs, (list, tuple)), outputs self.outputs = [create_output(context, output) for output in outputs] - def write_initial_step(self, state): + def write_initial_step(self, state, step): for output in self.outputs: output.write_initial_state(state) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index a5f6dbd..4d2563e 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -65,6 +65,7 @@ def __init__( patch_metadata={}, development_hacks={}, # For testing purposes, don't use in production output_frequency=None, + write_initial_state=True, ): self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata) @@ -79,6 +80,7 @@ def __init__( self.development_hacks = development_hacks self.hacks = bool(development_hacks) self.output_frequency = output_frequency + self.write_initial_state = write_initial_state # This could also be passed as an argument diff --git a/src/anemoi/inference/runners/default.py b/src/anemoi/inference/runners/default.py index c43fd85..99df321 100644 --- a/src/anemoi/inference/runners/default.py +++ b/src/anemoi/inference/runners/default.py @@ -51,6 +51,7 @@ def __init__(self, config): patch_metadata=config.patch_metadata, development_hacks=config.development_hacks, output_frequency=config.output_frequency, + write_initial_state=config.write_initial_state, ) def create_input(self):