Skip to content

Commit

Permalink
feat: Add support for output_frequency to write less output
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Jan 19, 2025
1 parent e89974c commit fccad0e
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 79 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Add support for `output_frequency` to write less output

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def run(self, args):

input_state = input.create_input_state(date=config.date)

if config.write_initial_state:
output.write_initial_state(input_state)
output.write_initial_state(input_state)

for state in runner.run(input_state=input_state, lead_time=config.lead_time):
output.write_state(state)
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Config:
"""Wether to write the initial state to the output file. If the model is multi-step, only fields at the forecast reference date are
written."""

output_frequency: Optional[str] = None
"""The frequency at which to write the output. This can be a string or an integer. If a string, it is parsed by :func:`anemoi.utils.dates.as_timedelta`."""

env: Dict[str, str | int] = {}
"""Environment variables to set before running the model. This may be useful to control some packages
such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ class Context(ABC):
verbosity = 0
development_hacks = {} # For testing purposes, don't use in production

# Some runners will set these values, which can be queried by Output objects,
# but may remain as None

reference_date = None
time_step = None
lead_time = None
output_frequency = None

##################################################################

@property
@abstractmethod
def checkpoint(self):
Expand Down
53 changes: 49 additions & 4 deletions src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,67 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
import logging
from abc import ABC
from abc import abstractmethod

LOG = logging.getLogger(__name__)


class Output(ABC):
"""_summary_"""

def __init__(self, context):
def __init__(self, context, output_frequency=None, write_initial_step=False):
from anemoi.utils.dates import as_timedelta

self.context = context
self.checkpoint = context.checkpoint
self.reference_date = None

self.write_step_zero = write_initial_step and context.write_initial_step

self.output_frequency = output_frequency or context.output_frequency
if self.output_frequency is not None:
self.output_frequency = as_timedelta(self.output_frequency)

def __repr__(self):
return f"{self.__class__.__name__}()"

@abstractmethod
def write_initial_state(self, state):
pass
self._init(state)
if self.write_step_zero:
return self.write_initial_step(state, state["date"] - self.reference_date)

@abstractmethod
def write_state(self, state):
self._init(state)

step = state["date"] - self.reference_date
if self.output_frequency is not None:
if (step % self.output_frequency).total_seconds() != 0:
return

return self.write_step(state, step)

def _init(self, state):
if self.reference_date is not None:
return

self.reference_date = state["date"]

self.open(state)

def write_initial_step(self, state, step):
"""This method should not be called directly
call `write_initial_state` instead.
"""
reduced_state = self.reduce(state)
self.write_step(reduced_state, step)

@abstractmethod
def write_step(self, state, step):
"""This method should be be called directly
call `write_state` instead.
"""
pass

def reduce(self, state):
Expand All @@ -36,5 +77,9 @@ def reduce(self, state):
reduced_state["fields"][field] = values[-1, :]
return reduced_state

def open(self, state):
# Override this method when initialisation is needed
pass

def close(self):
pass
12 changes: 6 additions & 6 deletions src/anemoi/inference/outputs/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
class ApplyMaskOutput(Output):
"""_summary_"""

def __init__(self, context, *, mask, output):
super().__init__(context)
def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.mask = self.checkpoint.load_supporting_array(mask)
self.output = create_output(context, output)

def __repr__(self):
return f"ApplyMaskOutput({self.mask}, {self.output})"

def write_initial_state(self, state):
self.output.write_initial_state(self._apply_mask(state))
def write_initial_step(self, state, step):
self.output.write_initial_step(self._apply_mask(state), step)

def write_state(self, state):
self.output.write_state(self._apply_mask(state))
def write_step(self, state, step):
self.output.write_step(self._apply_mask(state), step)

def _apply_mask(self, state):
state = state.copy()
Expand Down
13 changes: 7 additions & 6 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
class ExtractLamOutput(Output):
"""_summary_"""

def __init__(self, context, *, output, points="cutout_mask"):
super().__init__(context)
def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)

if isinstance(points, str):
mask = self.checkpoint.load_supporting_array(points)
points = -np.sum(mask) # This is the global, we want the lam
Expand All @@ -34,11 +35,11 @@ def __init__(self, context, *, output, points="cutout_mask"):
def __repr__(self):
return f"ExtractLamOutput({self.points}, {self.output})"

def write_initial_state(self, state):
self.output.write_initial_state(self._apply_mask(state))
def write_initial_step(self, state, step):
self.output.write_initial_step(self._apply_mask(state), step)

def write_state(self, state):
self.output.write_state(self._apply_mask(state))
def write_step(self, state, step):
self.output.write_step(self._apply_mask(state), step)

def _apply_mask(self, state):

Expand Down
21 changes: 16 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,19 @@ class GribOutput(Output):
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
def __init__(
self,
context,
*,
encoding=None,
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
self.quiet = set()
Expand All @@ -88,7 +99,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
def write_initial_step(self, state, step):
# We trust the GribInput class to provide the templates
# matching the input state

Expand All @@ -98,7 +109,7 @@ def write_initial_state(self, state):
if template is None:
# We can currently only write grib output if we have a grib input
raise ValueError(
"GRIB output only works if the input is GRIB (for now). Set `write_initial_state` to `false`."
"GRIB output only works if the input is GRIB (for now). Set `write_initial_step` to `false`."
)

variable = self.typed_variables[name]
Expand Down Expand Up @@ -128,7 +139,7 @@ def write_initial_state(self, state):

self.write_message(values, template=template, **keys)

def write_state(self, state):
def write_step(self, state, step):

reference_date = self.context.reference_date
date = state["date"]
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
**kwargs,
):
super().__init__(
Expand All @@ -99,6 +101,8 @@ def __init__(
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
output_frequency=output_frequency,
write_initial_step=write_initial_step,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
Expand Down
49 changes: 13 additions & 36 deletions src/anemoi/inference/outputs/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
class NetCDFOutput(Output):
"""_summary_"""

def __init__(self, context, path):
super().__init__(context)
def __init__(self, context, path, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.path = path
self.ncfile = None
self.float_size = "f4"
Expand All @@ -37,12 +37,9 @@ def __del__(self):
if self.ncfile is not None:
self.ncfile.close()

def _init(self, state):
def open(self, state):
from netCDF4 import Dataset

if self.ncfile is not None:
return self.ncfile

# If the file exists, we may get a 'Permission denied' error
if os.path.exists(self.path):
os.remove(self.path)
Expand All @@ -53,23 +50,21 @@ def _init(self, state):

values = len(state["latitudes"])

time = 0
self.reference_date = state["date"]
if hasattr(self.context, "time_step") and hasattr(self.context, "lead_time"):
time = self.context.lead_time // self.context.time_step
if hasattr(self.context, "reference_date"):
self.reference_date = self.context.reference_date

self.values_dim = self.ncfile.createDimension("values", values)
self.time_dim = self.ncfile.createDimension("time", time)
self.time_var = self.ncfile.createVariable("time", "i4", ("time",), **compression)
self.time_dim = self.ncfile.createDimension("time", size=None) # infinite dimension
self.time_var = self.ncfile.createVariable("time", "i8", ("time",), **compression)

self.time_var.units = "seconds since {0}".format(self.reference_date)
self.time_var.long_name = "time"
self.time_var.calendar = "gregorian"

latitudes = state["latitudes"]
self.latitude_var = self.ncfile.createVariable("latitude", self.float_size, ("values",), **compression)
self.latitude_var = self.ncfile.createVariable(
"latitude",
self.float_size,
("values",),
**compression,
)
self.latitude_var.units = "degrees_north"
self.latitude_var.long_name = "latitude"

Expand All @@ -84,20 +79,7 @@ def _init(self, state):
self.longitude_var.long_name = "longitude"

self.vars = {}

for name in state["fields"].keys():
chunksizes = (1, values)

while np.prod(chunksizes) > 1000000:
chunksizes = tuple(int(np.ceil(x / 2)) for x in chunksizes)

self.vars[name] = self.ncfile.createVariable(
name,
self.float_size,
("time", "values"),
chunksizes=chunksizes,
**compression,
)
self.ensure_variables(state)

self.latitude_var[:] = latitudes
self.longitude_var[:] = longitudes
Expand Down Expand Up @@ -126,15 +108,10 @@ def ensure_variables(self, state):
**compression,
)

def write_initial_state(self, state):
reduced_state = self.reduce(state)
self.write_state(reduced_state)
def write_step(self, state, step):

def write_state(self, state):
self._init(state)
self.ensure_variables(state)

step = state["date"] - self.reference_date
self.time_var[self.n] = step.total_seconds()

for name, value in state["fields"].items():
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/inference/outputs/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,5 @@
class NoneOutput(Output):
"""_summary_"""

def write_initial_state(self, state):
pass

def write_state(self, state):
def write_step(self, state, step):
pass
10 changes: 6 additions & 4 deletions src/anemoi/inference/outputs/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def __init__(
template="plot_{variable}_{date}.{format}",
dpi=300,
format="png",
output_frequency=None,
write_initial_step=False,
):
super().__init__(context)
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.path = path
self.format = format
self.variables = variables
Expand All @@ -48,11 +50,11 @@ def __init__(
if not isinstance(self.variables, (list, tuple)):
self.variables = [self.variables]

def write_initial_state(self, state):
def write_initial_step(self, state):
reduced_state = self.reduce(state)
self.write_state(reduced_state)
self.write_step(reduced_state)

def write_state(self, state):
def write_step(self, state, step):
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/inference/outputs/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,5 @@ def print_state(state, print=print):
class PrinterOutput(Output):
"""_summary_"""

def write_initial_state(self, state):
self.write_state(state)

def write_state(self, state):
def write_step(self, state, step):
print_state(state)
Loading

0 comments on commit fccad0e

Please sign in to comment.