Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for output_frequency to write less output #109

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Add support for `output_frequency` to write less output

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def run(self, args):

input_state = input.create_input_state(date=config.date)

if config.write_initial_state:
output.write_initial_state(input_state)
output.write_initial_state(input_state)

for state in runner.run(input_state=input_state, lead_time=config.lead_time):
output.write_state(state)
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Config:
"""Wether to write the initial state to the output file. If the model is multi-step, only fields at the forecast reference date are
written."""

output_frequency: Optional[str] = None
"""The frequency at which to write the output. This can be a string or an integer. If a string, it is parsed by :func:`anemoi.utils.dates.as_timedelta`."""

env: Dict[str, str | int] = {}
"""Environment variables to set before running the model. This may be useful to control some packages
such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ class Context(ABC):
verbosity = 0
development_hacks = {} # For testing purposes, don't use in production

# Some runners will set these values, which can be queried by Output objects,
# but may remain as None

reference_date = None
time_step = None
lead_time = None
output_frequency = None

##################################################################

@property
@abstractmethod
def checkpoint(self):
Expand Down
53 changes: 49 additions & 4 deletions src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,67 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
import logging
from abc import ABC
from abc import abstractmethod

LOG = logging.getLogger(__name__)


class Output(ABC):
"""_summary_"""

def __init__(self, context):
def __init__(self, context, output_frequency=None, write_initial_step=False):
from anemoi.utils.dates import as_timedelta

self.context = context
self.checkpoint = context.checkpoint
self.reference_date = None

self.write_step_zero = write_initial_step and context.write_initial_step

self.output_frequency = output_frequency or context.output_frequency
if self.output_frequency is not None:
self.output_frequency = as_timedelta(self.output_frequency)

def __repr__(self):
return f"{self.__class__.__name__}()"

@abstractmethod
def write_initial_state(self, state):
pass
self._init(state)
if self.write_step_zero:
return self.write_initial_step(state, state["date"] - self.reference_date)

@abstractmethod
def write_state(self, state):
self._init(state)

step = state["date"] - self.reference_date
if self.output_frequency is not None:
if (step % self.output_frequency).total_seconds() != 0:
return

return self.write_step(state, step)

def _init(self, state):
if self.reference_date is not None:
return

self.reference_date = state["date"]

self.open(state)

def write_initial_step(self, state, step):
"""This method should not be called directly
call `write_initial_state` instead.
"""
reduced_state = self.reduce(state)
self.write_step(reduced_state, step)

@abstractmethod
def write_step(self, state, step):
"""This method should be be called directly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""This method should be be called directly
"""This method should not be called directly

call `write_state` instead.
"""
pass

def reduce(self, state):
Expand All @@ -36,5 +77,9 @@ def reduce(self, state):
reduced_state["fields"][field] = values[-1, :]
return reduced_state

def open(self, state):
# Override this method when initialisation is needed
pass

def close(self):
pass
12 changes: 6 additions & 6 deletions src/anemoi/inference/outputs/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
class ApplyMaskOutput(Output):
"""_summary_"""

