Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
Merge branch 'develop' of https://github.com/ecmwf-lab/ecml-tools int…
Browse files Browse the repository at this point in the history
…o develop
  • Loading branch information
floriankrb committed Mar 6, 2024
2 parents b6a0699 + ab74c0e commit b1bccb0
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 82 deletions.
17 changes: 13 additions & 4 deletions ecml_tools/create/functions/actions/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def _expand_mars_request(request, date):
"step": s,
}
)

for pproc in ("grid", "rotation", "frame", "area", "bitmap", "resol"):
if pproc in r:
if isinstance(r[pproc], (list, tuple)):
r[pproc] = "/".join(str(x) for x in r[pproc])

requests.append(r)

return requests


Expand All @@ -68,7 +75,11 @@ def factorise_requests(dates, *requests):
updates += _expand_mars_request(req, date=d)

compressed = Availability(updates)
return compressed.iterate()
for r in compressed.iterate():
for k, v in r.items():
if isinstance(v, (list, tuple)) and len(v) == 1:
r[k] = v[0]
yield r


def mars(context, dates, *requests, **kwargs):
Expand Down Expand Up @@ -109,9 +120,7 @@ def mars(context, dates, *requests, **kwargs):
"""
)
dates = yaml.safe_load(
"[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]"
)
dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
dates = to_datetime_list(dates)

DEBUG = True
Expand Down
18 changes: 8 additions & 10 deletions ecml_tools/create/functions/steps/unrotate_winds.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def rotate_winds(
new_x = np.zeros_like(x_wind)
new_y = np.zeros_like(y_wind)

for i, (vx, vy, lat, lon, raw_lat, raw_lon) in enumerate(
zip(x_wind, y_wind, lats, lons, raw_lats, raw_lons)
):
for i, (vx, vy, lat, lon, raw_lat, raw_lon) in enumerate(zip(x_wind, y_wind, lats, lons, raw_lats, raw_lons)):
lonRotated = south_pole_longitude - lon
lon_rotated = normalise_longitude(lonRotated, -180)
lon_unrotated = raw_lon
Expand Down Expand Up @@ -79,13 +77,13 @@ def __getattr__(self, name):
return getattr(self.field, name)


def execute(context, input, x_wind, y_wind):
def execute(context, input, u, v):
"""
Unrotate the wind components of a GRIB file.
"""
result = FieldArray()

wind_params = (x_wind, y_wind)
wind_params = (u, v)
wind_pairs = defaultdict(dict)

for f in input:
Expand All @@ -107,15 +105,15 @@ def execute(context, input, x_wind, y_wind):
if len(pairs) != 2:
raise ValueError("Missing wind component")

x = pairs[x_wind]
y = pairs[y_wind]
x = pairs[u]
y = pairs[v]

lats, lons = x.grid_points()
raw_lats, raw_longs = x.grid_points_raw()

assert x.rotation == y.rotation

x_new, y_new = rotate_winds(
u_new, v_new = rotate_winds(
lats,
lons,
raw_lats,
Expand All @@ -125,8 +123,8 @@ def execute(context, input, x_wind, y_wind):
*x.rotation,
)

result.append(NewDataField(x, x_new))
result.append(NewDataField(y, y_new))
result.append(NewDataField(x, u_new))
result.append(NewDataField(y, v_new))

return result

Expand Down
74 changes: 67 additions & 7 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import re
import textwrap
import warnings
from functools import cached_property, wraps
from pathlib import PurePath
Expand Down Expand Up @@ -83,6 +84,32 @@ class MissingDate(Exception):
pass


class Node:
def __init__(self, dataset, kids, **kwargs):
self.dataset = dataset
self.kids = kids
self.kwargs = kwargs

def _put(self, indent, result):

def _spaces(indent):
return " " * indent if indent else ""

result.append(f"{_spaces(indent)}{self.dataset.__class__.__name__}")
for k, v in self.kwargs.items():
if isinstance(v, (list, tuple)):
v = ", ".join(str(i) for i in v)
v = textwrap.shorten(v, width=40, placeholder="...")
result.append(f"{_spaces(indent+2)}{k}: {v}")
for kid in self.kids:
kid._put(indent + 2, result)

def __repr__(self):
result = []
self._put(0, result)
return "\n".join(result)


class Dataset:
arguments = {}

Expand All @@ -106,15 +133,15 @@ def _subset(self, **kwargs):

if "select" in kwargs:
select = kwargs.pop("select")
return Select(self, self._select_to_columns(select))._subset(**kwargs)
return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs)

if "drop" in kwargs:
drop = kwargs.pop("drop")
return Select(self, self._drop_to_columns(drop))._subset(**kwargs)
return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs)

if "reorder" in kwargs:
reorder = kwargs.pop("reorder")
return Select(self, self._reorder_to_columns(reorder))._subset(**kwargs)
return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs)

if "rename" in kwargs:
rename = kwargs.pop("rename")
Expand Down Expand Up @@ -499,6 +526,9 @@ def mutate(self):
return ZarrWithMissingDates(self.z if self.was_zarr else self.path)
return self

def tree(self):
return Node(self, [], path=self.path)


class ZarrWithMissingDates(Zarr):
def __init__(self, path):
Expand Down Expand Up @@ -550,6 +580,9 @@ def __getitem__(self, n):
def _report_missing(self, n):
raise MissingDate(f"Date {self.missing_to_dates[n]} is missing (index={n})")

def tree(self):
return Node(self, [], path=self.path, missing=sorted(self.missing))


class Forwards(Dataset):
def __init__(self, forward):
Expand Down Expand Up @@ -768,6 +801,9 @@ def dates(self):
def shape(self):
return (len(self),) + self.datasets[0].shape[1:]

def tree(self):
return Node(self, [d.tree() for d in self.datasets])


class GivenAxis(Combined):
"""Given a given axis, combine the datasets along that axis."""
Expand Down Expand Up @@ -820,7 +856,9 @@ def __getitem__(self, n):


class Ensemble(GivenAxis):
pass

def tree(self):
return Node(self, [d.tree() for d in self.datasets])


class Grids(GivenAxis):
Expand Down Expand Up @@ -851,6 +889,9 @@ def grids(self):
result.extend(d.grids)
return tuple(result)

def tree(self):
return Node(self, [d.tree() for d in self.datasets], mode="concat")


class CutoutGrids(Grids):
def __init__(self, datasets, axis):
Expand Down Expand Up @@ -926,6 +967,9 @@ def grids(self):
shape = self.lam.shape
return (shape[-1], self.shape[-1] - shape[-1])

def tree(self):
return Node(self, [d.tree() for d in self.datasets], mode="cutout")


class Join(Combined):
"""Join the datasets along the variables axis."""
Expand Down Expand Up @@ -996,7 +1040,7 @@ def _overlay(self):
if not ok:
LOG.warning("Dataset %r completely overridden.", d)

return Select(self, indices)
return Select(self, indices, {"overlay": True})

@cached_property
def variables(self):
Expand Down Expand Up @@ -1036,6 +1080,9 @@ def missing(self):
result = result | d.missing
return result

def tree(self):
return Node(self, [d.tree() for d in self.datasets])


class Subset(Forwards):
"""Select a subset of the dates."""
Expand Down Expand Up @@ -1108,24 +1155,28 @@ def source(self, index):
return Source(self, index, self.forward.source(index))

def __repr__(self):
return f"Subset({self.dates[0]}, {self.dates[-1]}, {self.frequency})"
return f"Subset({self.dataset},{self.dates[0]}...{self.dates[-1]}/{self.frequency})"

@cached_property
def missing(self):
return {self.indices[i] for i in self.dataset.missing if i in self.indices}

def tree(self):
return Node(self, [self.dataset.tree()])


class Select(Forwards):
"""Select a subset of the variables."""

def __init__(self, dataset, indices):
def __init__(self, dataset, indices, title):
while isinstance(dataset, Select):
indices = [dataset.indices[i] for i in indices]
dataset = dataset.dataset

self.dataset = dataset
self.indices = list(indices)
assert len(self.indices) > 0
self.title = title or {"indices": self.indices}

# Forward other properties to the main dataset
super().__init__(dataset)
Expand Down Expand Up @@ -1174,6 +1225,9 @@ def metadata_specific(self, **kwargs):
def source(self, index):
return Source(self, index, self.dataset.source(self.indices[index]))

def tree(self):
return Node(self, [self.dataset.tree()], **self.title)


class Rename(Forwards):
def __init__(self, dataset, rename):
Expand All @@ -1194,6 +1248,9 @@ def name_to_index(self):
def metadata_specific(self, **kwargs):
return super().metadata_specific(rename=self.rename, **kwargs)

def tree(self):
return Node(self, [self.forward.tree()], rename=self.rename)


class Statistics(Forwards):
def __init__(self, dataset, statistic):
Expand All @@ -1210,6 +1267,9 @@ def metadata_specific(self, **kwargs):
**kwargs,
)

def tree(self):
return Node(self, [self.forward.tree()])


def _name_to_path(name, zarr_root):
_, ext = os.path.splitext(name)
Expand Down
Loading

0 comments on commit b1bccb0

Please sign in to comment.