diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f537085..d3d0cd6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,6 @@ repos: - id: ruff args: - --line-length=120 - - --ignore=E203 - --fix - --exit-non-zero-on-fix - --preview diff --git a/ecml_tools/__init__.py b/ecml_tools/__init__.py index d64a42f..1fb3780 100644 --- a/ecml_tools/__init__.py +++ b/ecml_tools/__init__.py @@ -5,4 +5,4 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -__version__ = "0.6.0" +__version__ = "0.6.1" diff --git a/ecml_tools/commands/copy.py b/ecml_tools/commands/copy.py index 2a458ad..881d58d 100644 --- a/ecml_tools/commands/copy.py +++ b/ecml_tools/commands/copy.py @@ -154,9 +154,26 @@ def copy_group(self, source, target, transfers, block_size, _copy, progress, rec for name in sorted(source.keys()): if isinstance(source[name], zarr.hierarchy.Group): group = target[name] if name in target else target.create_group(name) - self.copy_group(source[name], group, transfers, block_size, _copy, progress, rechunking) + self.copy_group( + source[name], + group, + transfers, + block_size, + _copy, + progress, + rechunking, + ) else: - self.copy_array(name, source, target, transfers, block_size, _copy, progress, rechunking) + self.copy_array( + name, + source, + target, + transfers, + block_size, + _copy, + progress, + rechunking, + ) def copy(self, source, target, transfers, block_size, progress, rechunking): import zarr diff --git a/ecml_tools/commands/scan.py b/ecml_tools/commands/scan.py index bc45304..6574f62 100644 --- a/ecml_tools/commands/scan.py +++ b/ecml_tools/commands/scan.py @@ -29,7 +29,6 @@ def add_arguments(self, command_parser): command_parser.add_argument("paths", nargs="+", help="Paths to scan") def run(self, args): - def match(path): return fnmatch.fnmatch(path, args.match) diff --git a/ecml_tools/create/__init__.py b/ecml_tools/create/__init__.py index 6498d19..3d85d2a 100644 --- a/ecml_tools/create/__init__.py +++ b/ecml_tools/create/__init__.py @@ -51,7 +51,10 @@ def load(self, parts=None): with self._cache_context(): loader = ContentLoader.from_dataset_config( - path=self.path, statistics_tmp=self.statistics_tmp, print=self.print, parts=parts + path=self.path, + statistics_tmp=self.statistics_tmp, + print=self.print, + parts=parts, ) loader.load() diff --git a/ecml_tools/create/config.py b/ecml_tools/create/config.py index 61d140d..99a1511 100644 --- a/ecml_tools/create/config.py +++ b/ecml_tools/create/config.py @@ -139,7 +139,6 @@ def statistics(self): class LoadersConfig(Config): - def __init__(self, config, *args, **kwargs): if "build" not in config: config["build"] = {} diff --git a/ecml_tools/create/functions/actions/acc.py b/ecml_tools/create/functions/actions/acc.py new file mode 100644 index 0000000..c1fc7b6 --- /dev/null +++ b/ecml_tools/create/functions/actions/acc.py @@ -0,0 +1,304 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import datetime +from copy import deepcopy + +import climetlab as cml +from climetlab.core.temporary import temp_file +from climetlab.readers.grib.output import new_grib_output +from climetlab.utils.availability import Availability + +from ecml_tools.create.utils import to_datetime_list + +DEBUG = True + + +class Accumulation: + def __init__(self, out, param, date, time, number, stepping): + self.out = out + self.param = param + self.date = date + self.time = time + self.number = number + self.values = None + self.seen = set() + self.startStep = None + self.endStep = None + self.done = False + self.stepping = stepping + + @property + def key(self): + return (self.param, self.date, self.time, self.number) + + +class AccumulationFromStart(Accumulation): + def add(self, field, values): + step = field.metadata("step") + # if step not in self.steps: + # return + + assert not self.done, (self.key, step) + assert step not in self.seen, (self.key, step) + + startStep = field.metadata("startStep") + endStep = field.metadata("endStep") + + assert startStep == 0 or (startStep == endStep), (startStep, endStep, step) + assert step == endStep, (startStep, endStep, step) + + if self.values is None: + import numpy as np + + self.values = np.copy(values) + self.startStep = 0 + self.endStep = endStep + ready = False + + else: + assert endStep != self.endStep, (self.endStep, endStep) + + if endStep > self.endStep: + # assert endStep - self.endStep == self.stepping, (self.endStep, endStep, self.stepping) + self.values = values - self.values + self.endStep = endStep + else: + # assert self.endStep - endStep == self.stepping, (self.endStep, endStep, self.stepping) + self.values = self.values - values + + ready = True + + self.seen.add(step) + + if ready: + self.out.write( + self.values, + template=field, + startStep=self.startStep, + endStep=self.endStep, + ) + self.values = None + self.done = True + + +class AccumulationFromLastStep(Accumulation): + def add(self, field, values): + step = field.metadata("step") + + assert not self.done, (self.key, step) + assert step not in self.seen, (self.key, step) + + startStep = field.metadata("startStep") + endStep = field.metadata("endStep") + + assert endStep == step, (startStep, endStep, step) + assert step not in self.seen, (self.key, step) + + assert endStep - startStep == self.stepping, (startStep, endStep) + + if self.startStep is None: + self.startStep = startStep + else: + self.startStep = min(self.startStep, startStep) + + if self.endStep is None: + self.endStep = endStep + else: + self.endStep = max(self.endStep, endStep) + + if self.values is None: + import numpy as np + + self.values = np.zeros_like(values) + + self.values += values + + self.seen.add(step) + + if len(self.seen) == len(self.steps): + self.out.write( + self.values, + template=field, + startStep=self.startStep, + endStep=self.endStep, + ) + self.values = None + self.done = True + + +def accumulations_from_start(dates, step1, step2): + for valid_date in dates: + base_date = valid_date - datetime.timedelta(hours=step2) + + yield ( + base_date.year * 10000 + base_date.month * 100 + base_date.day, + base_date.hour * 100 + base_date.minute, + step1, + ) + yield ( + base_date.year * 10000 + base_date.month * 100 + base_date.day, + base_date.hour * 100 + base_date.minute, + step2, + ) + + +def accumulations_from_last_step(dates, step1, step2, frequency): + for valid_date in dates: + date1 = valid_date - datetime.timedelta(hours=step1 + frequency) + + for step in range(step1, step2, frequency): + date = date1 + datetime.timedelta(hours=step) + yield ( + date.year * 10000 + date.month * 100 + date.day, + date.hour * 100 + date.minute, + step, + ) + + +def identity(x): + return x + + +def accumulations( + dates, + data_accumulation_period, + user_accumulation_period, + request, + patch=identity, +): + if not isinstance(user_accumulation_period, (list, tuple)): + user_accumulation_period = (0, user_accumulation_period) + + assert len(user_accumulation_period) == 2, user_accumulation_period + step1, step2 = user_accumulation_period + assert step1 < step2, user_accumulation_period + + if data_accumulation_period == 0: + mars_date_time_step = accumulations_from_start(dates, step1, step2) + else: + mars_date_time_step = accumulations_from_last_step(dates, step1, step2, data_accumulation_period) + + request = deepcopy(request) + + param = request["param"] + if not isinstance(param, (list, tuple)): + param = [param] + + for p in param: + assert p in ["cp", "lsp", "tp"], p + + number = request.get("number", [0]) + assert isinstance(number, (list, tuple)) + + stepping = data_accumulation_period + + type_ = request.get("type", "an") + if type_ == "an": + type_ = "fc" + + request.update({"type": type_, "levtype": "sfc"}) + + tmp = temp_file() + path = tmp.path + out = new_grib_output(path) + + requests = [] + + AccumulationClass = AccumulationFromStart if data_accumulation_period == 0 else AccumulationFromLastStep + + accumulations = {} + + for date, time, step in mars_date_time_step: + for p in param: + for n in number: + requests.append( + patch( + { + "param": p, + "date": date, + "time": time, + "step": step, + "number": n, + } + ) + ) + + key = (p, date, time, n) + if key not in accumulations: + accumulations[key] = AccumulationClass( + out, + stepping=stepping, + param=p, + date=date, + time=time, + number=number, + ) + + compressed = Availability(requests) + ds = cml.load_source("empty") + for r in compressed.iterate(): + request.update(r) + ds = ds + cml.load_source("mars", **request) + + for field in ds: + print(field) + key = ( + field.metadata("param"), + field.metadata("date"), + field.metadata("time"), + field.metadata("number"), + ) + values = field.values # optimisation + accumulations[key].add(field, values) + + for a in accumulations.values(): + assert a.done, (a.key, a.seen) + + out.close() + + ds = cml.load_source("file", path) + + assert len(ds) / len(param) / len(number) == len(dates), ( + len(ds), + len(param), + len(dates), + ) + ds._tmp = tmp + + return ds + + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + class: od + expver: '0001' + grid: 20./20. + levtype: sfc + param: tp + """ + ) + dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) + + print(dates) + + def scda(request): + if request["time"] in (600, 1800): + request["stream"] = "scda" + else: + request["stream"] = "oper" + return request + + ds = accumulations(dates, 0, (0, 6), config, scda) + print() + for f in ds: + print(f.valid_datetime()) diff --git a/ecml_tools/create/functions/actions/grib.py b/ecml_tools/create/functions/actions/grib.py index 7c5ff36..d1b0027 100644 --- a/ecml_tools/create/functions/actions/grib.py +++ b/ecml_tools/create/functions/actions/grib.py @@ -13,7 +13,6 @@ def check(ds, paths, **kwargs): - count = 1 for k, v in kwargs.items(): if isinstance(v, (tuple, list)): diff --git a/ecml_tools/create/functions/steps/unrotate_winds.py b/ecml_tools/create/functions/steps/unrotate_winds.py index e10ac85..c5fcde0 100644 --- a/ecml_tools/create/functions/steps/unrotate_winds.py +++ b/ecml_tools/create/functions/steps/unrotate_winds.py @@ -38,7 +38,6 @@ def rotate_winds( south_pole_longitude, south_pole_rotation_angle=0, ): - # Code from MIR assert south_pole_rotation_angle == 0 C = np.deg2rad(90 - south_pole_latitude) diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index 0dfdb4b..a70a1c7 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -392,7 +392,10 @@ def load_result(self, result): dates = result.dates cube = result.get_cube() - assert cube.extended_user_shape[0] == len(dates), (cube.extended_user_shape[0], len(dates)) + assert cube.extended_user_shape[0] == len(dates), ( + cube.extended_user_shape[0], + len(dates), + ) shape = cube.extended_user_shape dates_in_data = cube.user_coords["valid_datetime"] @@ -445,7 +448,12 @@ def load_cube(self, cube, array): load += time.time() - now name = self.variables_names[local_indexes[1]] - check_data_values(data[:], name=name, log=[i, data.shape, local_indexes], allow_nan=self.allow_nan) + check_data_values( + data[:], + name=name, + log=[i, data.shape, local_indexes], + allow_nan=self.allow_nan, + ) now = time.time() array[local_indexes] = data @@ -493,7 +501,10 @@ def _get_statistics_dates(self): # remove missing dates if self.missing_dates: - assert type(self.missing_dates[0]) is dtype, (type(self.missing_dates[0]), dtype) + assert type(self.missing_dates[0]) is dtype, ( + type(self.missing_dates[0]), + dtype, + ) dates = [d for d in dates if d not in self.missing_dates] # filter dates according the the start and end dates in the metadata diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 8fbc068..17981c4 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -97,7 +97,13 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False): squares[i] = np.nansum(values * values, axis=1) count[i] = np.sum(~np.isnan(values), axis=1) - return {"minimum": minimum, "maximum": maximum, "sums": sums, "squares": squares, "count": count} + return { + "minimum": minimum, + "maximum": maximum, + "sums": sums, + "squares": squares, + "count": count, + } class TempStatistics: @@ -201,7 +207,10 @@ def check_type(a, b): offset = 0 for _, _dates, stats in self.owner._gather_data(): assert isinstance(stats, dict), stats - assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates)) + assert stats["minimum"].shape[0] == len(_dates), ( + stats["minimum"].shape, + len(_dates), + ) assert stats["minimum"].shape[1] == len(self.variables_names), ( stats["minimum"].shape, len(self.variables_names), @@ -226,13 +235,19 @@ def check_type(a, b): for k in self.NAMES: stats[k] = stats[k][bitmap] - assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates)) + assert stats["minimum"].shape[0] == len(dates), ( + stats["minimum"].shape, + len(dates), + ) # store data in self found |= set(dates) for name in self.NAMES: array = getattr(self, name) - assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates)) + assert stats[name].shape[0] == len(dates), ( + stats[name].shape, + len(dates), + ) array[offset : offset + len(dates)] = stats[name] offset += len(dates) @@ -243,7 +258,6 @@ def check_type(a, b): print(f"Statistics for {len(found)} dates found.") def aggregate(self): - minimum = np.nanmin(self.minimum, axis=0) maximum = np.nanmax(self.maximum, axis=0) sums = np.nansum(self.sums, axis=0) diff --git a/ecml_tools/data/dataset.py b/ecml_tools/data/dataset.py index d1db030..f9be152 100644 --- a/ecml_tools/data/dataset.py +++ b/ecml_tools/data/dataset.py @@ -12,15 +12,10 @@ from .debug import debug_indexing from .indexing import expand_list_indexing -# from .misc import _as_first_date -# from .misc import _as_last_date -# from .misc import _frequency_to_hours - LOG = logging.getLogger(__name__) class Dataset: - arguments = {} @cached_property @@ -82,6 +77,12 @@ def _subset(self, **kwargs): method = kwargs.pop("method", "every-nth") return Thinning(self, thinning, method)._subset(**kwargs) + if "area" in kwargs: + from .masked import Cropping + + bbox = kwargs.pop("area") + return Cropping(self, bbox)._subset(**kwargs) + raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs)) def _frequency_to_indices(self, frequency): diff --git a/ecml_tools/data/debug.py b/ecml_tools/data/debug.py index 48c7599..e394b41 100644 --- a/ecml_tools/data/debug.py +++ b/ecml_tools/data/debug.py @@ -28,7 +28,6 @@ def __init__(self, dataset, kids, **kwargs): self.kwargs = kwargs def _put(self, indent, result): - def _spaces(indent): return " " * indent if indent else "" diff --git a/ecml_tools/data/ensemble.py b/ecml_tools/data/ensemble.py index 342748d..dad5080 100644 --- a/ecml_tools/data/ensemble.py +++ b/ecml_tools/data/ensemble.py @@ -16,7 +16,6 @@ class Ensemble(GivenAxis): - def tree(self): return Node(self, [d.tree() for d in self.datasets]) diff --git a/ecml_tools/data/masked.py b/ecml_tools/data/masked.py index de4c2c2..3cd8aa1 100644 --- a/ecml_tools/data/masked.py +++ b/ecml_tools/data/masked.py @@ -10,9 +10,11 @@ import numpy as np -from . import Forwards +from ..grids import cropping_mask +from .dataset import Dataset from .debug import Node from .debug import debug_indexing +from .forewards import Forwards from .indexing import apply_index_to_slices_changes from .indexing import expand_list_indexing from .indexing import index_to_slices @@ -22,7 +24,6 @@ class Masked(Forwards): - def __init__(self, forward, mask): super().__init__(forward) assert len(forward.shape) == 4, "Grids must be 1D for now" @@ -65,7 +66,6 @@ def _get_tuple(self, index): class Thinning(Masked): - def __init__(self, forward, thinning, method): self.thinning = thinning self.method = method @@ -82,3 +82,22 @@ def __init__(self, forward, thinning, method): def tree(self): return Node(self, [self.forward.tree()], thinning=self.thinning, method=self.method) + + +class Cropping(Masked): + def __init__(self, forward, area): + self.area = area + + if isinstance(area, Dataset): + north = np.amax(area.latitudes) + south = np.amin(area.latitudes) + east = np.amax(area.longitudes) + west = np.amin(area.longitudes) + area = (north, west, south, east) + + mask = cropping_mask(forward.latitudes, forward.longitudes, *area) + + super().__init__(forward, mask) + + def tree(self): + return Node(self, [self.forward.tree()], area=self.area) diff --git a/tests/test_data.py b/tests/test_data.py index 1112e71..e73cc95 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -256,7 +256,7 @@ def run( assert isinstance(self.ds, expected_class) assert len(self.ds) == expected_length assert len([row for row in self.ds]) == len(self.ds) - assert self.ds.shape == expected_shape + assert self.ds.shape == expected_shape, (self.ds.shape, expected_shape) assert self.ds.variables == expected_variables assert self.ds.name_to_index == expected_name_to_index assert self.ds.dates[0] == start_date @@ -1096,5 +1096,14 @@ def test_statistics(): ) +@mockup_open_zarr +def test_cropping(): + test = DatasetTester( + "test-2021-2021-6h-o96-abcd", + area=(18, 11, 11, 18), + ) + assert test.ds.shape == (365 * 4, 4, 1, 8) + + if __name__ == "__main__": - test_simple() + test_cropping()