def __init__(self, context, *, mask, output):
super().__init__(context)
def __init__(self, context, *, mask, output, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.mask = self.checkpoint.load_supporting_array(mask)
self.output = create_output(context, output)

def __repr__(self):
return f"ApplyMaskOutput({self.mask}, {self.output})"

def write_initial_state(self, state):
self.output.write_initial_state(self._apply_mask(state))
def write_initial_step(self, state, step):
self.output.write_initial_step(self._apply_mask(state), step)

def write_state(self, state):
self.output.write_state(self._apply_mask(state))
def write_step(self, state, step):
self.output.write_step(self._apply_mask(state), step)

def _apply_mask(self, state):
state = state.copy()
Expand Down
13 changes: 7 additions & 6 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
class ExtractLamOutput(Output):
"""_summary_"""

def __init__(self, context, *, output, points="cutout_mask"):
super().__init__(context)
def __init__(self, context, *, output, points="cutout_mask", output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)

if isinstance(points, str):
mask = self.checkpoint.load_supporting_array(points)
points = -np.sum(mask) # This is the global, we want the lam
Expand All @@ -34,11 +35,11 @@ def __init__(self, context, *, output, points="cutout_mask"):
def __repr__(self):
return f"ExtractLamOutput({self.points}, {self.output})"

def write_initial_state(self, state):
self.output.write_initial_state(self._apply_mask(state))
def write_initial_step(self, state, step):
self.output.write_initial_step(self._apply_mask(state), step)

def write_state(self, state):
self.output.write_state(self._apply_mask(state))
def write_step(self, state, step):
self.output.write_step(self._apply_mask(state), step)

def _apply_mask(self, state):

Expand Down
21 changes: 16 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,19 @@ class GribOutput(Output):
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
def __init__(
self,
context,
*,
encoding=None,
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
self.quiet = set()
Expand All @@ -88,7 +99,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
def write_initial_step(self, state, step):
# We trust the GribInput class to provide the templates
# matching the input state

Expand All @@ -98,7 +109,7 @@ def write_initial_state(self, state):
if template is None:
# We can currently only write grib output if we have a grib input
raise ValueError(
"GRIB output only works if the input is GRIB (for now). Set `write_initial_state` to `false`."
"GRIB output only works if the input is GRIB (for now). Set `write_initial_step` to `false`."
)

variable = self.typed_variables[name]
Expand Down Expand Up @@ -128,7 +139,7 @@ def write_initial_state(self, state):

self.write_message(values, template=template, **keys)

def write_state(self, state):
def write_step(self, state, step):

reference_date = self.context.reference_date
date = state["date"]
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_step=False,
**kwargs,
):
super().__init__(
Expand All @@ -99,6 +101,8 @@ def __init__(
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
output_frequency=output_frequency,
write_initial_step=write_initial_step,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
Expand Down
49 changes: 13 additions & 36 deletions src/anemoi/inference/outputs/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
class NetCDFOutput(Output):
"""_summary_"""

def __init__(self, context, path):
super().__init__(context)
def __init__(self, context, path, output_frequency=None, write_initial_step=False):
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.path = path
self.ncfile = None
self.float_size = "f4"
Expand All @@ -37,12 +37,9 @@ def __del__(self):
if self.ncfile is not None:
self.ncfile.close()

def _init(self, state):
def open(self, state):
from netCDF4 import Dataset

if self.ncfile is not None:
return self.ncfile

# If the file exists, we may get a 'Permission denied' error
if os.path.exists(self.path):
os.remove(self.path)
Expand All @@ -53,23 +50,21 @@ def _init(self, state):

values = len(state["latitudes"])

time = 0
self.reference_date = state["date"]
if hasattr(self.context, "time_step") and hasattr(self.context, "lead_time"):
time = self.context.lead_time // self.context.time_step
if hasattr(self.context, "reference_date"):
self.reference_date = self.context.reference_date

self.values_dim = self.ncfile.createDimension("values", values)
self.time_dim = self.ncfile.createDimension("time", time)
self.time_var = self.ncfile.createVariable("time", "i4", ("time",), **compression)
self.time_dim = self.ncfile.createDimension("time", size=None) # infinite dimension
self.time_var = self.ncfile.createVariable("time", "i8", ("time",), **compression)

self.time_var.units = "seconds since {0}".format(self.reference_date)
self.time_var.long_name = "time"
self.time_var.calendar = "gregorian"

latitudes = state["latitudes"]
self.latitude_var = self.ncfile.createVariable("latitude", self.float_size, ("values",), **compression)
self.latitude_var = self.ncfile.createVariable(
"latitude",
self.float_size,
("values",),
**compression,
)
self.latitude_var.units = "degrees_north"
self.latitude_var.long_name = "latitude"

Expand All @@ -84,20 +79,7 @@ def _init(self, state):
self.longitude_var.long_name = "longitude"

self.vars = {}

for name in state["fields"].keys():
chunksizes = (1, values)

while np.prod(chunksizes) > 1000000:
chunksizes = tuple(int(np.ceil(x / 2)) for x in chunksizes)

self.vars[name] = self.ncfile.createVariable(
name,
self.float_size,
("time", "values"),
chunksizes=chunksizes,
**compression,
)
self.ensure_variables(state)

self.latitude_var[:] = latitudes
self.longitude_var[:] = longitudes
Expand Down Expand Up @@ -126,15 +108,10 @@ def ensure_variables(self, state):
**compression,
)

def write_initial_state(self, state):
reduced_state = self.reduce(state)
self.write_state(reduced_state)
def write_step(self, state, step):

def write_state(self, state):
self._init(state)
self.ensure_variables(state)

step = state["date"] - self.reference_date
self.time_var[self.n] = step.total_seconds()

for name, value in state["fields"].items():
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/inference/outputs/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,5 @@
class NoneOutput(Output):
"""_summary_"""

def write_initial_state(self, state):
pass

def write_state(self, state):
def write_step(self, state, step):
pass
10 changes: 6 additions & 4 deletions src/anemoi/inference/outputs/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def __init__(
template="plot_{variable}_{date}.{format}",
dpi=300,
format="png",
output_frequency=None,
write_initial_step=False,
):
super().__init__(context)
super().__init__(context, output_frequency=output_frequency, write_initial_step=write_initial_step)
self.path = path
self.format = format
self.variables = variables
Expand All @@ -48,11 +50,11 @@ def __init__(
if not isinstance(self.variables, (list, tuple)):
self.variables = [self.variables]

def write_initial_state(self, state):
def write_initial_step(self, state):
reduced_state = self.reduce(state)
self.write_state(reduced_state)
self.write_step(reduced_state)

def write_state(self, state):
def write_step(self, state, step):
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/inference/outputs/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,5 @@ def print_state(state, print=print):
class PrinterOutput(Output):
"""_summary_"""

def write_initial_state(self, state):
self.write_state(state)

def write_state(self, state):
def write_step(self, state, step):
print_state(state)
Loading
Loading