Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
tendencies function ok. yaml todo
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Feb 13, 2024
1 parent 92f6b9d commit c39abf9
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 108 deletions.
7 changes: 1 addition & 6 deletions ecml_tools/create/functions/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
from copy import deepcopy

from climetlab import load_source
from climetlab.core.temporary import temp_file
from climetlab.readers.grib.output import new_grib_output
from climetlab.utils.availability import Availability

from ecml_tools.create.functions import assert_is_fieldset
from ecml_tools.create.utils import to_datetime_list

DEBUG = False
DEBUG = True


def to_list(x):
Expand Down Expand Up @@ -119,5 +116,3 @@ def mars(dates, *requests, **kwargs):
DEBUG = True
for f in mars(None, dates, *config):
print(f, f.to_numpy().mean())

# "[2022-12-30 12:00, 2022-12-31 00:00, 2022-12-31 12:00, 2023-01-01 00:00, 2023-01-01 12:00]"
148 changes: 63 additions & 85 deletions ecml_tools/create/functions/tendencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# nor does it submit to any jurisdiction.
#
import datetime
from collections import defaultdict
from copy import deepcopy

from climetlab.core.temporary import temp_file
Expand All @@ -33,77 +34,75 @@ def normalise_time_delta(t):
return t


def tendencies(dates, time_increment, **kwargs):
for d in dates:
assert isinstance(d, datetime.datetime), (type(d), d)
def group_by_field(ds):
d = defaultdict(list)
for field in ds.order_by("valid_datetime"):
m = field.as_mars()
for k in ("date", "time", "step"):
m.pop(k, None)
keys = tuple(m.items())
d[keys].append(field)
return d


def tendencies(dates, time_increment, **kwargs):
print("✅", kwargs)
time_increment = normalise_time_delta(time_increment)

assert len(kwargs) == 1, kwargs
assert "function" in kwargs, kwargs
func_kwargs = deepcopy(kwargs["function"])
assert func_kwargs.pop("name") == "mars", kwargs
shifted_dates = [d - time_increment for d in dates]
all_dates = sorted(list(set(dates + shifted_dates)))

# from .mars import execute as mars
from ecml_tools.create.functions.mars import execute as mars

current_dates = [d.isoformat() for d in dates]
shifted_dates = [(d - time_increment).isoformat() for d in dates]
all_dates = sorted(list(set(current_dates + shifted_dates)))
ds = mars(dates=all_dates, **kwargs)

padded_source = mars(dates=all_dates, **func_kwargs)
assert_is_fieldset(padded_source)
assert len(padded_source)
dates_in_data = ds.unique_values("valid_datetime")["valid_datetime"]
for d in all_dates:
assert d.isoformat() in dates_in_data, d

dates_in_data = padded_source.unique_values("valid_datetime")["valid_datetime"]
print(dates_in_data)
for d in current_dates:
assert d in dates_in_data, d
for d in shifted_dates:
assert d in dates_in_data, d
ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates])
ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates])

keys = ["valid_datetime", "date", "time", "step", "param", "level", "number"]
current = padded_source.sel(valid_datetime=current_dates).order_by(*keys)
before = padded_source.sel(valid_datetime=shifted_dates).order_by(*keys)
assert len(ds1) == len(ds2), (len(ds1), len(ds2))

assert len(current), (current, current_dates)
group1 = group_by_field(ds1)
group2 = group_by_field(ds2)

assert len(current) == len(before), (
len(current),
len(before),
time_increment,
len(dates),
)
assert group1.keys() == group2.keys(), (group1.keys(), group2.keys())

# prepare output tmp file so we can read it back
tmp = temp_file()
path = tmp.path
out = new_grib_output(path)

for field, b_field in zip(current, before):
for k in ["param", "level", "number", "grid", "shape"]:
assert field.metadata(k) == b_field.metadata(k), (
k,
field.metadata(k),
b_field.metadata(k),
for k in group1:
assert len(group1[k]) == len(group2[k]), k
print()
print("❌", k)

for field, b_field in zip(group1[k], group2[k]):
for k in ["param", "level", "number", "grid", "shape"]:
assert field.metadata(k) == b_field.metadata(k), (
k,
field.metadata(k),
b_field.metadata(k),
)

c = field.to_numpy()
b = b_field.to_numpy()
assert c.shape == b.shape, (c.shape, b.shape)

################
# Actual computation happens here
x = c - b
################

assert x.shape == c.shape, c.shape
print(
f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}"
)
c = field.to_numpy()
b = b_field.to_numpy()
assert c.shape == b.shape, (c.shape, b.shape)

################
# Actual computation happens here
x = c - b
################

assert x.shape == c.shape, c.shape
print(
"computing data for",
field.metadata("valid_datetime"),
"=",
field,
"-",
b_field,
)
out.write(x, template=field)
out.write(x, template=field)

out.close()

Expand All @@ -115,10 +114,6 @@ def tendencies(dates, time_increment, **kwargs):
# only when the dataset is not used anymore
ds._tmp = tmp

len(ds)
len(padded_source)
assert len(ds) == len(current), (len(ds), len(current))

return ds


Expand Down Expand Up @@ -158,37 +153,20 @@ def tendencies(dates, time_increment, **kwargs):
"""
config:
# name: tendencies
# dates: $dates
time_increment: 12h
function:
name: mars
database: marser
class: ea
# date: computed automatically
# time: computed automatically
expver: "0001"
grid: 20.0/20.0
levtype: sfc
param: [2t]
# levtype: pl
# param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z]
# number: [0, 1]
# name: mars
# database: marser
# class: ea
# expver: '0001'
# grid: 20.0/20.0
# levtype: sfc
# param: [2t]
# # levtype: pl
# # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z]
# # number: [0, 1]
time_increment: 12h
database: marser
class: ea
# date: computed automatically
# time: computed automatically
expver: "0001"
grid: 20.0/20.0
levtype: sfc
param: [2t]
"""
)["config"]

dates = yaml.safe_load(
"[2022-12-30 12:00, 2022-12-31 00:00, 2022-12-31 12:00, 2023-01-01 00:00, 2023-01-01 12:00]"
"[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]"
)
dates = to_datetime_list(dates)

Expand Down
21 changes: 4 additions & 17 deletions ecml_tools/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def sort(old_dic):
params_steps = sort(params_steps)
params_levels = sort(params_levels)

return dict(
param_level=params_levels, param_step=params_steps, area=area, grid=grid
)
return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid)


class Cache:
Expand Down Expand Up @@ -116,14 +114,7 @@ def _build_coords(self):
ensembles_key = list(from_config.keys())[2]

if isinstance(from_config[variables_key], (list, tuple)):
assert all(
[
v == w
for v, w in zip(
from_data[variables_key], from_config[variables_key]
)
]
), (
assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), (
from_data[variables_key],
from_config[variables_key],
)
Expand Down Expand Up @@ -424,9 +415,7 @@ class BaseFunctionAction(Action):
def __repr__(self):
content = ""
content += ",".join([self._short_str(a) for a in self.args])
content += " ".join(
[self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]
)
content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()])
content = self._short_str(content)
return super().__repr__(_inline_=content, _indent_=" ")

Expand Down Expand Up @@ -656,9 +645,7 @@ def action_factory(config, context):
)

if len(config) != 1:
raise ValueError(
f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}"
)
raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}")

config = deepcopy(config)
key = list(config.keys())[0]
Expand Down

0 comments on commit c39abf9

Please sign in to comment.