Skip to content

Commit

Permalink
CDS input (#97)
Browse files Browse the repository at this point in the history
* CDS Input Support
  • Loading branch information
HCookie authored Jan 13, 2025
1 parent 38cb2b5 commit 92eb869
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 0 deletions.
36 changes: 36 additions & 0 deletions docs/configs/inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,39 @@ You can change that to use ERA5 reanalysis data (``class=ea``).

The ``mars`` input also accepts the ``namer`` parameter of the GRIB
input.

*****
cds
*****

You can also specify the input as ``cds`` to read the data from the
`Climate Data Store <https://cds.climate.copernicus.eu/>`_. This
requires the `cdsapi` package to be installed, and the user to have a
CDS account.

.. literalinclude:: inputs_8.yaml
:language: yaml

As the CDS contains a plethora of `datasets
<https://cds.climate.copernicus.eu/datasets>`_, you can specify the
dataset you want to use with the key `dataset`.

This can be a str in which case the dataset is used for all requests, or
a dict of any number of levels which will be descended based on the
key/values for each request.

You can use `*` to represent any not given value for a key, i.e. set a
dataset for `param: 2t`. and `param: *` to represent any other param.

.. literalinclude:: inputs_9.yaml
:language: yaml

In the above example, the dataset `reanalysis-era5-pressure-levels` is
used for all with `levtype: pl` and `reanalysis-era5-single-levels` used
for all with `levtype: sfc`.

Additionally, any kwarg can be passed to be added to all requests, i.e.
for ERA5 data, `product_type: 'reanalysis'` is needed.

.. literalinclude:: inputs_10.yaml
:language: yaml
7 changes: 7 additions & 0 deletions docs/configs/inputs_10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
input:
cds:
dataset:
levtype:
pl: reanalysis-era5-pressure-levels
sfc: reanalysis-era5-single-levels
product_type: 'reanalysis'
3 changes: 3 additions & 0 deletions docs/configs/inputs_8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
input:
cds:
dataset: ???
24 changes: 24 additions & 0 deletions docs/configs/inputs_9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
input:
cds:
# Dataset examples
## As a string
dataset:
'reanalysis-era5-pressure-levels'

## As a simple dictionary
dataset:
levtype:
pl: reanalysis-era5-pressure-levels
sfc: reanalysis-era5-single-levels

## As a complex dictionary
dataset:
stream:
oper:
levtype:
pl: reanalysis-era5-pressure-levels
sfc: reanalysis-era5-single-levels
an:
# ... Other datasets
'*': # Any other stream
# ... Other datasets
118 changes: 118 additions & 0 deletions src/anemoi/inference/inputs/cds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

from earthkit.data.utils.dates import to_datetime

from . import input_registry
from .grib import GribInput
from .mars import postproc

LOG = logging.getLogger(__name__)


def retrieve(requests, grid, area, dataset, **kwargs):
import earthkit.data as ekd

def _(r):
mars = r.copy()
for k, v in r.items():
if isinstance(v, (list, tuple)):
mars[k] = "/".join(str(x) for x in v)
else:
mars[k] = str(v)

return ",".join(f"{k}={v}" for k, v in mars.items())

pproc = postproc(grid, area)

result = ekd.from_source("empty")
for r in requests:
if isinstance(dataset, str):
d = dataset
elif isinstance(dataset, dict):
# Get dataset from intersection of keys between request and dataset dict
search_dataset = dataset.copy()
while isinstance(search_dataset, dict):
keys = set(r.keys()).intersection(set(search_dataset.keys()))
if len(keys) == 0:
raise KeyError(
f"While searching for dataset, could not find any valid key in dictionary: {r.keys()}, {search_dataset}"
)
key = list(keys)[0]
if r[key] not in search_dataset[key]:
if "*" in search_dataset[key]:
search_dataset = search_dataset[key]["*"]
continue

raise KeyError(
f"Dataset dictionary does not contain key {r[key]!r} in {key!r}: {dict(search_dataset[key])}."
)
search_dataset = search_dataset[key][r[key]]

d = search_dataset

r.update(pproc)
r.update(kwargs)

LOG.debug("%s", _(r))
result += ekd.from_source("cds", d, r)

return result


@input_registry.register("cds")
class CDSInput(GribInput):
"""Get input fields from CDS"""

def __init__(self, context, *, dataset, namer=None, **kwargs):
super().__init__(context, namer=namer)

self.variables = self.checkpoint.variables_from_input(include_forcings=False)
self.dataset = dataset
self.kwargs = kwargs

def create_input_state(self, *, date):
if date is None:
date = to_datetime(-1)
LOG.warning("CDSInput: `date` parameter not provided, using yesterday's date: %s", date)

date = to_datetime(date)

return self._create_input_state(
self.retrieve(
self.variables,
[date + h for h in self.checkpoint.lagged],
),
variables=self.variables,
date=date,
)

def retrieve(self, variables, dates):

requests = self.checkpoint.mars_requests(
variables=variables,
dates=dates,
use_grib_paramid=self.context.use_grib_paramid,
)

if not requests:
raise ValueError("No requests for %s (%s)" % (variables, dates))

return retrieve(
requests, self.checkpoint.grid, self.checkpoint.area, dataset=self.dataset, expver="0001", **self.kwargs
)

def template(self, variable, date, **kwargs):
return self.retrieve([variable], [date])[0]

def load_forcings(self, variables, dates):
return self._load_forcings(self.retrieve(variables, dates), variables, dates)

0 comments on commit 92eb869

Please sign in to comment.