From bac0f46e880b40e4355e6cdc2841667b31aeb673 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 26 Feb 2024 13:46:46 +0000 Subject: [PATCH] merge develop --- .gitignore | 2 + .pre-commit-config.yaml | 9 +- ecml_tools/__init__.py | 2 +- ecml_tools/__main__.py | 74 ++ ecml_tools/commands/__init__.py | 84 +++ ecml_tools/commands/create.py | 24 + ecml_tools/commands/inspect/__init__.py | 45 ++ ecml_tools/commands/inspect/zarr.py | 639 ++++++++++++++++ ecml_tools/commands/scan.py | 115 +++ ecml_tools/create/__init__.py | 19 +- ecml_tools/create/check.py | 65 +- ecml_tools/create/config.py | 39 +- ecml_tools/create/expand.py | 183 ----- ecml_tools/create/functions/__init__.py | 55 ++ .../create/functions/actions/__init__.py | 8 + .../create/functions/actions/accumulations.py | 86 +++ .../create/functions/actions/constants.py | 17 + ecml_tools/create/functions/actions/empty.py | 14 + .../{ => actions}/ensemble_perturbations.py | 57 +- ecml_tools/create/functions/actions/grib.py | 50 ++ ecml_tools/create/functions/actions/mars.py | 119 +++ ecml_tools/create/functions/actions/netcdf.py | 57 ++ .../create/functions/actions/opendap.py | 14 + ecml_tools/create/functions/actions/source.py | 51 ++ .../create/functions/actions/tendencies.py | 147 ++++ ecml_tools/create/functions/steps/__init__.py | 8 + ecml_tools/create/functions/steps/empty.py | 16 + ecml_tools/create/functions/steps/noop.py | 12 + ecml_tools/create/functions/steps/rename.py | 30 + .../create/functions/steps/rotate_winds.py | 141 ++++ ecml_tools/create/group.py | 123 --- ecml_tools/create/input.py | 712 +++++++++++------- ecml_tools/create/loaders.py | 75 +- ecml_tools/create/template.py | 276 ++++--- ecml_tools/create/writer.py | 34 +- ecml_tools/data.py | 159 +++- ecml_tools/grids.py | 266 +++++++ ecml_tools/utils/__init__.py | 0 ecml_tools/utils/dates/__init__.py | 113 +++ ecml_tools/utils/dates/groups.py | 85 +++ ecml_tools/utils/humanize.py | 377 ++++++++++ ecml_tools/utils/text.py | 241 ++++++ setup.py | 3 +- tests/_test_create.py | 26 +- tests/create-1.yaml | 11 +- tests/create-concat.yaml | 54 +- tests/create-join.yaml | 69 +- tests/create-perturbations-full.yaml | 21 +- tests/create-perturbations.yaml | 17 +- tests/create-pipe.yaml | 82 +- tests/create-shift.yaml | 51 ++ tests/test_data.py | 208 ++--- 52 files changed, 4001 insertions(+), 1184 deletions(-) create mode 100644 ecml_tools/__main__.py create mode 100644 ecml_tools/commands/__init__.py create mode 100644 ecml_tools/commands/create.py create mode 100644 ecml_tools/commands/inspect/__init__.py create mode 100644 ecml_tools/commands/inspect/zarr.py create mode 100644 ecml_tools/commands/scan.py delete mode 100644 ecml_tools/create/expand.py create mode 100644 ecml_tools/create/functions/actions/__init__.py create mode 100644 ecml_tools/create/functions/actions/accumulations.py create mode 100644 ecml_tools/create/functions/actions/constants.py create mode 100644 ecml_tools/create/functions/actions/empty.py rename ecml_tools/create/functions/{ => actions}/ensemble_perturbations.py (78%) create mode 100644 ecml_tools/create/functions/actions/grib.py create mode 100644 ecml_tools/create/functions/actions/mars.py create mode 100644 ecml_tools/create/functions/actions/netcdf.py create mode 100644 ecml_tools/create/functions/actions/opendap.py create mode 100644 ecml_tools/create/functions/actions/source.py create mode 100644 ecml_tools/create/functions/actions/tendencies.py create mode 100644 ecml_tools/create/functions/steps/__init__.py create mode 100644 ecml_tools/create/functions/steps/empty.py create mode 100644 ecml_tools/create/functions/steps/noop.py create mode 100644 ecml_tools/create/functions/steps/rename.py create mode 100644 ecml_tools/create/functions/steps/rotate_winds.py delete mode 100644 ecml_tools/create/group.py create mode 100644 ecml_tools/grids.py create mode 100644 ecml_tools/utils/__init__.py create mode 100644 ecml_tools/utils/dates/__init__.py create mode 100644 ecml_tools/utils/dates/groups.py create mode 100644 ecml_tools/utils/humanize.py create mode 100644 ecml_tools/utils/text.py mode change 100644 => 100755 tests/_test_create.py create mode 100644 tests/create-shift.yaml diff --git a/.gitignore b/.gitignore index 2272c2b..8ee614c 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,5 @@ bar *.zarr/ ~$images.pptx test.py +cutout.png +*.out diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb2664a..eb9748a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,10 +10,11 @@ repos: - id: debug-statements # Check for debugger imports and py37+ breakpoint() - id: end-of-file-fixer # Ensure files end in a newline - id: trailing-whitespace # Trailing whitespace checker -- repo: https://github.com/asottile/reorder-python-imports # Reorder imports - rev: v3.10.0 - hooks: - - id: reorder-python-imports +# conflicting with "isort" +# - repo: https://github.com/asottile/reorder-python-imports # Reorder imports +# rev: v3.10.0 +# hooks: +# - id: reorder-python-imports - repo: https://github.com/asottile/pyupgrade # Upgrade Python syntax rev: v3.7.0 hooks: diff --git a/ecml_tools/__init__.py b/ecml_tools/__init__.py index 4d1f504..db7a161 100644 --- a/ecml_tools/__init__.py +++ b/ecml_tools/__init__.py @@ -5,4 +5,4 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -__version__ = "0.4.5" +__version__ = "0.4.6" diff --git a/ecml_tools/__main__.py b/ecml_tools/__main__.py new file mode 100644 index 0000000..bad1f0f --- /dev/null +++ b/ecml_tools/__main__.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import argparse +import logging +import sys +import traceback + +from . import __version__ +from .commands import COMMANDS + +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--version", + "-V", + action="store_true", + help="show the version and exit", + ) + parser.add_argument( + "--debug", + "-d", + action="store_true", + help="Debug mode", + ) + + subparsers = parser.add_subparsers(help="commands:", dest="command") + for name, command in COMMANDS.items(): + command_parser = subparsers.add_parser(name, help=command.__doc__) + command.add_arguments(command_parser) + + args = parser.parse_args() + + if args.version: + print(__version__) + return + + if args.command is None: + parser.print_help() + return + + cmd = COMMANDS[args.command] + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG if args.debug else logging.INFO, + ) + + try: + cmd.run(args) + except ValueError as e: + traceback.print_exc() + LOG.error("\nπŸ’£ %s", str(e).lstrip()) + LOG.error("πŸ’£ Exiting") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ecml_tools/commands/__init__.py b/ecml_tools/commands/__init__.py new file mode 100644 index 0000000..41bf43f --- /dev/null +++ b/ecml_tools/commands/__init__.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import argparse +import importlib +import logging +import os +import sys + +LOG = logging.getLogger(__name__) + + +def register(here, package, select, fail=None): + result = {} + not_available = {} + + for p in os.listdir(here): + full = os.path.join(here, p) + if p.startswith("_"): + continue + if not ( + p.endswith(".py") + or ( + os.path.isdir(full) + and os.path.exists(os.path.join(full, "__init__.py")) + ) + ): + continue + + name, _ = os.path.splitext(p) + + try: + imported = importlib.import_module( + f".{name}", + package=package, + ) + except ImportError as e: + not_available[name] = e + continue + + obj = select(imported) + if obj is not None: + result[name] = obj + + for name, e in not_available.items(): + if fail is None: + pass + if callable(fail): + result[name] = fail(name, e) + + return result + + +class Command: + def run(self, args): + raise NotImplementedError(f"Command not implemented: {args.command}") + + +class Failed(Command): + def __init__(self, name, error): + self.name = name + self.error = error + + def add_arguments(self, command_parser): + command_parser.add_argument("x", nargs=argparse.REMAINDER) + + def run(self, args): + print(f"Command '{self.name}' not available: {self.error}") + sys.exit(1) + + +COMMANDS = register( + os.path.dirname(__file__), + __name__, + lambda x: x.command(), + lambda name, error: Failed(name, error), +) diff --git a/ecml_tools/commands/create.py b/ecml_tools/commands/create.py new file mode 100644 index 0000000..2aff08c --- /dev/null +++ b/ecml_tools/commands/create.py @@ -0,0 +1,24 @@ +from ecml_tools.create import Creator + +from . import Command + + +class Create(Command): + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing files" + ) + command_parser.add_argument("config", help="Configuration file") + command_parser.add_argument("path", help="Path to store the created data") + + def run(self, args): + kwargs = vars(args) + + c = Creator(**kwargs) + c.create() + + +command = Create diff --git a/ecml_tools/commands/inspect/__init__.py b/ecml_tools/commands/inspect/__init__.py new file mode 100644 index 0000000..93b6b65 --- /dev/null +++ b/ecml_tools/commands/inspect/__init__.py @@ -0,0 +1,45 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import os + +from .. import Command + +# from .checkpoint import InspectCheckpoint +from .zarr import InspectZarr + + +class Inspect(Command, InspectZarr): + # class Inspect(Command, InspectCheckpoint, InspectZarr): + """Inspect a checkpoint or zarr file.""" + + def add_arguments(self, command_parser): + # g = command_parser.add_mutually_exclusive_group() + # g.add_argument("--inspect", action="store_true", help="Inspect weights") + command_parser.add_argument("path", metavar="PATH", nargs="+") + command_parser.add_argument("--detailed", action="store_true") + # command_parser.add_argument("--probe", action="store_true") + command_parser.add_argument("--progress", action="store_true") + command_parser.add_argument("--statistics", action="store_true") + command_parser.add_argument("--size", action="store_true", help="Print size") + + def run(self, args): + dic = vars(args) + for path in dic.pop("path"): + if ( + os.path.isdir(path) + or path.endswith(".zarr.zip") + or path.endswith(".zarr") + ): + self.inspect_zarr(path=path, **dic) + else: + raise ValueError(f"Unknown file type: {path}") + # self.inspect_checkpoint(path=path, **dic) + + +command = Inspect diff --git a/ecml_tools/commands/inspect/zarr.py b/ecml_tools/commands/inspect/zarr.py new file mode 100644 index 0000000..1737492 --- /dev/null +++ b/ecml_tools/commands/inspect/zarr.py @@ -0,0 +1,639 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import os +from copy import deepcopy +from functools import cached_property + +import numpy as np +import semantic_version +import tqdm + +from ecml_tools.data import open_dataset, open_zarr +from ecml_tools.utils.humanize import bytes, number, when +from ecml_tools.utils.text import dotted_line, progress, table + +LOG = logging.getLogger(__name__) + + +def compute_directory_size(path): + if not os.path.isdir(path): + return None, None + size = 0 + n = 0 + for dirpath, _, filenames in tqdm.tqdm( + os.walk(path), desc="Computing size", leave=False + ): + for filename in filenames: + file_path = os.path.join(dirpath, filename) + size += os.path.getsize(file_path) + n += 1 + return size, n + + +def local_time_bug(lon, date): + delta = date - datetime.datetime(date.year, date.month, date.day) + hours_since_midnight = delta.days + delta.seconds / 86400.0 # * 24 is missing + return (lon / 360.0 * 24.0 + hours_since_midnight) % 24 + + +def cos_local_time_bug(lon, date): + radians = local_time_bug(lon, date) / 24 * np.pi * 2 + return np.cos(radians) + + +def find(config, name): + if isinstance(config, dict): + if name in config: + return config[name] + + for k, v in config.items(): + r = find(v, name) + if r is not None: + return r + + if isinstance(config, list): + for v in config: + r = find(v, name) + if r is not None: + return r + + return None + + +class Version: + def __init__(self, path, zarr, metadata, version): + self.path = path + self.zarr = zarr + self.metadata = metadata + self.version = version + self.dataset = None + # try: + self.dataset = open_dataset(self.path) + # except Exception as e: + # LOG.error("Error opening dataset '%s': %s", self.path, e) + + def describe(self): + print(f"πŸ“¦ Path : {self.path}") + print(f"πŸ”’ Format version: {self.version}") + + def probe(self): + if "cos_local_time" not in self.name_to_index: + print("⚠️ probe: no cos_local_time") + return + + try: + lon = self.longitudes + except AttributeError: + print("⚠️ probe: no longitudes") + return + # print(json.dumps(self.metadata, indent=4)) + cos_local_time = self.name_to_index["cos_local_time"] + data = self.data + start, end, frequency = self.first_date, self.last_date, self.frequency + date = start + same = 0 + for i in range(10): + field = data[i, cos_local_time] + buggy = cos_local_time_bug(lon, date).reshape(field.hape) + diff = np.abs(field - buggy) + if np.max(diff) < 1e-5: + same += 1 + date += datetime.timedelta(hours=frequency) + if date > end: + break + if same > 1: + print("❌ probe: cos_local_time is buggy") + return + + print("βœ… probe: cos_local_time is fixed") + + @property + def name_to_index(self): + return find(self.metadata, "name_to_index") + + @property + def longitudes(self): + try: + return self.zarr.longitudes[:] + except (KeyError, AttributeError): + return self.zarr.longitude[:] + + @property + def data(self): + try: + return self.zarr.data + except AttributeError: + return self.zarr + + @property + def first_date(self): + return datetime.datetime.fromisoformat(self.metadata["first_date"]) + + @property + def last_date(self): + return datetime.datetime.fromisoformat(self.metadata["last_date"]) + + @property + def frequency(self): + return self.metadata["frequency"] + + @property + def resolution(self): + return self.metadata["resolution"] + + @property + def shape(self): + if self.data and hasattr(self.data, "shape"): + return self.data.shape + + @property + def uncompressed_data_size(self): + if self.data and hasattr(self.data, "dtype") and hasattr(self.data, "size"): + return self.data.dtype.itemsize * self.data.size + + def info(self, detailed, size): + print() + print(f'πŸ“… Start : {self.first_date.strftime("%Y-%m-%d %H:%M")}') + print(f'πŸ“… End : {self.last_date.strftime("%Y-%m-%d %H:%M")}') + print(f"⏰ Frequency : {self.frequency}h") + print(f"🌎 Resolution: {self.resolution}") + + print() + shape_str = "πŸ“ Shape : " + if self.shape: + shape_str += " Γ— ".join(["{:,}".format(s) for s in self.shape]) + if self.uncompressed_data_size: + shape_str += f" ({bytes(self.uncompressed_data_size)})" + print(shape_str) + self.print_sizes(size) + print() + rows = [] + + if self.statistics_ready: + stats = self.statistics + else: + stats = [["-"] * len(self.variables)] * 4 + + for i, v in enumerate(self.variables): + rows.append([i, v] + [x[i] for x in stats]) + + print( + table( + rows, + header=["Index", "Variable", "Min", "Max", "Mean", "Stdev"], + align=[">", "<", ">", ">", ">", ">"], + margin=3, + ) + ) + + if detailed: + self.details() + + self.progress() + if self.ready(): + self.probe() + + print() + + @property + def variables(self): + return [v[0] for v in sorted(self.name_to_index.items(), key=lambda x: x[1])] + + @property + def total_size(self): + return self.zarr.attrs.get("total_size") + + @property + def total_number_of_files(self): + return self.zarr.attrs.get("total_number_of_files") + + def print_sizes(self, size): + total_size = self.total_size + n = self.total_number_of_files + + if total_size is None: + if not size: + return + + total_size, n = compute_directory_size(self.path) + + if total_size is not None: + print(f"πŸ’½ Size : {bytes(total_size)} ({number(total_size)})") + if n is not None: + print(f"πŸ“ Files : {number(n)}") + + @property + def statistics(self): + try: + if self.dataset is not None: + stats = self.dataset.statistics + return stats["minimum"], stats["maximum"], stats["mean"], stats["stdev"] + except AttributeError: + return [["-"] * len(self.variables)] * 4 + + @property + def statistics_ready(self): + for d in reversed(self.metadata.get("history", [])): + if d["action"] == "compute_statistics_end": + return True + return False + + @property + def statistics_started(self): + for d in reversed(self.metadata.get("history", [])): + if d["action"] == "compute_statistics_start": + return datetime.datetime.fromisoformat(d["timestamp"]) + return None + + @property + def build_flags(self): + return self.zarr.get("_build_flags") + + @cached_property + def copy_flags(self): + if "_copy" not in self.zarr: + return None + return self.zarr["_copy"][:] + + @property + def copy_in_progress(self): + if "_copy" not in self.zarr: + return False + + start = self.zarr["_copy"].attrs.get("copy_start_timestamp") + end = self.zarr["_copy"].attrs.get("copy_end_timestamp") + if start and end: + return False + + return not all(self.copy_flags) + + @property + def build_lengths(self): + return self.zarr.get("_build_lengths") + + def progress(self): + if self.copy_in_progress: + copy_flags = self.copy_flags + print("πŸͺ« Dataset not ready, copy in progress.") + assert isinstance(copy_flags, np.ndarray) + total = len(copy_flags) + built = copy_flags.sum() + print( + "πŸ“ˆ Progress:", + progress(built, total, width=50), + "{:.0f}%".format(built / total * 100), + ) + return + + if self.build_flags is None: + print("πŸͺ« Dataset not initialized") + return + + build_flags = self.build_flags + + build_lengths = self.build_lengths + assert build_flags.size == build_lengths.size + + latest_write_timestamp = self.zarr.attrs.get("latest_write_timestamp") + latest = ( + datetime.datetime.fromisoformat(latest_write_timestamp) + if latest_write_timestamp + else None + ) + + if not all(build_flags): + if latest: + print(f"πŸͺ« Dataset not ready, last update {when(latest)}.") + else: + print("πŸͺ« Dataset not ready.") + total = sum(build_lengths) + built = sum( + ln if flag else 0 for ln, flag in zip(build_lengths, build_flags) + ) + print( + "πŸ“ˆ Progress:", + progress(built, total, width=50), + "{:.0f}%".format(built / total * 100), + ) + start = self.initialised + if self.initialised: + print(f"πŸ•°οΈ Dataset initialized {when(start)}.") + if built and latest: + speed = (latest - start) / built + eta = datetime.datetime.utcnow() + speed * (total - built) + print(f"🏁 ETA {when(eta)}.") + else: + if latest: + print(f"πŸ”‹ Dataset ready, last update {when(latest)}.") + else: + print("πŸ”‹ Dataset ready.") + if self.statistics_ready: + print("πŸ“Š Statistics ready.") + else: + started = self.statistics_started + if started: + print(f"⏳ Statistics not ready, started {when(started)}.") + else: + print("⏳ Statistics not ready.") + + def brute_force_statistics(self): + if self.dataset is None: + return + print("πŸ“Š Computing statistics...") + # np.seterr(all="raise") + + nvars = self.dataset.shape[1] + + count = np.zeros(nvars, dtype=np.int64) + sums = np.zeros(nvars, dtype=np.float32) + squares = np.zeros(nvars, dtype=np.float32) + + minimum = np.full((nvars,), np.inf, dtype=np.float32) + maximum = np.full((nvars,), -np.inf, dtype=np.float32) + + for i, chunk in enumerate(tqdm.tqdm(self.dataset, total=len(self.dataset))): + values = chunk.reshape((nvars, -1)) + minimum = np.minimum(minimum, np.min(values, axis=1)) + maximum = np.maximum(maximum, np.max(values, axis=1)) + sums += np.sum(values, axis=1) + squares += np.sum(values * values, axis=1) + count += values.shape[1] + + mean = sums / count + stats = [ + minimum, + maximum, + mean, + np.sqrt(squares / count - mean * mean), + ] + + rows = [] + + for i, v in enumerate(self.variables): + rows.append([i, v] + [x[i] for x in stats]) + + print( + table( + rows, + header=["Index", "Variable", "Min", "Max", "Mean", "Stdev"], + align=[">", "<", ">", ">", ">", ">"], + margin=3, + ) + ) + + +class NoVersion(Version): + @property + def first_date(self): + monthly = find(self.metadata, "monthly") + return datetime.datetime.fromisoformat(monthly["start"]) + + @property + def last_date(self): + monthly = find(self.metadata, "monthly") + time = max([int(t) for t in find(self.metadata["climetlab"], "time")]) + assert isinstance(time, int), (time, type(time)) + if time > 100: + time = time // 100 + return datetime.datetime.fromisoformat(monthly["stop"]) + datetime.timedelta( + hours=time + ) + + @property + def frequency(self): + time = find(self.metadata["climetlab"], "time") + return 24 // len(time) + + @property + def statistics(self): + stats = find(self.metadata, "statistics_by_index") + return stats["minimum"], stats["maximum"], stats["mean"], stats["stdev"] + + @property + def statistics_ready(self): + return find(self.metadata, "statistics_by_index") is not None + + @property + def resolution(self): + return find(self.metadata, "grid") + + def details(self): + pass + + def progress(self): + pass + + def ready(self): + return True + + +class Version0_4(Version): + def details(self): + pass + + @property + def initialised(self): + return datetime.datetime.fromisoformat(self.metadata["creation_timestamp"]) + + def statistics_ready(self): + if not self.ready(): + return False + build_flags = self.zarr["_build_flags"] + return build_flags.attrs.get("_statistics_computed") + + def ready(self): + if "_build_flags" not in self.zarr: + return False + + build_flags = self.zarr["_build_flags"] + if not build_flags.attrs.get("_initialised"): + return False + + return all(build_flags) + + def _info(self, verbose, history, statistics, **kwargs): + z = self.zarr + + # for backward compatibility + if "climetlab" in z.attrs: + climetlab_version = ( + z.attrs["climetlab"].get("versions", {}).get("climetlab", "unkwown") + ) + print( + f"climetlab version used to create this zarr: {climetlab_version}. Not supported." + ) + return + + version = z.attrs.get("version") + versions = z.attrs.get("versions") + if not versions: + print(" Cannot find metadata information about versions.") + else: + print(f"Zarr format (version {version})", end="") + print(f" created by climetlab={versions.pop('climetlab')}", end="") + timestamp = z.attrs.get("creation_timestamp") + timestamp = datetime.datetime.fromisoformat(timestamp) + print(f" on {timestamp}", end="") + versions = ", ".join([f"{k}={v}" for k, v in versions.items()]) + print(f" using {versions}", end="") + print() + + +class Version0_6(Version): + @property + def initialised(self): + for record in self.metadata.get("history", []): + if record["action"] == "initialised": + return datetime.datetime.fromisoformat(record["timestamp"]) + + # Sometimes the first record is missing + timestamps = sorted( + [ + datetime.datetime.fromisoformat(d["timestamp"]) + for d in self.metadata.get("history", []) + ] + ) + if timestamps: + return timestamps[0] + + return None + + def details(self): + print() + for d in self.metadata.get("history", []): + d = deepcopy(d) + timestamp = d.pop("timestamp") + timestamp = datetime.datetime.fromisoformat(timestamp) + action = d.pop("action") + versions = d.pop("versions") + versions = ", ".join(f"{k}={v}" for k, v in versions.items()) + more = ", ".join(f"{k}={v}" for k, v in d.items()) + print(f" {timestamp} : {action} ({versions}) {more}") + print() + + def ready(self): + if "_build_flags" not in self.zarr: + return False + + build_flags = self.zarr["_build_flags"] + return all(build_flags) + + @property + def name_to_index(self): + return {n: i for i, n in enumerate(self.metadata["variables"])} + + @property + def variables(self): + return self.metadata["variables"] + + +class Version0_12(Version0_6): + def details(self): + print() + for d in self.metadata.get("history", []): + d = deepcopy(d) + timestamp = d.pop("timestamp") + timestamp = datetime.datetime.fromisoformat(timestamp) + action = d.pop("action") + more = ", ".join(f"{k}={v}" for k, v in d.items()) + if more: + more = f" ({more})" + print(f" {timestamp} : {action}{more}") + print() + + @property + def first_date(self): + return datetime.datetime.fromisoformat(self.metadata["start_date"]) + + @property + def last_date(self): + return datetime.datetime.fromisoformat(self.metadata["end_date"]) + + +class Version0_13(Version0_12): + @property + def build_flags(self): + if "_build" not in self.zarr: + return None + build = self.zarr["_build"] + return build.get("flags") + + @property + def build_lengths(self): + if "_build" not in self.zarr: + return None + build = self.zarr["_build"] + return build.get("lengths") + + +VERSIONS = { + "0.0.0": NoVersion, + "0.4.0": Version0_4, + "0.6.0": Version0_6, + "0.12.0": Version0_12, + "0.13.0": Version0_13, +} + + +class InspectZarr: + """Inspect a checkpoint or zarr file.""" + + def inspect_zarr(self, path, **kwargs): + version = self._info(path) + + # try: + # with open("/tmp/probe.json", "w") as f: + # json.dump(version.metadata, f, indent=4, sort_keys=True) + # except Exception: + # pass + + dotted_line() + version.describe() + + try: + if kwargs.get("probe"): + return version.probe() + if kwargs.get("progress"): + return version.progress() + if kwargs.get("statistics"): + return version.brute_force_statistics() + version.info(kwargs.get("detailed"), kwargs.get("size")) + except Exception as e: + LOG.error("Error inspecting zarr file '%s': %s", path, e) + + print(type(version)) + raise + + def _info(self, path): + if path.endswith("/"): + path = path[:-1] + + try: + z = open_zarr(path) + except Exception as e: + LOG.error("Error opening zarr file '%s': %s", path, e) + raise + + metadata = dict(z.attrs) + version = metadata.get("version", "0.0.0") + if isinstance(version, int): + version = f"0.{version}" + + version = semantic_version.Version.coerce(version) + + versions = {semantic_version.Version.coerce(k): v for k, v in VERSIONS.items()} + + candidate = None + for v, klass in sorted(versions.items()): + if version >= v: + candidate = klass + + return candidate(path, z, metadata, version) diff --git a/ecml_tools/commands/scan.py b/ecml_tools/commands/scan.py new file mode 100644 index 0000000..e7bbf40 --- /dev/null +++ b/ecml_tools/commands/scan.py @@ -0,0 +1,115 @@ +import os +import sys +from collections import defaultdict + +import climetlab as cml +import tqdm +import yaml + +from . import Command + +KEYS = ("class", "type", "stream", "expver", "levtype", "domain") + + +class Scan(Command): + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument( + "--extension", default=".grib", help="Extension of the files to scan" + ) + command_parser.add_argument( + "--magic", + help="File 'magic' to use to identify the file type. Overrides --extension", + ) + command_parser.add_argument("paths", nargs="+", help="Paths to scan") + + def run(self, args): + EXTENSIONS = { + ".grib": "grib", + ".grib1": "grib", + ".grib2": "grib", + ".grb": "grib", + ".nc": "netcdf", + ".nc4": "netcdf", + } + + MAGICS = { + "GRIB": "grib", + } + + if args.magic: + what = MAGICS[args.magic] + args.magic = args.magic.encode() + else: + what = EXTENSIONS[args.extension] + + def match(path): + if args.magic: + with open(path, "rb") as f: + return args.magic == f.read(len(args.magic)) + else: + return path.endswith(args.extension) + + paths = [] + for path in args.paths: + if os.path.isfile(path): + + paths.append(path) + else: + for root, _, files in os.walk(path): + for file in files: + full = os.path.join(root, file) + paths.append(full) + + dates = set() + gribs = defaultdict(set) + unique = defaultdict(lambda: defaultdict(set)) + + for path in tqdm.tqdm(paths, leave=False): + if not match(path): + continue + for field in tqdm.tqdm(cml.load_source("file", path), leave=False): + dates.add(field.valid_datetime()) + mars = field.as_mars() + keys = tuple(mars.get(k) for k in KEYS) + gribs[keys].add(path) + for k, v in mars.items(): + if k not in KEYS + ("date", "time", "step"): + unique[keys][k].add(v) + + config = dict( + description=f"Generated by {sys.argv})", + dataset_status="experimental", + purpose="aifs", + name="test", + config_format_version=2, + dates=dict(values=sorted(dates)), + input=dict(join=[]), + output=dict( + chunking=dict(dates=1, ensembles=1), + dtype="float32", + flatten_grid=True, + order_by=["valid_datetime", "param_level", "number"], + statistics="param_level", + statistics_end=2020, + remapping=dict(param_level="{param}_{levelist}"), + ), + ) + + for k, v in sorted(gribs.items()): + request = {what: dict(path=sorted(v), **dict(zip(KEYS, k)))} + for k, v in sorted(unique[k].items()): + if len(v) == 1: + request[what][k] = list(v)[0] + else: + request[what][k] = sorted(v) + + config["input"]["join"].append(request) + + with open("scan-config.yaml", "w") as f: + print(yaml.dump(config, sort_keys=False), file=f) + + +command = Scan diff --git a/ecml_tools/create/__init__.py b/ecml_tools/create/__init__.py index 0afa998..0c22af9 100644 --- a/ecml_tools/create/__init__.py +++ b/ecml_tools/create/__init__.py @@ -21,14 +21,12 @@ def __init__( overwrite=False, **kwargs, ): - self.path = path + self.path = path # Output path self.config = config self.cache = cache self.print = print self.statistics_tmp = statistics_tmp self.overwrite = overwrite - # if kwargs: - # raise ValueError(f"Unknown arguments {kwargs}") def init(self, check_name=False): # check path @@ -41,10 +39,8 @@ def init(self, check_name=False): f"{self.path} already exists. Use overwrite=True to overwrite." ) - cls = InitialiseLoader - with self._cache_context(): - obj = cls.from_config( + obj = InitialiseLoader.from_config( path=self.path, config=self.config, statistics_tmp=self.statistics_tmp, @@ -103,6 +99,16 @@ def size(self): loader = SizeLoader.from_dataset(path=self.path, print=self.print) loader.add_total_size() + def cleanup(self): + from .loaders import CleanupLoader + + loader = CleanupLoader.from_dataset( + path=self.path, + print=self.print, + statistics_tmp=self.statistics_tmp, + ) + loader.run() + def patch(self, **kwargs): from .patch import apply_patch @@ -116,6 +122,7 @@ def create(self): self.init() self.load() self.finalise() + self.cleanup() def _cache_context(self): from .utils import cache_context diff --git a/ecml_tools/create/check.py b/ecml_tools/create/check.py index c1ff16c..acc5be9 100644 --- a/ecml_tools/create/check.py +++ b/ecml_tools/create/check.py @@ -44,6 +44,9 @@ def __init__( ): self.name = name self.parsed = self._parse(name) + print("---------------") + print(self.parsed) + print("---------------") self.messages = [] @@ -59,10 +62,6 @@ def __init__( + "/".join(f"{k}={v}" for k, v in self.parsed.items()) ) - @property - def is_valid(self): - return not self.messages - @property def error_message(self): out = " And ".join(self.messages) @@ -71,40 +70,34 @@ def error_message(self): return out def raise_if_not_valid(self, print=print): - if not self.is_valid: + if self.messages: for m in self.messages: print(m) raise ValueError(self.error_message) def _parse(self, name): - pattern = r"^(\w+)-(\w+)-(\w+)-(\w+)-(\w\w\w\w)-(\w+)-(\w+)-([\d\-]+)-(\d+h)-v(\d+)-?(.*)$" + pattern = ( + r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$" + ) match = re.match(pattern, name) + assert match, (name, pattern) + parsed = {} if match: keys = [ - "use_case", - "class_", - "type_", - "stream", - "expver", + "purpose", + "labelling", "source", "resolution", - "period", + "start_date", + "end_date", "frequency", "version", "additional", ] parsed = {k: v for k, v in zip(keys, match.groups())} - period = parsed["period"].split("-") - assert len(period) in (1, 2), (name, period) - parsed["start_date"] = period[0] - if len(period) == 1: - parsed["end_date"] = period[0] - if len(period) == 2: - parsed["end_date"] = period[1] - return parsed def __str__(self): @@ -113,11 +106,9 @@ def __str__(self): def check_parsed(self): if not self.parsed: self.messages.append( - ( - f"the dataset name {self} does not follow naming convention. " - "See here for details: " - "https://confluence.ecmwf.int/display/DWF/Datasets+available+as+zarr" - ) + f"the dataset name {self} does not follow naming convention. " + "See here for details: " + "https://confluence.ecmwf.int/display/DWF/Datasets+available+as+zarr" ) def check_resolution(self, resolution): @@ -126,10 +117,8 @@ def check_resolution(self, resolution): and self.parsed["resolution"][0] not in "0123456789on" ): self.messages.append( - ( - f"the resolution {self.parsed['resolution'] } should start " - f"with a number or 'o' or 'n' in the dataset name {self}." - ) + f"the resolution {self.parsed['resolution'] } should start " + f"with a number or 'o' or 'n' in the dataset name {self}." ) if resolution is None: @@ -149,7 +138,7 @@ def check_start_date(self, start_date): if start_date is None: return start_date_str = str(start_date.year) - self._check_missing("first date", start_date_str) + self._check_missing("start_date", start_date_str) self._check_mismatch("start_date", start_date_str) def check_end_date(self, end_date): @@ -162,13 +151,13 @@ def check_end_date(self, end_date): def _check_missing(self, key, value): if value not in self.name: self.messages.append( - (f"the {key} is {value}, but is missing in {self.name}.") + f"the {key} is {value}, but is missing in {self.name}." ) def _check_mismatch(self, key, value): if self.parsed.get(key) and self.parsed[key] != value: self.messages.append( - (f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.") + f"the {key} is {value}, but is {self.parsed[key]} in {self.name}." ) @@ -195,17 +184,13 @@ def check_data_values(arr, *, name: str, log=[]): if name in limits: if min < limits[name]["minimum"]: raise StatisticsValueError( - ( - f"For {name}: minimum value in the data is {min}. " - "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" - ) + f"For {name}: minimum value in the data is {min}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" ) if max > limits[name]["maximum"]: raise StatisticsValueError( - ( - f"For {name}: maximum value in the data is {max}. " - "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" - ) + f"For {name}: maximum value in the data is {max}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" ) diff --git a/ecml_tools/create/config.py b/ecml_tools/create/config.py index 210432e..3092afe 100644 --- a/ecml_tools/create/config.py +++ b/ecml_tools/create/config.py @@ -9,7 +9,6 @@ import datetime import logging import os -import warnings from copy import deepcopy import yaml @@ -143,11 +142,25 @@ def __init__(self, config, *args, **kwargs): # deprecated/obsolete if "order" in self.output: - raise ValueError(f"Do not use 'order'. Use order_by in {self}") + raise ValueError( + f"Do not use 'order'. Use order_by instead. {list(self.keys())}" + ) if "loops" in self: - assert "loop" not in self - warnings.warn("Should use loop instead of loops in config") - self.loop = self.pop("loops") + raise ValueError( + f"Do not use 'loops'. Use dates instead. {list(self.keys())}" + ) + if "loop" in self: + raise ValueError( + f"Do not use 'loop'. Use dates instead. {list(self.keys())}" + ) + + if not isinstance(self.dates, dict): + raise ValueError(f"Dates must be a dict. Got {self.dates}") + + if "licence" not in self: + raise ValueError("Must provide a licence in the config.") + if "copyright" not in self: + raise ValueError("Must provide a copyright in the config.") self.normalise() @@ -155,10 +168,6 @@ def normalise(self): if isinstance(self.input, (tuple, list)): self.input = dict(concat=self.input) - if not isinstance(self.loop, list): - assert isinstance(self.loop, dict), self.loop - self.loop = [dict(loop_a=self.loop)] - if "order_by" in self.output: self.output.order_by = normalize_order_by(self.output.order_by) @@ -217,6 +226,12 @@ def ensure_element_in_list(self, lst, elt, index): return lst[:index] + [elt] + lst[index:] + def get_serialisable_dict(self): + return _prepare_serialisation(self) + + def get_variables_names(self): + return self.output.order_by[self.output.statistics] + class UnknownPurposeConfig(LoadersConfig): purpose = "unknown" @@ -253,12 +268,6 @@ def normalise(self): super().normalise() # must be called last - def get_serialisable_dict(self): - return _prepare_serialisation(self) - - def get_variables_names(self): - return self.output.order_by[self.output.statistics] - def _prepare_serialisation(o): if isinstance(o, dict): diff --git a/ecml_tools/create/expand.py b/ecml_tools/create/expand.py deleted file mode 100644 index 6de5c50..0000000 --- a/ecml_tools/create/expand.py +++ /dev/null @@ -1,183 +0,0 @@ -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# -import datetime -import itertools -from functools import cached_property - -from .utils import to_datetime - - -class GroupByDays: - def __init__(self, days): - self.days = days - - def __call__(self, dt): - year = dt.year - days = (dt - datetime.datetime(year, 1, 1)).days - x = (year, days // self.days) - return x - - -class Expand(list): - """ - This class is used to expand loops. - It creates a list of list in self.groups. - Flatten values are in self.values. - """ - - def __init__(self, config): - assert isinstance(config, dict), config - for k, v in config.items(): - assert not isinstance(v, dict), (k, v) - self._config = config - if "stop" in self._config: - raise ValueError(f"Use 'end' not 'stop' in loop. {self._config}") - - @property - def values(self): - raise NotImplementedError() - - @property - def groups(self): - raise NotImplementedError(type(self)) - - -class ValuesExpand(Expand): - def __init__(self, config): - super().__init__(config) - if not isinstance(self._config, dict): - raise ValueError(f"Config must be a dict. {self._config}") - if not isinstance(self._config["values"], list): - raise ValueError(f"Values must be a list. {self._config}") - - @property - def values(self): - return self._config["values"] - - -class StartStopExpand(Expand): - def __init__(self, config): - super().__init__(config) - self.start = self._config["start"] - self.end = self._config["end"] - - @property - def values(self): - x = self.start - all = [] - while x <= self.end: - all.append(x) - yield x - x += self.step - - def format(self, x): - return x - - @cached_property - def groups(self): - groups = [] - for _, g in itertools.groupby(self.values, key=self.grouper_key): - g = [self.format(x) for x in g] - groups.append(g) - return groups - - -class DateStartStopExpand(StartStopExpand): - def __init__(self, config): - super().__init__(config) - - self.start = to_datetime(self.start) - self.end = to_datetime(self.end) - assert isinstance(self.start, datetime.date), (type(self.start), self.start) - assert isinstance(self.end, datetime.date), (type(self.end), self.end) - - frequency = self._config.get("frequency", "24h") - - if frequency.lower().endswith("h"): - freq = int(frequency[:-1]) - elif frequency.lower().endswith("d"): - freq = int(frequency[:-1]) * 24 - else: - raise ValueError( - f"Frequency must be in hours or days (12h or 2d). {frequency}" - ) - - if freq > 24 and freq % 24 != 0: - raise ValueError( - f"Frequency must be less than 24h or a multiple of 24h. {frequency}" - ) - - self.step = datetime.timedelta(hours=freq) - - @property - def grouper_key(self): - group_by = self._config.get("group_by") - if isinstance(group_by, int) and group_by > 0: - return GroupByDays(group_by) - return { - None: lambda dt: 0, # only one group - 0: lambda dt: 0, # only one group - "monthly": lambda dt: (dt.year, dt.month), - "daily": lambda dt: (dt.year, dt.month, dt.day), - "weekly": lambda dt: (dt.weekday(),), - "MMDD": lambda dt: (dt.month, dt.day), - }[group_by] - - def format(self, x): - assert isinstance(x, datetime.date), (type(x), x) - return x - - -class IntegerStartStopExpand(StartStopExpand): - def __init__(self, config): - super().__init__(config) - self.step = self._config.get("step", 1) - assert isinstance(self.step, int), config - assert isinstance(self.start, int), config - assert isinstance(self.end, int), config - - def grouper_key(self, x): - group_by = self._config["group_by"] - return { - 1: lambda x: 0, # only one group - None: lambda x: x, # one group per value - }[group_by](x) - - -def expand_class(config): - if isinstance(config, list): - config = {"values": config} - - assert isinstance(config, dict), config - - if "start" not in config and "values" not in config: - raise ValueError(f"Cannot expand loop from {config}") - - if isinstance(config.get("values"), list): - assert len(config) == 1, f"No other config keys implemented. {config}" - return ValuesExpand - - if ( - config.get("group_by") - in [ - "monthly", - "daily", - "weekly", - ] - or isinstance(config["start"], datetime.datetime) - or isinstance(config["end"], datetime.datetime) - or "frequency" in config - or (config.get("kind") == "dates" and "start" in config) - ): - return DateStartStopExpand - - if isinstance(config["start"], int) or isinstance(config["end"], int): - return IntegerStartStopExpand - - raise ValueError(f"Cannot expand loop from {config}") diff --git a/ecml_tools/create/functions/__init__.py b/ecml_tools/create/functions/__init__.py index e69de29..9dd5da9 100644 --- a/ecml_tools/create/functions/__init__.py +++ b/ecml_tools/create/functions/__init__.py @@ -0,0 +1,55 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from collections import defaultdict + + +def assert_is_fieldset(obj): + from climetlab.readers.grib.index import FieldSet + + assert isinstance(obj, FieldSet), type(obj) + + +def wrapped_mars_source(name, param, **kwargs): + from climetlab import load_source + + assert name == "mars", name # untested with other sources + + for_accumlated = dict( + ea="era5-accumulations", + oper="oper-accumulations", + ei="oper-accumulations", + )[kwargs["class"]] + + param_to_source = defaultdict(lambda: "mars") + param_to_source.update( + dict( + tp=for_accumlated, + cp=for_accumlated, + lsp=for_accumlated, + ) + ) + + source_names = defaultdict(list) + for p in param: + source_names[param_to_source[p]].append(p) + + sources = [] + for n, params in source_names.items(): + sources.append(load_source(n, param=params, **patch_time_to_hours(kwargs))) + return load_source("multi", sources) + + +def patch_time_to_hours(dic): + # era5-accumulations requires time in hours + if "time" not in dic: + return dic + time = dic["time"] + assert isinstance(time, (tuple, list)), time + time = [f"{int(t[:2]):02d}" for t in time] + return {**dic, "time": time} diff --git a/ecml_tools/create/functions/actions/__init__.py b/ecml_tools/create/functions/actions/__init__.py new file mode 100644 index 0000000..33d7fa0 --- /dev/null +++ b/ecml_tools/create/functions/actions/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# diff --git a/ecml_tools/create/functions/actions/accumulations.py b/ecml_tools/create/functions/actions/accumulations.py new file mode 100644 index 0000000..75f0359 --- /dev/null +++ b/ecml_tools/create/functions/actions/accumulations.py @@ -0,0 +1,86 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from copy import deepcopy + +from climetlab import load_source + +from ecml_tools.create.functions.actions.mars import factorise_requests +from ecml_tools.create.utils import to_datetime_list + +DEBUG = True + + +def to_list(x): + if isinstance(x, (list, tuple)): + return x + return [x] + + +def normalise_time_to_hours(r): + r = deepcopy(r) + if "time" not in r: + return r + + times = [] + for t in to_list(r["time"]): + assert len(t) == 4, r + assert t.endswith("00"), r + times.append(int(t) // 100) + r["time"] = tuple(times) + return r + + +def accumulations(context, dates, **request): + to_list(request["param"]) + class_ = request["class"] + + source_name = dict( + ea="era5-accumulations", + oper="oper-accumulations", + ei="oper-accumulations", + )[class_] + + requests = factorise_requests(dates, request) + + ds = load_source("empty") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + r = normalise_time_to_hours(r) + + if DEBUG: + print(f"load_source({source_name}, {r}") + ds = ds + load_source(source_name, **r) + return ds + + +execute = accumulations + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + class: ea + expver: '0001' + grid: 20./20. + levtype: sfc +# number: [0, 1] +# stream: enda + param: [cp, tp] +# accumulation_period: 6h + """ + ) + dates = yaml.safe_load( + "[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]" + ) + dates = to_datetime_list(dates) + + DEBUG = True + for f in accumulations(None, dates, **config): + print(f, f.to_numpy().mean()) diff --git a/ecml_tools/create/functions/actions/constants.py b/ecml_tools/create/functions/actions/constants.py new file mode 100644 index 0000000..2ea043a --- /dev/null +++ b/ecml_tools/create/functions/actions/constants.py @@ -0,0 +1,17 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from climetlab import load_source + + +def constants(context, dates, template, param): + context.trace("βœ…", f"load_source(constants, {template}, {param}") + return load_source("constants", source_or_dataset=template, date=dates, param=param) + + +execute = constants diff --git a/ecml_tools/create/functions/actions/empty.py b/ecml_tools/create/functions/actions/empty.py new file mode 100644 index 0000000..1076f5d --- /dev/null +++ b/ecml_tools/create/functions/actions/empty.py @@ -0,0 +1,14 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import climetlab as cml + + +def execute(context, dates, **kwargs): + return cml.load_source("empty") diff --git a/ecml_tools/create/functions/ensemble_perturbations.py b/ecml_tools/create/functions/actions/ensemble_perturbations.py similarity index 78% rename from ecml_tools/create/functions/ensemble_perturbations.py rename to ecml_tools/create/functions/actions/ensemble_perturbations.py index dc61526..1ff3fdf 100644 --- a/ecml_tools/create/functions/ensemble_perturbations.py +++ b/ecml_tools/create/functions/actions/ensemble_perturbations.py @@ -1,4 +1,4 @@ -# (C) Copyright 2020 ECMWF. +# (C) Copyright 2024 ECMWF. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,51 +7,65 @@ # nor does it submit to any jurisdiction. # import warnings +from copy import deepcopy import numpy as np import tqdm -from climetlab import load_source from climetlab.core.temporary import temp_file from climetlab.readers.grib.output import new_grib_output from ecml_tools.create.check import check_data_values +from ecml_tools.create.functions import assert_is_fieldset, wrapped_mars_source -def get_unique_field(ds, selection): - ds = ds.sel(**selection) - assert len(ds) == 1, (ds, selection) - return ds[0] +def to_list(x): + if isinstance(x, (list, tuple)): + return x + if isinstance(x, str): + return x.split("/") + return [x] def normalise_number(number): - if isinstance(number, (tuple, list, int)): - return number + number = to_list(number) - assert isinstance(number, str), (type(number), number) - - number = number.split("/") if len(number) > 4 and (number[1] == "to" and number[3] == "by"): return list(range(int(number[0]), int(number[2]) + 1, int(number[4]))) if len(number) > 2 and number[1] == "to": return list(range(int(number[0]), int(number[2]) + 1)) - assert isinstance(number, list), (type(number), number) return number +def normalise_request(request): + request = deepcopy(request) + if "number" in request: + request["number"] = normalise_number(request["number"]) + if "time" in request: + request["time"] = to_list(request["time"]) + request["param"] = to_list(request["param"]) + return request + + def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}): - number_list = normalise_number(ensembles["number"]) + from climetlab import load_source + + ensembles = normalise_request(ensembles) + center = normalise_request(center) + mean = normalise_request(mean) + + number_list = ensembles["number"] n_numbers = len(number_list) keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"] print(f"Retrieving ensemble data with {ensembles}") - ensembles = load_source(**ensembles).order_by(*keys) + ensembles = wrapped_mars_source(**ensembles).order_by(*keys) print(f"Retrieving center data with {center}") - center = load_source(**center).order_by(*keys) + center = wrapped_mars_source(**center).order_by(*keys) print(f"Retrieving mean data with {mean}") - mean = load_source(**mean).order_by(*keys) + mean = wrapped_mars_source(**mean).order_by(*keys) assert len(mean) * n_numbers == len(ensembles), ( len(mean), @@ -99,11 +113,17 @@ def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}): c = center_field.to_numpy() assert m.shape == c.shape, (m.shape, c.shape) + FORCED_POSITIVE = [ + "q", + "cp", + "lsp", + "tp", + ] # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? ################################# # Actual computation happens here x = c - m + e - if param == "q": - warnings.warn("Clipping q") + if param in FORCED_POSITIVE: + warnings.warn(f"Clipping {param} to be positive") x = np.maximum(x, 0) ################################# @@ -115,6 +135,7 @@ def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}): out.close() ds = load_source("file", path) + assert_is_fieldset(ds) # save a reference to the tmp file so it is deleted # only when the dataset is not used anymore ds._tmp = tmp diff --git a/ecml_tools/create/functions/actions/grib.py b/ecml_tools/create/functions/actions/grib.py new file mode 100644 index 0000000..0607c22 --- /dev/null +++ b/ecml_tools/create/functions/actions/grib.py @@ -0,0 +1,50 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from climetlab import load_source +from climetlab.utils.patterns import Pattern + + +def check(ds, paths, **kwargs): + count = 1 + for k, v in kwargs.items(): + if isinstance(v, (tuple, list)): + count *= len(v) + + if len(ds) != count: + raise ValueError( + f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})" + ) + + +def execute(context, dates, path, *args, **kwargs): + given_paths = path if isinstance(path, list) else [path] + + ds = load_source("empty") + dates = [d.isoformat() for d in dates] + + for path in given_paths: + paths = Pattern(path, ignore_missing_keys=True).substitute( + *args, date=dates, **kwargs + ) + + for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): + if name in kwargs: + raise ValueError(f"MARS interpolation parameter '{name}' not supported") + + for path in paths: + context.trace("πŸ“", "PATH", path) + s = load_source("file", path) + s = s.sel(valid_datetime=dates, **kwargs) + ds = ds + s + + check(ds, given_paths, valid_datetime=dates, **kwargs) + + return ds diff --git a/ecml_tools/create/functions/actions/mars.py b/ecml_tools/create/functions/actions/mars.py new file mode 100644 index 0000000..525d9ee --- /dev/null +++ b/ecml_tools/create/functions/actions/mars.py @@ -0,0 +1,119 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import datetime +from copy import deepcopy + +from climetlab import load_source +from climetlab.utils.availability import Availability + +from ecml_tools.create.utils import to_datetime_list + +DEBUG = False + + +def to_list(x): + if isinstance(x, (list, tuple)): + return x + return [x] + + +def _date_to_datetime(d): + if isinstance(d, datetime.datetime): + return d + if isinstance(d, (list, tuple)): + return [_date_to_datetime(x) for x in d] + return datetime.datetime.fromisoformat(d) + + +def normalise_time_delta(t): + if isinstance(t, datetime.timedelta): + assert t == datetime.timedelta(hours=t.hours), t + + assert t.endswith("h"), t + + t = int(t[:-1]) + t = datetime.timedelta(hours=t) + return t + + +def _expand_mars_request(request, date): + requests = [] + step = to_list(request.get("step", [0])) + for s in step: + r = deepcopy(request) + base = date - datetime.timedelta(hours=int(s)) + r.update( + { + "date": base.strftime("%Y%m%d"), + "time": base.strftime("%H%M"), + "step": s, + } + ) + requests.append(r) + return requests + + +def factorise_requests(dates, *requests): + updates = [] + for req in requests: + # req = normalise_request(req) + + for d in dates: + updates += _expand_mars_request(req, date=d) + + compressed = Availability(updates) + return compressed.iterate() + + +def mars(context, dates, *requests, **kwargs): + if not requests: + requests = [kwargs] + + requests = factorise_requests(dates, *requests) + ds = load_source("empty") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + if DEBUG: + context.trace("βœ…", f"load_source(mars, {r}") + ds = ds + load_source("mars", **r) + return ds + + +execute = mars + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + - class: ea + expver: '0001' + grid: 20.0/20.0 + levtype: sfc + param: [2t] + # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z] + number: [0, 1] + + # - class: ea + # expver: '0001' + # grid: 20.0/20.0 + # levtype: pl + # param: [q] + # levelist: [1000, 850] + + """ + ) + dates = yaml.safe_load( + "[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]" + ) + dates = to_datetime_list(dates) + + DEBUG = True + for f in mars(None, dates, *config): + print(f, f.to_numpy().mean()) diff --git a/ecml_tools/create/functions/actions/netcdf.py b/ecml_tools/create/functions/actions/netcdf.py new file mode 100644 index 0000000..335bd92 --- /dev/null +++ b/ecml_tools/create/functions/actions/netcdf.py @@ -0,0 +1,57 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from climetlab import load_source +from climetlab.utils.patterns import Pattern + + +def check(what, ds, paths, **kwargs): + count = 1 + for k, v in kwargs.items(): + if isinstance(v, (tuple, list)): + count *= len(v) + + if len(ds) != count: + raise ValueError( + f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})" + ) + + +def load_netcdfs(emoji, what, context, dates, path, *args, **kwargs): + given_paths = path if isinstance(path, list) else [path] + + dates = [d.isoformat() for d in dates] + ds = load_source("empty") + + for path in given_paths: + paths = Pattern(path, ignore_missing_keys=True).substitute( + *args, date=dates, **kwargs + ) + + levels = kwargs.get("level", kwargs.get("levelist")) + + for path in paths: + context.trace(emoji, what.upper(), path) + s = load_source("opendap", path) + s = s.sel( + valid_datetime=dates, + param=kwargs["param"], + step=kwargs.get("step", 0), + ) + if levels: + s = s.sel(levelist=levels) + ds = ds + s + + check(what, ds, given_paths, valid_datetime=dates, **kwargs) + + return ds + + +def execute(context, dates, path, *args, **kwargs): + return load_netcdfs("πŸ“", "path", context, dates, path, *args, **kwargs) diff --git a/ecml_tools/create/functions/actions/opendap.py b/ecml_tools/create/functions/actions/opendap.py new file mode 100644 index 0000000..ffbfc3e --- /dev/null +++ b/ecml_tools/create/functions/actions/opendap.py @@ -0,0 +1,14 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from .netcdf import load_netcdfs + + +def execute(context, dates, url, *args, **kwargs): + return load_netcdfs("🌐", "url", context, dates, url, *args, **kwargs) diff --git a/ecml_tools/create/functions/actions/source.py b/ecml_tools/create/functions/actions/source.py new file mode 100644 index 0000000..13f00b5 --- /dev/null +++ b/ecml_tools/create/functions/actions/source.py @@ -0,0 +1,51 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from climetlab import load_source + +from ecml_tools.create.utils import to_datetime_list + +DEBUG = True + + +def source(context, dates, **kwargs): + name = kwargs.pop("name") + context.trace("βœ…", f"load_source({name}, {dates}, {kwargs}") + if kwargs["date"] == "$from_dates": + kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates}) + if kwargs["time"] == "$from_dates": + kwargs["time"] = list({d.strftime("%H%M") for d in dates}) + return load_source(name, **kwargs) + + +execute = source + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + name: mars + class: ea + expver: '0001' + grid: 20.0/20.0 + levtype: sfc + param: [2t] + number: [0, 1] + date: $from_dates + time: $from_dates + """ + ) + dates = yaml.safe_load( + "[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]" + ) + dates = to_datetime_list(dates) + + DEBUG = True + for f in source(None, dates, **config): + print(f, f.to_numpy().mean()) diff --git a/ecml_tools/create/functions/actions/tendencies.py b/ecml_tools/create/functions/actions/tendencies.py new file mode 100644 index 0000000..b859146 --- /dev/null +++ b/ecml_tools/create/functions/actions/tendencies.py @@ -0,0 +1,147 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import datetime +from collections import defaultdict + +from climetlab.core.temporary import temp_file +from climetlab.readers.grib.output import new_grib_output + +from ecml_tools.create.functions import assert_is_fieldset +from ecml_tools.create.utils import to_datetime_list + + +def _date_to_datetime(d): + if isinstance(d, (list, tuple)): + return [_date_to_datetime(x) for x in d] + return datetime.datetime.fromisoformat(d) + + +def normalise_time_delta(t): + if isinstance(t, datetime.timedelta): + assert t == datetime.timedelta(hours=t.hours), t + + assert t.endswith("h"), t + + t = int(t[:-1]) + t = datetime.timedelta(hours=t) + return t + + +def group_by_field(ds): + d = defaultdict(list) + for field in ds.order_by("valid_datetime"): + m = field.as_mars() + for k in ("date", "time", "step"): + m.pop(k, None) + keys = tuple(m.items()) + d[keys].append(field) + return d + + +def tendencies(dates, time_increment, **kwargs): + print("βœ…", kwargs) + time_increment = normalise_time_delta(time_increment) + + shifted_dates = [d - time_increment for d in dates] + all_dates = sorted(list(set(dates + shifted_dates))) + + # from .mars import execute as mars + from ecml_tools.create.functions.mars import execute as mars + + ds = mars(dates=all_dates, **kwargs) + + dates_in_data = ds.unique_values("valid_datetime")["valid_datetime"] + for d in all_dates: + assert d.isoformat() in dates_in_data, d + + ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates]) + ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates]) + + assert len(ds1) == len(ds2), (len(ds1), len(ds2)) + + group1 = group_by_field(ds1) + group2 = group_by_field(ds2) + + assert group1.keys() == group2.keys(), (group1.keys(), group2.keys()) + + # prepare output tmp file so we can read it back + tmp = temp_file() + path = tmp.path + out = new_grib_output(path) + + for k in group1: + assert len(group1[k]) == len(group2[k]), k + print() + print("❌", k) + + for field, b_field in zip(group1[k], group2[k]): + for k in ["param", "level", "number", "grid", "shape"]: + assert field.metadata(k) == b_field.metadata(k), ( + k, + field.metadata(k), + b_field.metadata(k), + ) + + c = field.to_numpy() + b = b_field.to_numpy() + assert c.shape == b.shape, (c.shape, b.shape) + + ################ + # Actual computation happens here + x = c - b + ################ + + assert x.shape == c.shape, c.shape + print( + f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}" + ) + out.write(x, template=field) + + out.close() + + from climetlab import load_source + + ds = load_source("file", path) + assert_is_fieldset(ds) + # save a reference to the tmp file so it is deleted + # only when the dataset is not used anymore + ds._tmp = tmp + + return ds + + +execute = tendencies + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + + config: + time_increment: 12h + database: marser + class: ea + # date: computed automatically + # time: computed automatically + expver: "0001" + grid: 20.0/20.0 + levtype: sfc + param: [2t] + """ + )["config"] + + dates = yaml.safe_load( + "[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]" + ) + dates = to_datetime_list(dates) + + DEBUG = True + for f in tendencies(dates, **config): + print(f, f.to_numpy().mean()) diff --git a/ecml_tools/create/functions/steps/__init__.py b/ecml_tools/create/functions/steps/__init__.py new file mode 100644 index 0000000..33d7fa0 --- /dev/null +++ b/ecml_tools/create/functions/steps/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# diff --git a/ecml_tools/create/functions/steps/empty.py b/ecml_tools/create/functions/steps/empty.py new file mode 100644 index 0000000..384e344 --- /dev/null +++ b/ecml_tools/create/functions/steps/empty.py @@ -0,0 +1,16 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import climetlab as cml + + +def execute(context, input, **kwargs): + # Usefull to create a pipeline that returns an empty result + # So we can reference an earlier step in a function like 'contants' + return cml.load_source("empty") diff --git a/ecml_tools/create/functions/steps/noop.py b/ecml_tools/create/functions/steps/noop.py new file mode 100644 index 0000000..9dbdccd --- /dev/null +++ b/ecml_tools/create/functions/steps/noop.py @@ -0,0 +1,12 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +def execute(context, input, *args, **kwargs): + return input diff --git a/ecml_tools/create/functions/steps/rename.py b/ecml_tools/create/functions/steps/rename.py new file mode 100644 index 0000000..41a8876 --- /dev/null +++ b/ecml_tools/create/functions/steps/rename.py @@ -0,0 +1,30 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from climetlab.indexing.fieldset import FieldArray + + +class RenamedField: + def __init__(self, field, what, renaming): + self.field = field + self.what = what + self.renaming = renaming + + def metadata(self, key): + value = self.field.metadata(key) + if key == self.what: + return self.renaming.get(value, value) + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute(context, input, what="param", **kwargs): + return FieldArray([RenamedField(fs, what, kwargs) for fs in input]) diff --git a/ecml_tools/create/functions/steps/rotate_winds.py b/ecml_tools/create/functions/steps/rotate_winds.py new file mode 100644 index 0000000..4d9feb8 --- /dev/null +++ b/ecml_tools/create/functions/steps/rotate_winds.py @@ -0,0 +1,141 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +from climetlab.indexing.fieldset import FieldArray + + +def rotate_winds(lats, lons, x_wind, y_wind, source_projection, target_projection): + """ + Code provided by MetNO + """ + import numpy as np + import pyproj + + if source_projection == target_projection: + return x_wind, x_wind + + source_projection = pyproj.Proj(source_projection) + target_projection = pyproj.Proj(target_projection) + + transformer = pyproj.transformer.Transformer.from_proj( + source_projection, target_projection + ) + + # To compute the new vector components: + # 1) perturb each position in the direction of the winds + # 2) convert the perturbed positions into the new coordinate system + # 3) measure the new x/y components. + # + # A complication occurs when using the longlat "projections", since this is not a cartesian grid + # (i.e. distances in each direction is not consistent), we need to deal with the fact that the + # width of a longitude varies with latitude + orig_speed = np.sqrt(x_wind**2 + y_wind**2) + + x0, y0 = source_projection(lons, lats) + + if source_projection.name != "longlat": + x1 = x0 + x_wind + y1 = y0 + y_wind + else: + # Reduce the perturbation, since x_wind and y_wind are in meters, which would create + # large perturbations in lat, lon. Also, deal with the fact that the width of longitude + # varies with latitude. + factor = 3600000.0 + x1 = x0 + x_wind / factor / np.cos(np.deg2rad(lats)) + y1 = y0 + y_wind / factor + + X0, Y0 = transformer.transform(x0, y0) + X1, Y1 = transformer.transform(x1, y1) + + new_x_wind = X1 - X0 + new_y_wind = Y1 - Y0 + if target_projection.name == "longlat": + new_x_wind *= np.cos(np.deg2rad(lats)) + + if target_projection.name == "longlat" or source_projection.name == "longlat": + # Ensure the wind speed is not changed (which might not the case since the units in longlat + # is degrees, not meters) + curr_speed = np.sqrt(new_x_wind**2 + new_y_wind**2) + new_x_wind *= orig_speed / curr_speed + new_y_wind *= orig_speed / curr_speed + + return new_x_wind, new_y_wind + + +class NewDataField: + def __init__(self, field, data): + self.field = field + self.data = data + + def to_numpy(self, *args, **kwargs): + return self.data + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute( + context, + input, + x_wind, + y_wind, + source_projection=None, + target_projection="+proj=longlat", +): + from pyproj import CRS + + result = FieldArray() + + wind_params = (x_wind, y_wind) + wind_pairs = defaultdict(dict) + + for f in input: + key = f.as_mars() + param = key.pop("param") + + if param not in wind_params: + result.append(f) + continue + + key = tuple(key.items()) + + if param in wind_pairs[key]: + raise ValueError(f"Duplicate wind component {param} for {key}") + + wind_pairs[key][param] = f + + for _, pairs in wind_pairs.items(): + if len(pairs) != 2: + raise ValueError("Missing wind component") + + x = pairs[x_wind] + y = pairs[y_wind] + + assert x.grid_mapping == y.grid_mapping + + lats, lons = x.grid_points() + x_new, y_new = rotate_winds( + lats, + lons, + x.to_numpy(reshape=False), + y.to_numpy(reshape=False), + ( + source_projection + if source_projection is not None + else CRS.from_cf(x.grid_mapping) + ), + target_projection, + ) + + result.append(NewDataField(x, x_new)) + result.append(NewDataField(y, y_new)) + + return result diff --git a/ecml_tools/create/group.py b/ecml_tools/create/group.py deleted file mode 100644 index ac44c39..0000000 --- a/ecml_tools/create/group.py +++ /dev/null @@ -1,123 +0,0 @@ -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# -import datetime -from functools import cached_property - -from .expand import expand_class - - -class Group(list): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert len(self) >= 1, self - - def __repr__(self): - content = ",".join([str(_.strftime("%Y-%m-%d:%H")) for _ in self]) - # content = ",".join([str(_.strftime("%Y-%m-%d:%H")) for _ in self[:10]]) + "..." - return f"Group({len(self)}, {content})" - - -class BaseGroups: - def __repr__(self): - try: - content = "+".join([str(len(list(g))) for g in self.groups]) - for g in self.groups: - assert isinstance(g[0], datetime.datetime), g[0] - return f"{self.__class__.__name__}({content}={len(self.values)})({self.n_groups} groups)" - except: # noqa - return f"{self.__class__.__name__}({len(self.values)} dates)" - - def intersect(self, dates): - if dates is None: - return self - if not isinstance(dates, Groups): - dates = build_groups(dates) - return GroupsIntersection(self, dates) - - def empty(self): - return len(self.values) == 0 - - @cached_property - def n_groups(self): - return len(self.groups) - - @property - def frequency(self): - datetimes = self.values - freq = (datetimes[1] - datetimes[0]).total_seconds() / 3600 - assert round(freq) == freq, freq - assert int(freq) == freq, freq - frequency = int(freq) - return frequency - - -class Groups(BaseGroups): - def __init__(self, config): - assert isinstance(config, dict), config - for k, v in config.items(): - assert not isinstance(v, dict), (k, v) - - self._config = config - self.cls = expand_class(config) - - @cached_property - def values(self): - dates = list(self.cls(self._config).values) - assert isinstance(dates[0], datetime.datetime), dates[0] - return dates - - @property - def groups(self): - return [Group(g) for g in self.cls(self._config).groups] - - -class EmptyGroups(BaseGroups): - def __init__(self): - self.values = [] - self.groups = [] - - @property - def frequency(self): - return None - - -class GroupsIntersection(BaseGroups): - def __init__(self, a, b): - self.a = a - self.b = b - - @cached_property - def values(self): - intersection = [] - for e in self.a.values: - if e in self.b.values: - intersection.append(e) - return intersection - - -def build_groups(*objs): - if len(objs) > 1: - raise NotImplementedError() - obj = objs[0] - - if isinstance(obj, GroupsIntersection): - return obj - - if isinstance(obj, Groups): - return obj - - if isinstance(obj, list): - return Groups(dict(values=obj)) - - assert isinstance(obj, dict), obj - if "dates" in obj: - assert len(obj) == 1, obj - return Groups(dict(obj["dates"])) - - return Groups(obj) diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index eeba11e..4f2478d 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -9,50 +9,107 @@ import datetime import importlib import logging -import os import time from collections import defaultdict from copy import deepcopy -from functools import cached_property +from functools import cached_property, wraps +import numpy as np from climetlab.core.order import build_remapping - -from .group import build_groups -from .template import substitute +from climetlab.indexing.fieldset import FieldSet + +from ecml_tools.utils.dates import Dates + +from .template import ( + Context, + notify_result, + resolve, + substitute, + trace, + trace_datasource, + trace_select, +) from .utils import seconds LOG = logging.getLogger(__name__) -def merge_remappings(*remappings): - remapping = remappings[0] - for other in remappings[1:]: - if not other: - continue - assert other == remapping, ( - "Multiple inconsistent remappings not implemented", - other, - remapping, - ) - return remapping +def parse_function_name(name): + if "-" in name: + name, delta = name.split("-") + sign = -1 + elif "+" in name: + name, delta = name.split("+") + sign = 1 + + else: + return name, None + + assert delta[-1] == "h", (name, delta) + delta = sign * int(delta[:-1]) + return name, delta + + +def time_delta_to_string(delta): + assert isinstance(delta, datetime.timedelta), delta + seconds = delta.total_seconds() + hours = int(seconds // 3600) + assert hours * 3600 == seconds, delta + hours = abs(hours) + + if seconds > 0: + return f"plus_{hours}h" + if seconds == 0: + return "" + if seconds < 0: + return f"minus_{hours}h" + + +def import_function(name, kind): + return importlib.import_module( + f"..functions.{kind}.{name}", + package=__name__, + ).execute -def assert_is_fieldset(obj): - from climetlab.readers.grib.index import FieldSet +def is_function(name, kind): + name, delta = parse_function_name(name) + try: + import_function(name, kind) + return True + except ImportError: + return False + + +def assert_fieldset(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + assert isinstance(result, FieldSet), type(result) + return result + + return wrapper + + +def assert_is_fieldset(obj): assert isinstance(obj, FieldSet), type(obj) -def _datasource_request(data): +def _data_request(data): date = None params_levels = defaultdict(set) params_steps = defaultdict(set) + area = grid = None + for field in data: if not hasattr(field, "as_mars"): continue + if date is None: date = field.valid_datetime() + if field.valid_datetime() != date: continue @@ -86,23 +143,12 @@ def sort(old_dic): ) -class Cache: - pass - - class Coords: def __init__(self, owner): self.owner = owner - self.cache = Cache() + @cached_property def _build_coords(self): - assert isinstance(self.owner.context, Context), type(self.owner.context) - assert isinstance(self.owner, Result), type(self.owner) - assert hasattr(self.owner, "context"), self.owner - assert hasattr(self.owner, "datasource"), self.owner - assert hasattr(self.owner, "get_cube"), self.owner - self.owner.datasource - from_data = self.owner.get_cube().user_coords from_config = self.owner.context.order_by @@ -123,65 +169,84 @@ def _build_coords(self): from_data[variables_key], from_config[variables_key] ) ] - ), (from_data[variables_key], from_config[variables_key]) + ), ( + from_data[variables_key], + from_config[variables_key], + ) - self.cache.variables = from_data[variables_key] # "param_level" - self.cache.ensembles = from_data[ensembles_key] # "number" + self._variables = from_data[variables_key] # "param_level" + self._ensembles = from_data[ensembles_key] # "number" first_field = self.owner.datasource[0] grid_points = first_field.grid_points() + + lats, lons = grid_points + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + assert -90 <= south <= north <= 90, (south, north, first_field) + assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( + west, + east, + first_field, + ) + grid_values = list(range(len(grid_points[0]))) - self.cache.grid_points = grid_points - self.cache.resolution = first_field.resolution - self.cache.grid_values = grid_values + self._grid_points = grid_points + self._resolution = first_field.resolution + self._grid_values = grid_values - def __getattr__(self, name): - if name in [ - "variables", - "ensembles", - "resolution", - "grid_values", - "grid_points", - ]: - if not hasattr(self.cache, name): - self._build_coords() - return getattr(self.cache, name) - raise AttributeError(name) + @cached_property + def variables(self): + self._build_coords + return self._variables + + @cached_property + def ensembles(self): + self._build_coords + return self._ensembles + + @cached_property + def resolution(self): + self._build_coords + return self._resolution + + @cached_property + def grid_values(self): + self._build_coords + return self._grid_values + + @cached_property + def grid_points(self): + self._build_coords + return self._grid_points class HasCoordsMixin: - @property + @cached_property def variables(self): return self._coords.variables - @property + @cached_property def ensembles(self): return self._coords.ensembles - @property + @cached_property def resolution(self): return self._coords.resolution - @property + @cached_property def grid_values(self): return self._coords.grid_values - @property + @cached_property def grid_points(self): return self._coords.grid_points - @property - def dates(self): - if self._dates is None: - raise ValueError(f"No dates for {self}") - return self._dates.values - - @property - def frequency(self): - return self._dates.frequency - - @property + @cached_property def shape(self): return [ len(self.dates), @@ -190,7 +255,7 @@ def shape(self): len(self.grid_values), ] - @property + @cached_property def coords(self): return { "dates": self.dates, @@ -201,7 +266,7 @@ def coords(self): class Action: - def __init__(self, context, /, *args, **kwargs): + def __init__(self, context, action_path, /, *args, **kwargs): if "args" in kwargs and "kwargs" in kwargs: """We have: args = [] @@ -213,10 +278,11 @@ def __init__(self, context, /, *args, **kwargs): args = kwargs.pop("args") kwargs = kwargs.pop("kwargs") - assert isinstance(context, Context), type(context) + assert isinstance(context, ActionContext), type(context) self.context = context self.kwargs = kwargs self.args = args + self.action_path = action_path @classmethod def _short_str(cls, x): @@ -241,29 +307,43 @@ def select(self, dates, **kwargs): def _raise_not_implemented(self): raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + def _trace_select(self, dates): + return f"{self.__class__.__name__}({shorten(dates)})" + + +def shorten(dates): + if isinstance(dates, (list, tuple)): + dates = [d.isoformat() for d in dates] + if len(dates) > 5: + return f"{dates[0]}...{dates[-1]}" + return dates + class Result(HasCoordsMixin): empty = False - def __init__(self, context, dates=None): - assert isinstance(context, Context), type(context) + def __init__(self, context, action_path, dates): + assert isinstance(context, ActionContext), type(context) + assert isinstance(action_path, list), action_path + self.context = context self._coords = Coords(self) - self._dates = dates + self.dates = dates + self.action_path = action_path @property + @trace_datasource def datasource(self): self._raise_not_implemented() @property def data_request(self): - """Returns a dictionary with the parameters needed to retrieve the data""" - return _datasource_request(self.datasource) + """Returns a dictionary with the parameters needed to retrieve the data.""" + return _data_request(self.datasource) def get_cube(self): - print(f"getting cube from {self.__class__.__name__}") + trace("🧊", f"getting cube from {self.__class__.__name__}") ds = self.datasource - assert_is_fieldset(ds) remapping = self.context.remapping order_by = self.context.order_by @@ -287,7 +367,7 @@ def __repr__(self, *args, _indent_="\n", **kwargs): more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) dates = " no-dates" - if self._dates is not None: + if self.dates is not None: dates = f" {len(self.dates)} dates" dates += " (" dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.dates) @@ -304,14 +384,19 @@ def __repr__(self, *args, _indent_="\n", **kwargs): def _raise_not_implemented(self): raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + def _trace_datasource(self, *args, **kwargs): + return f"{self.__class__.__name__}({shorten(self.dates)})" + class EmptyResult(Result): empty = True - def __init__(self, context, dates=None): - super().__init__(context) + def __init__(self, context, action_path, dates): + super().__init__(context, action_path + ["empty"], dates) @cached_property + @assert_fieldset + @trace_datasource def datasource(self): from climetlab import load_source @@ -322,44 +407,39 @@ def variables(self): return [] -class ReferencesSolver(dict): - def __init__(self, context, dates): - self.context = context - self.dates = dates - - def __getitem__(self, key): - if key == "dates": - return self.dates.values - if key in self.context.references: - result = self.context.references[key] - return result.datasource - raise KeyError(key) - - class FunctionResult(Result): - def __init__(self, context, dates, action, previous_sibling=None): - super().__init__(context, dates) + def __init__(self, context, action_path, dates, action): + super().__init__(context, action_path, dates) assert isinstance(action, Action), type(action) self.action = action - _args = self.action.args - _kwargs = self.action.kwargs - - vars = ReferencesSolver(context, dates) + self.args, self.kwargs = substitute( + context, (self.action.args, self.action.kwargs) + ) - self.args = substitute(_args, vars) - self.kwargs = substitute(_kwargs, vars) + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.name}({shorten(self.dates)})" @cached_property + @assert_fieldset + @notify_result + @trace_datasource def datasource(self): - print(f"loading source with {self.args} {self.kwargs}") - return self.action.function(*self.args, **self.kwargs) + args, kwargs = resolve(self.context, (self.args, self.kwargs)) - def __repr__(self): - content = " ".join([f"{v}" for v in self.args]) - content += " ".join([f"{k}={v}" for k, v in self.kwargs.items()]) + try: + return self.action.function( + FunctionContext(self), self.dates, *args, **kwargs + ) + except Exception: + LOG.error(f"Error in {self.action.function.__name__}", exc_info=True) + raise - return super().__repr__(content) + def __repr__(self): + try: + return f"{self.action.name}({shorten(self.dates)})" + except Exception: + return f"{self.__class__.__name__}(unitialised)" @property def function(self): @@ -367,16 +447,18 @@ def function(self): class JoinResult(Result): - def __init__(self, context, dates, results, **kwargs): - super().__init__(context, dates) + def __init__(self, context, action_path, dates, results, **kwargs): + super().__init__(context, action_path, dates) self.results = [r for r in results if not r.empty] - @property + @cached_property + @assert_fieldset + @notify_result + @trace_datasource def datasource(self): - ds = EmptyResult(self.context, self._dates).datasource + ds = EmptyResult(self.context, self.action_path, self.dates).datasource for i in self.results: ds += i.datasource - assert_is_fieldset(ds), i return ds def __repr__(self): @@ -384,78 +466,123 @@ def __repr__(self): return super().__repr__(content) -class LabelAction(Action): - def __init__(self, context, name, **kwargs): - super().__init__(context) - if len(kwargs) != 1: - raise ValueError(f"Invalid kwargs for label : {kwargs}") - self.name = name - self.content = action_factory(kwargs, context) +class DateShiftAction(Action): + def __init__(self, context, action_path, delta, **kwargs): + super().__init__(context, action_path, **kwargs) + + if isinstance(delta, str): + if delta[0] == "-": + delta, sign = int(delta[1:]), -1 + else: + delta, sign = int(delta), 1 + delta = datetime.timedelta(hours=sign * delta) + assert isinstance(delta, int), delta + delta = datetime.timedelta(hours=delta) + self.delta = delta + + self.content = action_factory(kwargs, context, self.action_path + ["shift"]) + @trace_select def select(self, dates): - result = self.content.select(dates) - self.context.register_reference(self.name, result) - return result + shifted_dates = [d + self.delta for d in dates] + result = self.content.select(shifted_dates) + return UnShiftResult(self.context, self.action_path, dates, result, action=self) def __repr__(self): - return super().__repr__(_inline_=self.name, _indent_=" ") + return super().__repr__(f"{self.delta}\n{self.content}") -class BaseFunctionAction(Action): - def __repr__(self): - content = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join( - [self._short_str(f"{k}={v}") for k, v in self.kwargs.items()] - ) - content = self._short_str(content) - return super().__repr__(_inline_=content, _indent_=" ") +class UnShiftResult(Result): + def __init__(self, context, action_path, dates, result, action): + super().__init__(context, action_path, dates) + # dates are the actual requested dates + # result does not have the same dates + self.action = action + self.result = result - def select(self, dates): - return FunctionResult(self.context, dates, action=self) + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.delta}({shorten(self.dates)})" + @cached_property + @assert_fieldset + @notify_result + @trace_datasource + def datasource(self): + from climetlab.indexing.fieldset import FieldArray + + class DateShiftedField: + def __init__(self, field, delta): + self.field = field + self.delta = delta + + def metadata(self, key): + value = self.field.metadata(key) + if key == "param": + return value + "_" + time_delta_to_string(self.delta) + if key == "valid_datetime": + dt = datetime.datetime.fromisoformat(value) + new_dt = dt - self.delta + new_value = new_dt.isoformat() + return new_value + if key in ["date", "time", "step", "hdate"]: + raise NotImplementedError( + f"metadata {key} not implemented when shifting dates" + ) + return value -class SourceAction(BaseFunctionAction): - @property - def function(self): - from climetlab import load_source + def __getattr__(self, name): + return getattr(self.field, name) + + ds = self.result.datasource + ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds]) + return ds - return load_source +class FunctionAction(Action): + def __init__(self, context, action_path, _name, **kwargs): + super().__init__(context, action_path, **kwargs) + self.name = _name -class FunctionAction(BaseFunctionAction): - def __init__(self, context, name, **kwargs): - super().__init__(context, **kwargs) - self.name = name + @trace_select + def select(self, dates): + return FunctionResult(self.context, self.action_path, dates, action=self) @property def function(self): - here = os.path.dirname(__file__) - path = os.path.join(here, "functions", f"{self.name}.py") - spec = importlib.util.spec_from_file_location(self.name, path) - module = spec.loader.load_module() - # TODO: this fails here, fix this. - # getattr(module, self.name) - # self.action.kwargs - return module.execute + name, delta = parse_function_name(self.name) + return import_function(self.name, "actions") + + def __repr__(self): + content = "" + content += ",".join([self._short_str(a) for a in self.args]) + content += " ".join( + [self._short_str(f"{k}={v}") for k, v in self.kwargs.items()] + ) + content = self._short_str(content) + return super().__repr__(_inline_=content, _indent_=" ") + + def _trace_select(self, dates): + return f"{self.name}({shorten(dates)})" class ConcatResult(Result): - def __init__(self, context, results): - super().__init__(context, dates=None) + def __init__(self, context, action_path, dates, results, **kwargs): + super().__init__(context, action_path, dates) self.results = [r for r in results if not r.empty] - @property + @cached_property + @assert_fieldset + @notify_result + @trace_datasource def datasource(self): - ds = EmptyResult(self.context, self.dates).datasource + ds = EmptyResult(self.context, self.action_path, self.dates).datasource for i in self.results: ds += i.datasource - assert_is_fieldset(ds), i return ds @property def variables(self): - """Check that all the results objects have the same variables""" + """Check that all the results objects have the same variables.""" variables = None for f in self.results: if f.empty: @@ -466,22 +593,6 @@ def variables(self): assert variables is not None, self.results return variables - @property - def dates(self): - """Merge the dates of all the results objects""" - dates = [] - for i in self.results: - d = i.dates - if d is None: - continue - dates += d - assert isinstance(dates[0], datetime.datetime), dates[0] - return sorted(dates) - - @property - def frequency(self): - return build_groups(self.dates).frequency - def __repr__(self): content = "\n".join([str(i) for i in self.results]) return super().__repr__(content) @@ -490,9 +601,12 @@ def __repr__(self): class ActionWithList(Action): result_class = None - def __init__(self, context, *configs): - super().__init__(context, *configs) - self.actions = [action_factory(c, context) for c in configs] + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + self.actions = [ + action_factory(c, context, action_path + [str(i)]) + for i, c in enumerate(configs) + ] def __repr__(self): content = "\n".join([str(i) for i in self.actions]) @@ -500,133 +614,149 @@ def __repr__(self): class PipeAction(Action): - def __init__(self, context, *configs): - super().__init__(context, *configs) - current = action_factory(configs[0], context) - for c in configs[1:]: - current = step_factory(c, context, _upstream_action=current) - self.content = current - + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + assert len(configs) > 1, configs + current = action_factory(configs[0], context, action_path + ["0"]) + for i, c in enumerate(configs[1:]): + current = step_factory( + c, context, action_path + [str(i + 1)], previous_step=current + ) + self.last_step = current + + @trace_select def select(self, dates): - return self.content.select(dates) + return self.last_step.select(dates) def __repr__(self): - return super().__repr__(self.content) + return super().__repr__(self.last_step) class StepResult(Result): - def __init__(self, upstream, context, dates, action): - super().__init__(context, dates) - assert isinstance(upstream, Result), type(upstream) - self.content = upstream + def __init__(self, context, action_path, dates, action, upstream_result): + super().__init__(context, action_path, dates) + assert isinstance(upstream_result, Result), type(upstream_result) + self.upstream_result = upstream_result self.action = action @property + @notify_result + @trace_datasource def datasource(self): - return self.content.datasource + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") class StepAction(Action): result_class = None - def __init__(self, context, _upstream_action, **kwargs): - super().__init__(context, **kwargs) - self.content = _upstream_action + def __init__(self, context, action_path, previous_step, *args, **kwargs): + super().__init__(context, action_path, *args, **kwargs) + self.previous_step = previous_step + @trace_select def select(self, dates): return self.result_class( - self.content.select(dates), self.context, + self.action_path, dates, self, + self.previous_step.select(dates), ) def __repr__(self): - return super().__repr__(self.content, _inline_=str(self.kwargs)) + return super().__repr__(self.previous_step, _inline_=str(self.kwargs)) + + +class StepFunctionResult(StepResult): + @cached_property + @assert_fieldset + @notify_result + @trace_datasource + def datasource(self): + try: + return self.action.function( + FunctionContext(self), + self.upstream_result.datasource, + **self.action.kwargs, + ) + except Exception: + LOG.error(f"Error in {self.action.name}", exc_info=True) + raise -class FilterResult(StepResult): + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.name}({shorten(self.dates)})" + + +class FilterStepResult(StepResult): @property + @notify_result + @assert_fieldset + @trace_datasource def datasource(self): ds = self.content.datasource - assert_is_fieldset(ds) ds = ds.sel(**self.action.kwargs) - assert_is_fieldset(ds) return ds -class FilterAction(StepAction): - result_class = FilterResult +class FilterStepAction(StepAction): + result_class = FilterStepResult -# class RenameResult(StepResult): -# @property -# def datasource(self): -# ds = self.content.datasource -# assert_is_fieldset(ds) -# ds = ds.rename(**self.action.kwargs) -# assert_is_fieldset(ds) -# return ds -# -# -# class RenameAction(StepAction): -# result_class = RenameResult +class FunctionStepAction(StepAction): + def __init__(self, context, action_path, previous_step, *args, **kwargs): + super().__init__(context, action_path, previous_step, *args, **kwargs) + self.name = args[0] + self.function = import_function(self.name, "steps") + + result_class = StepFunctionResult class ConcatAction(ActionWithList): + @trace_select def select(self, dates): - return ConcatResult(self.context, [a.select(dates) for a in self.actions]) + return ConcatResult( + self.context, + self.action_path, + dates, + [a.select(dates) for a in self.actions], + ) class JoinAction(ActionWithList): + @trace_select def select(self, dates): - return JoinResult(self.context, dates, [a.select(dates) for a in self.actions]) + return JoinResult( + self.context, + self.action_path, + dates, + [a.select(dates) for a in self.actions], + ) class DateAction(Action): - def __init__(self, context, **kwargs): - super().__init__(context, **kwargs) - - datesconfig = {} - subconfig = {} - for k, v in deepcopy(kwargs).items(): - if k in ["start", "end", "frequency"]: - datesconfig[k] = v - else: - subconfig[k] = v - - self._dates = build_groups(datesconfig) - self.content = action_factory(subconfig, context) + def __init__(self, context, action_path, start, end, frequency, **kwargs): + super().__init__(context, action_path, **kwargs) + self.filtering_dates = Dates.from_config( + start=start, end=end, frequency=frequency + ) + self.content = action_factory(kwargs, context, self.action_path + ["dates"]) + @trace_select def select(self, dates): - newdates = self._dates.intersect(dates) - if newdates.empty(): - return EmptyResult(self.context, dates=newdates) + newdates = sorted(set(dates) & set(self.filtering_dates)) + if not newdates: + return EmptyResult(self.context, self.action_path, newdates) return self.content.select(newdates) def __repr__(self): - return super().__repr__(f"{self._dates}\n{self.content}") - - -def merge_dicts(a, b): - if isinstance(a, dict): - assert isinstance(b, dict), (a, b) - a = deepcopy(a) - for k, v in b.items(): - if k not in a: - a[k] = v - else: - a[k] = merge_dicts(a[k], v) - return a - - return deepcopy(b) + return super().__repr__(f"{self.filtering_dates}\n{self.content}") -def action_factory(config, context): +def action_factory(config, context, action_path): assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") - if len(config) != 1: raise ValueError( f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}" @@ -634,26 +764,33 @@ def action_factory(config, context): config = deepcopy(config) key = list(config.keys())[0] + + if isinstance(config[key], list): + args, kwargs = config[key], {} + if isinstance(config[key], dict): + args, kwargs = [], config[key] + cls = dict( + date_shift=DateShiftAction, + # date_filter=DateFilterAction, + # include=IncludeAction, concat=ConcatAction, join=JoinAction, - label=LabelAction, pipe=PipeAction, - source=SourceAction, function=FunctionAction, dates=DateAction, - )[key] + ).get(key) - if isinstance(config[key], list): - args, kwargs = config[key], {} + if cls is None: + if not is_function(key, "actions"): + raise ValueError(f"Unknown action '{key}' in {config}") + cls = FunctionAction + args = [key] + args - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - return cls(context, *args, **kwargs) + return cls(context, action_path + [key], *args, **kwargs) -def step_factory(config, context, _upstream_action): +def step_factory(config, context, action_path, previous_step): assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") @@ -663,10 +800,10 @@ def step_factory(config, context, _upstream_action): key = list(config.keys())[0] cls = dict( - filter=FilterAction, + filter=FilterStepAction, # rename=RenameAction, # remapping=RemappingAction, - )[key] + ).get(key) if isinstance(config[key], list): args, kwargs = config[key], {} @@ -674,53 +811,54 @@ def step_factory(config, context, _upstream_action): if isinstance(config[key], dict): args, kwargs = [], config[key] - if "_upstream_action" in kwargs: - raise ValueError(f"Reserverd keyword '_upsream_action' in {config}") - kwargs["_upstream_action"] = _upstream_action + if cls is None: + if not is_function(key, "steps"): + raise ValueError(f"Unknown step {key}") + cls = FunctionStepAction + args = [key] + args - return cls(context, *args, **kwargs) + return cls(context, action_path, previous_step, *args, **kwargs) -class Context: +class FunctionContext: + """A FunctionContext is passed to all functions, it will be used to pass information + to the functions from the other actions and steps and results.""" + + def __init__(self, owner): + self.owner = owner + + def trace(self, emoji, *args): + trace(emoji, *args) + + +class ActionContext(Context): def __init__(self, /, order_by, flatten_grid, remapping): + super().__init__() self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) - self.references = {} - - def register_reference(self, name, obj): - assert isinstance(obj, Result), type(obj) - if name in self.references: - raise ValueError(f"Duplicate reference {name}") - self.references[name] = obj - - def find_reference(self, name): - if name in self.references: - return self.references[name] - # It can happend that the required name is not yet registered, - # even if it is defined in the config. - # Handling this case implies implementing a lazy inheritance resolution - # and would complexify the code. This is not implemented. - raise ValueError(f"Cannot find reference {name}") - class InputBuilder: def __init__(self, config, **kwargs): self.kwargs = kwargs self.config = config + self.action_path = ["input"] - @property - def _action(self): - context = Context(**self.kwargs) - return action_factory(self.config, context) - + @trace_select def select(self, dates): """This changes the context.""" - return self._action.select(dates) + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.select(dates) def __repr__(self): - return repr(self._action) + context = ActionContext(**self.kwargs) + a = action_factory(self.config, context, self.action_path) + return repr(a) + + def _trace_select(self, dates): + return f"InputBuilder({shorten(dates)})" build_input = InputBuilder diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index ba90107..aa15c20 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -4,7 +4,6 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import datetime import logging import os @@ -15,10 +14,10 @@ import zarr from ecml_tools.data import open_dataset +from ecml_tools.utils.dates.groups import Groups from .check import DatasetName from .config import build_output, loader_config -from .group import build_groups from .input import build_input from .statistics import ( StatisticsRegistry, @@ -43,7 +42,7 @@ class Loader: def __init__(self, *, path, print=print, **kwargs): # Catch all floating point errors, including overflow, sqrt(<0), etc - np.seterr(all="raise") + np.seterr(all="raise", under="warn") assert isinstance(path, str), path @@ -114,6 +113,7 @@ def initialise_dataset_backend(self): z.create_group("_build") def update_metadata(self, **kwargs): + print("Updating metadata", kwargs) z = zarr.open(self.path, mode="w+") for k, v in kwargs.items(): if isinstance(v, np.datetime64): @@ -145,49 +145,41 @@ def __init__(self, config, **kwargs): self.statistics_registry.delete() - self.groups = build_groups(*self.main_config.loop) + print(self.main_config.dates) + self.groups = Groups(**self.main_config.dates) + print("βœ… GROUPS") + self.output = build_output(self.main_config.output, parent=self) self.input = self.build_input() print(self.input) - self.inputs = self.input.select(dates=None) - all_dates = self.inputs.dates - self.minimal_input = self.input.select(dates=[all_dates[0]]) + all_dates = self.groups.dates + self.minimal_input = self.input.select([all_dates[0]]) print("βœ… GROUPS") print(self.groups) - print("βœ… ALL INPUTS") - print(self.inputs) print("βœ… MINIMAL INPUT") print(self.minimal_input) def initialise(self, check_name=True): - """Create empty dataset""" + """Create empty dataset.""" self.print("Config loaded ok:") print(self.main_config) print("-------------------------") - dates = self.inputs.dates - if self.groups.frequency != self.inputs.frequency: - raise ValueError( - f"Frequency mismatch: {self.groups.frequency} != {self.inputs.frequency}" - ) - if self.groups.values[0] != self.inputs.dates[0]: - raise ValueError( - f"First date mismatch: {self.groups.values[0]} != {self.inputs.dates[0]}" - ) + dates = self.groups.dates print("-------------------------") - frequency = self.inputs.frequency + frequency = dates.frequency assert isinstance(frequency, int), frequency self.print(f"Found {len(dates)} datetimes.") print( - f"Dates: Found {len(dates)} datetimes, in {self.groups.n_groups} groups: ", + f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ", end="", ) - lengths = [len(g) for g in self.groups.groups] + lengths = [len(g) for g in self.groups] self.print( f"Found {len(dates)} datetimes {'+'.join([str(_) for _ in lengths])}." ) @@ -251,9 +243,6 @@ def initialise(self, check_name=True): metadata["start_date"] = dates[0].isoformat() metadata["end_date"] = dates[-1].isoformat() - # metadata["statistics_start_date"]=self.output.get("statistics_start") - # metadata["statistics_end_date"]=self.output.get("statistics_end") - if check_name: basename, ext = os.path.splitext(os.path.basename(self.path)) ds_name = DatasetName( @@ -309,7 +298,7 @@ def __init__(self, config, **kwargs): super().__init__(**kwargs) self.main_config = loader_config(config) - self.groups = build_groups(*self.main_config.loop) + self.groups = Groups(**self.main_config.dates) self.output = build_output(self.main_config.output, parent=self) self.input = self.build_input() self.read_dataset_metadata() @@ -323,19 +312,21 @@ def load(self, parts): ) total = len(self.registry.get_flags()) - n_groups = len(self.groups.groups) filter = CubesFilter(parts=parts, total=total) - for igroup, group in enumerate(self.groups.groups): + for igroup, group in enumerate(self.groups): if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={n_groups} (already done)") + LOG.info( + f" -> Skipping {igroup} total={len(self.groups)} (already done)" + ) continue if not filter(igroup): continue - self.print(f" -> Processing {igroup} total={n_groups}") + self.print(f" -> Processing {igroup} total={len(self.groups)}") + print("========", group) assert isinstance(group[0], datetime.datetime), group - inputs = self.input.select(dates=group) - data_writer.write(inputs, igroup) + result = self.input.select(dates=group) + data_writer.write(result, igroup, group) self.registry.add_to_history("loading_data_end", parts=parts) self.registry.add_provenance(name="provenance_load") @@ -469,10 +460,8 @@ def recompute_temporary_statistics(self): self.statistics_registry.create(exist_ok=True) self.print( - ( - f"Building temporary statistics from data {self.path}. " - f"From {self.date_start} to {self.date_end}" - ) + f"Building temporary statistics from data {self.path}. " + f"From {self.date_start} to {self.date_end}" ) shape = (self.i_end + 1 - self.i_start, len(self.variables_names)) @@ -511,11 +500,9 @@ def get_detailed_stats(self): dates = open_dataset(self.path).dates missing_dates = dates[missing_index[0]] print( - ( - f"Missing dates: " - f"{missing_dates[0]} ... {missing_dates[len(missing_dates)-1]} " - f"({missing_dates.shape[0]} missing)" - ) + f"Missing dates: " + f"{missing_dates[0]} ... {missing_dates[len(missing_dates)-1]} " + f"({missing_dates.shape[0]} missing)" ) raise @@ -581,3 +568,9 @@ def add_total_size(self): print(f"Total number of files: {n}") self.update_metadata(total_size=size, total_number_of_files=n) + + +class CleanupLoader(Loader): + def run(self): + self.statistics_registry.delete() + self.registry.clean() diff --git a/ecml_tools/create/template.py b/ecml_tools/create/template.py index d7d87e6..230a5e3 100644 --- a/ecml_tools/create/template.py +++ b/ecml_tools/create/template.py @@ -8,157 +8,155 @@ # import logging -import os import re - -from ecml_tools.create.utils import to_datetime +import textwrap +from functools import wraps LOG = logging.getLogger(__name__) +TRACE_INDENT = 0 + + +def step(action_path): + return f"[{'.'.join(action_path)}]" + + +def trace(emoji, *args): + print(emoji, " " * TRACE_INDENT, *args) + + +def trace_datasource(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + global TRACE_INDENT + trace( + "🌍", + "=>", + step(self.action_path), + self._trace_datasource(*args, **kwargs), + ) + TRACE_INDENT += 1 + result = method(self, *args, **kwargs) + TRACE_INDENT -= 1 + trace( + "🍎", + "<=", + step(self.action_path), + textwrap.shorten(repr(result), 256), + ) + return result + + return wrapper + + +def trace_select(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + global TRACE_INDENT + trace( + "πŸ‘“", + "=>", + ".".join(self.action_path), + self._trace_select(*args, **kwargs), + ) + TRACE_INDENT += 1 + result = method(self, *args, **kwargs) + TRACE_INDENT -= 1 + trace( + "🍍", + "<=", + ".".join(self.action_path), + textwrap.shorten(repr(result), 256), + ) + return result + + return wrapper + + +def notify_result(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + self.context.notify_result(self.action_path, result) + return result + + return wrapper + + +class Context: + def __init__(self): + # used_references is a set of reference paths that will be needed + self.used_references = set() + # results is a dictionary of reference path -> obj + self.results = {} + + def will_need_reference(self, key): + assert isinstance(key, (list, tuple)), key + key = tuple(key) + self.used_references.add(key) + + def notify_result(self, key, result): + trace("🎯", step(key), "notify result", result) + assert isinstance(key, (list, tuple)), key + key = tuple(key) + if key in self.used_references: + if key in self.results: + raise ValueError(f"Duplicate result {key}") + self.results[key] = result + + def get_result(self, key): + assert isinstance(key, (list, tuple)), key + key = tuple(key) + if key in self.results: + return self.results[key] + raise ValueError(f"Cannot find result {key}") + + +class Substitution: + pass + + +class Reference(Substitution): + def __init__(self, context, action_path): + self.context = context + self.action_path = action_path + + def resolve(self, context): + return context.get_result(self.action_path) + + +def resolve(context, x): + if isinstance(x, tuple): + return tuple([resolve(context, y) for y in x]) + + if isinstance(x, list): + return [resolve(context, y) for y in x] -def substitute(x, vars=None, ignore_missing=False): - """Recursively substitute environment variables and dict values in a nested list ot dict of string. - substitution is performed using the environment var (if UPPERCASE) or the input dictionary. - - - >>> substitute({'bar': '$bar'}, {'bar': '43'}) - {'bar': '43'} - - >>> substitute({'bar': '$BAR'}, {'BAR': '43'}) - Traceback (most recent call last): - ... - KeyError: 'BAR' - - >>> substitute({'bar': '$BAR'}, ignore_missing=True) - {'bar': '$BAR'} - - >>> os.environ["BAR"] = "42" - >>> substitute({'bar': '$BAR'}) - {'bar': '42'} + if isinstance(x, dict): + return {k: resolve(context, v) for k, v in x.items()} - >>> substitute('$bar', {'bar': '43'}) - '43' + if isinstance(x, Substitution): + return x.resolve(context) - >>> substitute('$hdates_from_date($date, 2015, 2018)', {'date': '2023-05-12'}) - '2015-05-12/2016-05-12/2017-05-12/2018-05-12' + return x - """ - if vars is None: - vars = {} - assert isinstance(vars, dict), vars +def substitute(context, x): + if isinstance(x, tuple): + return tuple([substitute(context, y) for y in x]) - if isinstance(x, (tuple, list)): - return [substitute(y, vars, ignore_missing=ignore_missing) for y in x] + if isinstance(x, list): + return [substitute(context, y) for y in x] if isinstance(x, dict): - return { - k: substitute(v, vars, ignore_missing=ignore_missing) for k, v in x.items() - } - - if isinstance(x, str): - if "$" not in x: - return x - - lst = [] - - for i, bit in enumerate(re.split(r"(\$(\w+)(\([^\)]*\))?)", x)): - if bit is None: - continue - assert isinstance(bit, str), (bit, type(bit), x, type(x)) - - i %= 4 - if i in [2, 3]: - continue - if i == 1: - try: - if "(" in bit: - # substitute by a function - FUNCTIONS = dict( - hdates_from_date=hdates_from_date, - datetime_format=datetime_format, - ) - - pattern = r"\$(\w+)\(([^)]*)\)" - match = re.match(pattern, bit) - assert match, bit - - function_name = match.group(1) - params = [p.strip() for p in match.group(2).split(",")] - params = [ - substitute(p, vars, ignore_missing=ignore_missing) - for p in params - ] - - bit = FUNCTIONS[function_name](*params) - - elif bit.upper() == bit: - # substitute by the var env if $UPPERCASE - bit = os.environ[bit[1:]] - else: - # substitute by the value in the 'vars' dict - bit = vars[bit[1:]] - except KeyError as e: - if not ignore_missing: - raise e - - if bit != x: - bit = substitute(bit, vars, ignore_missing=ignore_missing) - - lst.append(bit) - - lst = [_ for _ in lst if _ != ""] - if len(lst) == 1: - return lst[0] - - out = [] - for elt in lst: - # if isinstance(elt, str): - # elt = [elt] - assert isinstance(elt, (list, tuple)), elt - out += elt - return out + return {k: substitute(context, v) for k, v in x.items()} - return x + if not isinstance(x, str): + return x + if re.match(r"^\${[\.\w]+}$", x): + path = x[2:-1].split(".") + context.will_need_reference(path) + return Reference(context, path) -def datetime_format(dates, format, join=None): - formated = [to_datetime(d).strftime(format) for d in dates] - formated = set(formated) - formated = list(formated) - formated = sorted(formated) - if join: - formated = join.join(formated) - return formated - - -def hdates_from_date(date, start_year, end_year): - """ - Returns a list of dates in the format '%Y%m%d' between start_year and end_year (inclusive), - with the year of the input date. - - Args: - date (str or datetime): The input date. - start_year (int): The start year. - end_year (int): The end year. - - Returns: - List[str]: A list of dates in the format '%Y%m%d'. - """ - if not str(start_year).isdigit(): - raise ValueError(f"start_year must be an int: {start_year}") - if not str(end_year).isdigit(): - raise ValueError(f"end_year must be an int: {end_year}") - start_year = int(start_year) - end_year = int(end_year) - - if isinstance(date, (list, tuple)): - if len(date) != 1: - raise NotImplementedError(f"{date} should have only one element.") - date = date[0] - - date = to_datetime(date) - assert not (date.hour or date.minute or date.second), date - - hdates = [date.replace(year=year) for year in range(start_year, end_year + 1)] - return "/".join(d.strftime("%Y-%m-%d") for d in hdates) + return x diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index 6ce5a97..dcc4783 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # +import datetime import logging import time import warnings @@ -75,8 +76,7 @@ def new_key(self, key, values_shape): class FastWriteArray(ArrayLike): - """ - A class that provides a caching mechanism for writing to a NumPy-like array. + """A class that provides a caching mechanism for writing to a NumPy-like array. The `FastWriteArray` instance is initialized with a NumPy-like array and its shape. The array is used to store the final data, while the cache is used to temporarily @@ -121,10 +121,9 @@ def compute_statistics_and_key(self, variables_names): class OffsetView(ArrayLike): - """ - A view on a portion of the large_array. - 'axis' is the axis along which the offset applies. - 'shape' is the shape of the view. + """A view on a portion of the large_array. + + 'axis' is the axis along which the offset applies. 'shape' is the shape of the view. """ def __init__(self, large_array, *, offset, axis, shape): @@ -175,10 +174,17 @@ def __init__(self, parts, full_array, parent, print=print): self.print = parent.print self.append_axis = parent.output.append_axis - self.n_cubes = parent.groups.n_groups + self.n_cubes = len(parent.groups) - def write(self, inputs, igroup): - cube = inputs.get_cube() + def write(self, result, igroup, dates): + cube = result.get_cube() + assert cube.extended_user_shape[0] == len(dates), ( + cube.extended_user_shape[0], + len(dates), + ) + dates_in_data = cube.user_coords["valid_datetime"] + dates_in_data = [datetime.datetime.fromisoformat(_) for _ in dates_in_data] + assert dates_in_data == list(dates), (dates_in_data, list(dates)) self.write_cube(cube, igroup) @property @@ -192,10 +198,8 @@ def write_cube(self, cube, icube): slice = self.registry.get_slice_for(icube) LOG.info( - ( - f"Building dataset '{self.path}' i={icube} total={self.n_cubes} " - f"(total shape ={shape}) at {slice}, {self.full_array.chunks=}" - ) + f"Building dataset '{self.path}' i={icube} total={self.n_cubes} " + f"(total shape ={shape}) at {slice}, {self.full_array.chunks=}" ) self.print( f"Building dataset (total shape ={shape}) at {slice}, {self.full_array.chunks=}" @@ -233,9 +237,7 @@ def load_datacube(self, cube, array): data = cubelet.to_numpy() cubelet_coords = cubelet.coords - bar.set_description( - f"Loading {i}/{total} {str(cubelet)} ({data.shape}) {cube=}" - ) + bar.set_description(f"Loading {i}/{total} {str(cubelet)} ({data.shape})") load += time.time() - now j = cubelet_coords[1] diff --git a/ecml_tools/data.py b/ecml_tools/data.py index 995fe61..15234c5 100644 --- a/ecml_tools/data.py +++ b/ecml_tools/data.py @@ -784,6 +784,12 @@ def __init__(self, datasets, axis): # Shape: (dates, variables, ensemble, 1d-values) assert len(datasets[0].shape) == 4, "Grids must be 1D for now" + def check_same_grid(self, d1, d2): + # We don't check the grid, because we want to be able to combine + pass + + +class ConcatGrids(Grids): # TODO: select the statistics of the most global grid? @property def latitudes(self): @@ -793,10 +799,75 @@ def latitudes(self): def longitudes(self): return np.concatenate([d.longitudes for d in self.datasets]) - def check_same_grid(self, d1, d2): - # We don't check the grid, because we want to be able to combine + +class CutoutGrids(Grids): + def __init__(self, datasets, axis): + from .grids import cutout_mask + + super().__init__(datasets, axis) + assert len(datasets) == 2, "CutoutGrids requires two datasets" + assert axis == 3, "CutoutGrids requires axis=3" + + # We assume that the LAM is the first dataset, and the global is the second + # Note: the second fields does not really need to be global + + self.lam, self.globe = datasets + self.mask = cutout_mask( + self.lam.latitudes, + self.lam.longitudes, + self.globe.latitudes, + self.globe.longitudes, + plot="cutout", + ) + assert len(self.mask) == self.globe.shape[3], ( + len(self.mask), + self.globe.shape[3], + ) + + @cached_property + def shape(self): + shape = self.lam.shape + # Number of non-zero masked values in the globe dataset + nb_globe = np.count_nonzero(self.mask) + return shape[:-1] + (shape[-1] + nb_globe,) + + def check_same_resolution(self, d1, d2): + # Turned off because we are combining different resolutions pass + @property + def latitudes(self): + return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]]) + + @property + def longitudes(self): + return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]]) + + def __getitem__(self, index): + if isinstance(index, (int, slice)): + index = (index, slice(None), slice(None), slice(None)) + return self._get_tuple(index) + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, index): + assert index[self.axis] == slice( + None + ), "No support for selecting a subset of the 1D values" + index, changes = index_to_slices(index, self.shape) + + # In case index_to_slices has changed the last slice + index, _ = update_tuple(index, self.axis, slice(None)) + + lam_data = self.lam[index] + globe_data = self.globe[index] + + globe_data = globe_data[:, :, :, self.mask] + + result = np.concatenate([lam_data, globe_data], axis=self.axis) + + return apply_index_to_slices_changes(result, changes) + class Join(Combined): """ @@ -1108,7 +1179,8 @@ def _frequency_to_hours(frequency): def _as_date(d, dates, last): if isinstance(d, datetime.datetime): - assert d.minutes == 0 and d.hours == 0 and d.seconds == 0, d + if not d.minute == 0 and d.hour == 0 and d.second == 0: + return np.datetime64(d) d = datetime.date(d.year, d.month, d.day) if isinstance(d, datetime.date): @@ -1116,7 +1188,7 @@ def _as_date(d, dates, last): try: d = int(d) - except ValueError: + except (ValueError, TypeError): pass if isinstance(d, int): @@ -1184,12 +1256,14 @@ def _as_last_date(d, dates): return _as_date(d, dates, last=True) -def _concat_or_join(datasets): +def _concat_or_join(datasets, kwargs): + datasets, kwargs = _auto_adjust(datasets, kwargs) + # Study the dates ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets] if len(set(ranges)) == 1: - return Join(datasets)._overlay() + return Join(datasets)._overlay(), kwargs # Make sure the dates are disjoint for i in range(len(ranges)): @@ -1214,7 +1288,7 @@ def _concat_or_join(datasets): f"{r} and {s} ({datasets[i]} {datasets[i+1]})" ) - return Concat(datasets) + return Concat(datasets), kwargs def _open(a, zarr_root): @@ -1239,6 +1313,48 @@ def _open(a, zarr_root): raise NotImplementedError() +def _auto_adjust(datasets, kwargs): + """Adjust the datasets for concatenation or joining based + on parameters set to 'matching'""" + + if kwargs.get("ajust") == "matching": + kwargs.pop("ajust") + for p in ("select", "frequency", "start", "end"): + kwargs[p] = "matching" + + adjust = {} + + if kwargs.get("select") == "matching": + kwargs.pop("select") + variables = None + for d in datasets: + if variables is None: + variables = set(d.variables) + else: + variables &= set(d.variables) + if len(variables) == 0: + raise ValueError("No common variables") + + adjust["select"] = sorted(variables) + + if kwargs.get("frequency") == "matching": + kwargs.pop("frequency") + adjust["frequency"] = max(d.frequency for d in datasets) + + if kwargs.get("start") == "matching": + kwargs.pop("start") + adjust["start"] = max(d.dates[0] for d in datasets).astype(object) + + if kwargs.get("end") == "matching": + kwargs.pop("end") + adjust["end"] = max(d.dates[-1] for d in datasets).astype(object) + + if adjust: + datasets = [d._subset(**adjust) for d in datasets] + + return datasets, kwargs + + def _open_dataset(*args, zarr_root, **kwargs): sets = [] for a in args: @@ -1247,22 +1363,40 @@ def _open_dataset(*args, zarr_root, **kwargs): if "ensemble" in kwargs: if "grids" in kwargs: raise NotImplementedError("Cannot use both 'ensemble' and 'grids'") + ensemble = kwargs.pop("ensemble") axis = kwargs.pop("axis", 2) assert len(args) == 0 assert isinstance(ensemble, (list, tuple)) - return Ensemble([_open(e, zarr_root) for e in ensemble], axis=axis)._subset( - **kwargs - ) + + datasets = [_open(e, zarr_root) for e in ensemble] + datasets, kwargs = _auto_adjust(datasets, kwargs) + + return Ensemble(datasets, axis=axis)._subset(**kwargs) if "grids" in kwargs: if "ensemble" in kwargs: raise NotImplementedError("Cannot use both 'ensemble' and 'grids'") + grids = kwargs.pop("grids") + mode = kwargs.pop("mode", "concatenate") axis = kwargs.pop("axis", 3) assert len(args) == 0 assert isinstance(grids, (list, tuple)) - return Grids([_open(e, zarr_root) for e in grids], axis=axis)._subset(**kwargs) + + KLASSES = { + "concatenate": ConcatGrids, + "cutout": CutoutGrids, + } + if mode not in KLASSES: + raise ValueError( + f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}" + ) + + datasets = [_open(e, zarr_root) for e in grids] + datasets, kwargs = _auto_adjust(datasets, kwargs) + + return KLASSES[mode](datasets, axis=axis)._subset(**kwargs) for name in ("datasets", "dataset"): if name in kwargs: @@ -1275,7 +1409,8 @@ def _open_dataset(*args, zarr_root, **kwargs): assert len(sets) > 0, (args, kwargs) if len(sets) > 1: - return _concat_or_join(sets)._subset(**kwargs) + dataset, kwargs = _concat_or_join(sets, kwargs) + return dataset._subset(**kwargs) return sets[0]._subset(**kwargs) diff --git a/ecml_tools/grids.py b/ecml_tools/grids.py new file mode 100644 index 0000000..bfb3664 --- /dev/null +++ b/ecml_tools/grids.py @@ -0,0 +1,266 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import numpy as np +from scipy.spatial import KDTree + + +def plot_mask(path, mask, lats, lons, global_lats, global_lons): + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") + plt.savefig(path + "-global.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") + plt.savefig(path + "-cutout.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(lons, lats, s=0.01) + plt.savefig(path + "-lam.png") + # plt.scatter(lons, lats, s=0.01) + + +def latlon_to_xyz(lat, lon, radius=1.0): + # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates + # We assume that the Earth is a sphere of radius 1 so N(phi) = 1 + # We assume h = 0 + # + phi = np.deg2rad(lat) + lda = np.deg2rad(lon) + + cos_phi = np.cos(phi) + cos_lda = np.cos(lda) + sin_phi = np.sin(phi) + sin_lda = np.sin(lda) + + x = cos_phi * cos_lda * radius + y = cos_phi * sin_lda * radius + z = sin_phi * radius + + return x, y, z + + +class Triangle3D: + def __init__(self, v0, v1, v2): + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + + def intersect(self, ray_origin, ray_direction): + # MΓΆller–Trumbore intersection algorithm + # https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm + + epsilon = 0.0000001 + + h = np.cross(ray_direction, self.v2 - self.v0) + a = np.dot(self.v1 - self.v0, h) + + if -epsilon < a < epsilon: + return None + + f = 1.0 / a + s = ray_origin - self.v0 + u = f * np.dot(s, h) + + if u < 0.0 or u > 1.0: + return None + + q = np.cross(s, self.v1 - self.v0) + v = f * np.dot(ray_direction, q) + + if v < 0.0 or u + v > 1.0: + return None + + t = f * np.dot(self.v2 - self.v0, q) + + if t > epsilon: + return t + + return None + + +def cropping_mask(lats, lons, north, west, south, east): + mask = ( + (lats >= south) + & (lats <= north) + & ( + ((lons >= west) & (lons <= east)) + | ((lons >= west + 360) & (lons <= east + 360)) + | ((lons >= west - 360) & (lons <= east - 360)) + ) + ) + return mask + + +def cutout_mask( + lats, + lons, + global_lats, + global_lons, + cropping_distance=2.0, + min_distance=0.0, + plot=None, +): + """ + Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons] + """ + + # TODO: transform min_distance from lat/lon to xyz + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + points = np.array(xyx).transpose() + + # Use a KDTree to find the nearest points + kdtree = KDTree(points) + distances, indices = kdtree.query(global_points, k=3) + + zero = np.array([0.0, 0.0, 0.0]) + ok = [] + for i, (global_point, distance, index) in enumerate( + zip(global_points, distances, indices) + ): + t = Triangle3D(points[index[0]], points[index[1]], points[index[2]]) + distance = np.min(distance) + # The point is inside the triangle if the intersection with the ray + # from the point to the center of the Earth is not None + # (the direction of the ray is not important) + ok.append( + (t.intersect(zero, global_point) or t.intersect(global_point, zero)) + # and (distance >= min_distance) + ) + + j = 0 + ok = np.array(ok) + for i, m in enumerate(mask): + if not m: + continue + + mask[i] = ok[j] + j += 1 + + assert j == len(ok) + + # Invert the mask, so we have only the points outside the cutout + mask = ~mask + + if plot: + plot_mask(plot, mask, lats, lons, global_lats, global_lons) + + return mask + + +def thinning_mask( + lats, + lons, + global_lats, + global_lons, + cropping_distance=2.0, +): + """ + Return the list of points in [lats, lons] closest to [global_lats, global_lons] + """ + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + points = np.array(xyx).transpose() + + # Use a KDTree to find the nearest points + kdtree = KDTree(points) + _, indices = kdtree.query(global_points, k=1) + + return np.array([i for i in indices]) + + +if __name__ == "__main__": + global_lats, global_lons = np.meshgrid( + np.linspace(90, -90, 90), + np.linspace(-180, 180, 180), + ) + global_lats = global_lats.flatten() + global_lons = global_lons.flatten() + + lats, lons = np.meshgrid( + np.linspace(50, 40, 100), + np.linspace(-10, 15, 100), + ) + lats = lats.flatten() + lons = lons.flatten() + + mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0) + + import matplotlib.pyplot as plt + + fig = plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") + plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") + # plt.scatter(lons, lats, s=0.01) + plt.savefig("cutout.png") diff --git a/ecml_tools/utils/__init__.py b/ecml_tools/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ecml_tools/utils/dates/__init__.py b/ecml_tools/utils/dates/__init__.py new file mode 100644 index 0000000..4fd3f82 --- /dev/null +++ b/ecml_tools/utils/dates/__init__.py @@ -0,0 +1,113 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime + + +def frequency_to_hours(frequency): + if isinstance(frequency, int): + return frequency + assert isinstance(frequency, str), (type(frequency), frequency) + + unit = frequency[-1].lower() + v = int(frequency[:-1]) + return {"h": v, "d": v * 24}[unit] + + +class Dates: + """Base class for date generation. + + >>> Dates.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values + [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)] + + >>> Dates.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-03 00:00", "frequency": "18h"}).values + [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 18, 0), datetime.datetime(2023, 1, 2, 12, 0)] + + >>> Dates.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=6).as_dict() + {'start': '2023-01-01T00:00:00', 'end': '2023-01-02T00:00:00', 'frequency': '6h'} + + >>> len(Dates.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=12)) + 3 + """ + + @classmethod + def from_config(cls, **kwargs): + if "values" in kwargs: + return ValuesDates(**kwargs) + return StartEndDates(**kwargs) + + def __iter__(self): + for v in self.values: + yield v + + def __getitem__(self, i): + return self.values[i] + + def __len__(self): + return len(self.values) + + @property + def summary(self): + return f"πŸ“… {self.values[0]} ... {self.values[-1]}" + + +class ValuesDates(Dates): + def __init__(self, values, **kwargs): + self.values = sorted(values) + assert not kwargs, f"Unexpected arguments {kwargs}" + + def __repr__(self): + return f"{self.__class__.__name__}({self.values[0]}..{self.values[-1]})" + + def as_dict(self): + return {"values": self.values[0]} + + +class StartEndDates(Dates): + def __init__(self, start, end, frequency=1, **kwargs): + assert not kwargs, f"Unexpected arguments {kwargs}" + + frequency = frequency_to_hours(frequency) + + def _(x): + if isinstance(x, str): + return datetime.datetime.fromisoformat(x) + return x + + start = _(start) + end = _(end) + + if isinstance(start, datetime.date) and not isinstance( + start, datetime.datetime + ): + start = datetime.datetime(start.year, start.month, start.day) + + if isinstance(end, datetime.date) and not isinstance(end, datetime.datetime): + end = datetime.datetime(end.year, end.month, end.day) + + if end <= start: + raise ValueError(f"End date {end} must be after start date {start}") + + increment = datetime.timedelta(hours=frequency) + + self.start = start + self.end = end + self.frequency = frequency + + date = start + self.values = [] + while date <= end: + self.values.append(date) + date += increment + + def as_dict(self): + return { + "start": self.start.isoformat(), + "end": self.end.isoformat(), + "frequency": f"{self.frequency}h", + } diff --git a/ecml_tools/utils/dates/groups.py b/ecml_tools/utils/dates/groups.py new file mode 100644 index 0000000..a2d5752 --- /dev/null +++ b/ecml_tools/utils/dates/groups.py @@ -0,0 +1,85 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import itertools + +from ecml_tools.utils.dates import Dates + + +class Groups: + """ + >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] + [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] + + >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[1] + [datetime.datetime(2023, 1, 2, 0, 0), datetime.datetime(2023, 1, 2, 12, 0)] + + >>> g = Groups(group_by=3, start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=24) + >>> len(list(g)) + 2 + >>> len(list(g)[0]) + 3 + >>> len(list(g)[1]) + 2 + """ + + def __init__(self, **kwargs): + group_by = kwargs.pop("group_by") + self.dates = Dates.from_config(**kwargs) + self.grouper = Grouper.from_config(group_by) + + def __iter__(self): + return self.grouper(self.dates) + + def __len__(self): + return len(list(self.grouper(self.dates))) + + +class Grouper: + @classmethod + def from_config(cls, group_by): + if isinstance(group_by, int) and group_by > 0: + return GrouperByFixedSize(group_by) + if group_by is None: + return GrouperOneGroup() + key = { + "monthly": lambda dt: (dt.year, dt.month), + "daily": lambda dt: (dt.year, dt.month, dt.day), + "weekly": lambda dt: (dt.weekday(),), + "MMDD": lambda dt: (dt.month, dt.day), + }[group_by] + return GrouperByKey(key) + + +class GrouperOneGroup(Grouper): + def __call__(self, dates): + yield dates.values + + +class GrouperByKey(Grouper): + def __init__(self, key): + self.key = key + + def __call__(self, dates): + for _, g in itertools.groupby(dates, key=self.key): + yield list(g) + + +class GrouperByFixedSize(Grouper): + def __init__(self, size): + self.size = size + + def __call__(self, dates): + batch = [] + for d in dates: + batch.append(d) + if len(batch) == self.size: + yield batch + batch = [] + if batch: + yield batch diff --git a/ecml_tools/utils/humanize.py b/ecml_tools/utils/humanize.py new file mode 100644 index 0000000..e5f3e57 --- /dev/null +++ b/ecml_tools/utils/humanize.py @@ -0,0 +1,377 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime +import re +from collections import defaultdict + + +def bytes(n): + """ + >>> bytes(4096) + '4 KiB' + >>> bytes(4000) + '3.9 KiB' + """ + if n < 0: + sign = "-" + n -= 0 + else: + sign = "" + + u = ["", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"] + i = 0 + while n >= 1024: + n /= 1024.0 + i += 1 + return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i]) + + +def base2(n): + """ + >>> base2(4096) + '4K' + >>> base2(4000) + '3.9K' + """ + + u = ["", "K", "M", "G", "T", " P", "E", "Z", "Y"] + i = 0 + while n >= 1024: + n /= 1024.0 + i += 1 + return "%g%s" % (int(n * 10 + 0.5) / 10.0, u[i]) + + +PERIODS = ( + (7 * 24 * 60 * 60, "week"), + (24 * 60 * 60, "day"), + (60 * 60, "hour"), + (60, "minute"), + (1, "second"), +) + + +def _plural(count): + if count > 1: + return "s" + else: + return "" + + +def seconds(seconds): + if isinstance(seconds, datetime.timedelta): + seconds = seconds.total_seconds() + + if seconds == 0: + return "instantaneous" + + if seconds < 0.1: + units = [ + None, + "milli", + "micro", + "nano", + "pico", + "femto", + "atto", + "zepto", + "yocto", + ] + i = 0 + while seconds < 1.0 and i < len(units) - 1: + seconds *= 1000 + i += 1 + if seconds > 100 and i > 0: + seconds /= 1000 + i -= 1 + seconds = round(seconds * 10) / 10 + return f"{seconds:g} {units[i]}second{_plural(seconds)}" + + n = seconds + s = [] + for p in PERIODS: + m = int(n / p[0]) + if m: + s.append("%d %s%s" % (m, p[1], _plural(m))) + n %= p[0] + + if not s: + seconds = round(seconds * 10) / 10 + s.append("%g second%s" % (seconds, _plural(seconds))) + return " ".join(s) + + +def number(value): + return f"{value:,}" + + +def plural(value, what): + return f"{number(value)} {what}{_plural(value)}" + + +DOW = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +] + +MONTH = [ + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", +] + + +def __(n): + if n in (11, 12, 13): + return "th" + + if n % 10 == 1: + return "st" + + if n % 10 == 2: + return "nd" + + if n % 10 == 3: + return "rd" + + return "th" + + +def when(then, now=None, short=True): + last = "last" + + if now is None: + now = datetime.datetime.now() + + diff = (now - then).total_seconds() + + if diff < 0: + last = "next" + diff = -diff + + diff = int(diff) + + if diff == 0: + return "right now" + + def _(x): + if last == "last": + return "%s ago" % (x,) + else: + return "in %s" % (x,) + + if diff < 60: + diff = int(diff + 0.5) + return _("%s second%s" % (diff, _plural(diff))) + + if diff < 60 * 60: + diff /= 60 + diff = int(diff + 0.5) + return _("%s minute%s" % (diff, _plural(diff))) + + if diff < 60 * 60 * 6: + diff /= 60 * 60 + diff = int(diff + 0.5) + return _("%s hour%s" % (diff, _plural(diff))) + + jnow = now.toordinal() + jthen = then.toordinal() + + if jnow == jthen: + return "today at %02d:%02d" % (then.hour, then.minute) + + if jnow == jthen + 1: + return "yesterday at %02d:%02d" % (then.hour, then.minute) + + if jnow == jthen - 1: + return "tomorrow at %02d:%02d" % (then.hour, then.minute) + + if abs(jnow - jthen) <= 7: + if last == "next": + last = "this" + return "%s %s" % ( + last, + DOW[then.weekday()], + ) + + if abs(jnow - jthen) < 32 and now.month == then.month: + return "the %d%s of this month" % (then.day, __(then.day)) + + if abs(jnow - jthen) < 64 and now.month == then.month + 1: + return "the %d%s of %s month" % (then.day, __(then.day), last) + + if short: + years = int(abs(jnow - jthen) / 365.25 + 0.5) + if years == 1: + return "%s year" % last + + if years > 1: + return _("%d years" % (years,)) + + month = then.month + if now.year != then.year: + month -= 12 + + d = abs(now.month - month) + if d >= 12: + return _("a year") + else: + return _("%d month%s" % (d, _plural(d))) + + return "on %s %d %s %d" % ( + DOW[then.weekday()], + then.day, + MONTH[then.month], + then.year, + ) + + +def string_distance(s, t): + import numpy as np + + m = len(s) + n = len(t) + d = np.zeros((m + 1, n + 1), dtype=int) + + one = int(1) + zero = int(0) + + d[:, 0] = np.arange(m + 1) + d[0, :] = np.arange(n + 1) + + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = zero if s[i - 1] == t[j - 1] else one + d[i, j] = min( + d[i - 1, j] + one, + d[i, j - 1] + one, + d[i - 1, j - 1] + cost, + ) + + return d[m, n] + + +def did_you_mean(word, vocabulary): + distance, best = min((string_distance(word, w), w) for w in vocabulary) + # if distance < min(len(word), len(best)): + return best + + +def dict_to_human(query): + lst = [f"{k}={v}" for k, v in sorted(query.items())] + + return list_to_human(lst) + + +def list_to_human(lst, conjunction="and"): + if not lst: + return "??" + + if len(lst) > 2: + lst = [", ".join(lst[:-1]), lst[-1]] + + return f" {conjunction} ".join(lst) + + +def as_number(value, name, units, none_ok): + if value is None and none_ok: + return None + + value = str(value) + # TODO: support floats + m = re.search(r"^\s*(\d+)\s*([%\w]+)?\s*$", value) + if m is None: + raise ValueError(f"{name}: invalid number/unit {value}") + value = int(m.group(1)) + if m.group(2) is None: + return value + unit = m.group(2)[0] + if unit not in units: + valid = ", ".join(units.keys()) + raise ValueError(f"{name}: invalid unit '{unit}', valid values are {valid}") + return value * units[unit] + + +def as_seconds(value, name=None, none_ok=False): + units = dict(s=1, m=60, h=3600, d=86400, w=86400 * 7) + return as_number(value, name, units, none_ok) + + +def as_percent(value, name=None, none_ok=False): + units = {"%": 1} + return as_number(value, name, units, none_ok) + + +def as_bytes(value, name=None, none_ok=False): + units = {} + n = 1 + for u in "KMGTP": + n *= 1024 + units[u] = n + units[u.lower()] = n + + return as_number(value, name, units, none_ok) + + +def as_timedelta(value, name=None, none_ok=False): + if value is None and none_ok: + return None + + save = value + value = re.sub(r"[^a-zA-Z0-9]", "", value.lower()) + value = re.sub(r"([a-zA-Z])[a-zA-Z]*", r"\1", value) + # value = re.sub(r"[^dmhsw0-9]", "", value) + bits = [b for b in re.split(r"([dmhsw])", value) if b != ""] + + times = defaultdict(int) + + val = None + + for i, n in enumerate(bits): + if i % 2 == 0: + val = int(n) + else: + assert n in ("d", "m", "h", "s", "w") + times[n] = val + val = None + + if val is not None: + if name: + raise ValueError(f"{name}: invalid period '{save}'") + raise ValueError(f"Invalid period '{save}'") + + return datetime.timedelta( + weeks=times["w"], + days=times["d"], + hours=times["h"], + minutes=times["m"], + seconds=times["s"], + ) + + +def rounded_datetime(d): + if float(d.microsecond) / 1000.0 / 1000.0 >= 0.5: + d = d + datetime.timedelta(seconds=1) + d = d.replace(microsecond=0) + return d diff --git a/ecml_tools/utils/text.py b/ecml_tools/utils/text.py new file mode 100644 index 0000000..f8b1b9e --- /dev/null +++ b/ecml_tools/utils/text.py @@ -0,0 +1,241 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import sys + +# https://en.wikipedia.org/wiki/Box-drawing_character +from collections import defaultdict + +from termcolor import colored + + +def dotted_line(n=84, file=sys.stdout): + print("β”ˆ" * n, file=file) + + +def boxed(text, min_width=80, max_width=80): + lines = text.split("\n") + width = max(len(_) for _ in lines) + + if min_width is not None: + width = max(width, min_width) + + if max_width is not None: + width = min(width, max_width) + lines = [] + for line in text.split("\n"): + if len(line) > max_width: + line = line[: max_width - 1] + "…" + lines.append(line) + text = "\n".join(lines) + + box = [] + box.append("β”Œ" + "─" * (width + 2) + "┐") + for line in lines: + box.append(f"β”‚ {line:{width}} β”‚") + + box.append("β””" + "─" * (width + 2) + "β”˜") + return "\n".join(box) + + +def bold(text): + return colored(text, attrs=["bold"]) + + +def red(text): + return colored(text, "red") + + +def green(text): + return colored(text, "green") + + +class Tree: + def __init__(self, actor, parent=None): + self._actor = actor + self._kids = [] + self._parent = parent + + def adopt(self, kid): + kid._parent._kids.remove(kid) + self._kids.append(kid) + kid._parent = self + # assert False + + def forget(self): + self._parent._kids.remove(self) + self._parent = None + + @property + def is_leaf(self): + return len(self._kids) == 0 + + @property + def key(self): + return tuple(sorted(self._actor.as_dict().items())) + + @property + def _text(self): + return self._actor.summary + + @property + def summary(self): + return self._actor.summary + + def as_dict(self): + return self._actor.as_dict() + + def node(self, actor, insert=False): + node = Tree(actor, self) + if insert: + self._kids.insert(0, node) + else: + self._kids.append(node) + return node + + def print(self, file=sys.stdout): + padding = [] + + while self._factorise(): + pass + + self._print(padding, file=file) + + def _leaves(self, result): + if self.is_leaf: + result.append(self) + else: + for kid in self._kids: + kid._leaves(result) + + def _factorise(self): + if len(self._kids) == 0: + return False + + result = False + for kid in self._kids: + result = kid._factorise() or result + + if result: + return True + + same = defaultdict(list) + for kid in self._kids: + for grand_kid in kid._kids: + same[grand_kid.key].append((kid, grand_kid)) + + result = False + n = len(self._kids) + texts = [] + for text, v in same.items(): + if len(v) == n and n > 1: + for kid, grand_kid in v: + kid._kids.remove(grand_kid) + texts.append((text, v[1][1])) + result = True + + for text, actor in reversed(texts): + self.node(actor, True) + + if result: + return True + + if len(self._kids) != 1: + return False + + kid = self._kids[0] + texts = [] + for grand_kid in list(kid._kids): + if len(grand_kid._kids) == 0: + kid._kids.remove(grand_kid) + texts.append((grand_kid.key, grand_kid)) + result = True + + for text, actor in reversed(texts): + self.node(actor, True) + + return result + + def _print(self, padding, file=sys.stdout): + for i, p in enumerate(padding[:-1]): + if p == " β””": + padding[i] = " " + if p == " β”œ": + padding[i] = " β”‚" + if padding: + print(f"{''.join(padding)}─{self._text}", file=file) + else: + print(self._text, file=file) + padding.append(" ") + for i, k in enumerate(self._kids): + sep = " β”œ" if i < len(self._kids) - 1 else " β””" + padding[-1] = sep + k._print(padding, file=file) + + padding.pop() + + def to_json(self, depth=0): + while self._factorise(): + pass + + return { + "actor": self._actor.as_dict(), + "kids": [k.to_json(depth + 1) for k in self._kids], + "depth": depth, + } + + +def table(rows, header, align, margin=0): + def _(x): + try: + x = float(x) + except Exception: + pass + + if isinstance(x, float): + return f"{x:g}" + + if isinstance(x, str): + return x + if isinstance(x, int): + return str(x) + + return str(x) + + tmp = [] + for row in rows: + tmp.append([_(x) for x in row]) + + all_rows = [header] + tmp + + lens = [max(len(x) for x in col) for col in zip(*all_rows)] + + result = [] + for i, row in enumerate(all_rows): + result.append( + " β”‚ ".join( + [ + x.ljust(i) if align[j] == "<" else x.rjust(i) + for j, (x, i) in enumerate(zip(row, lens)) + ] + ) + ) + if i == 0: + result.append("─┼─".join(["─" * i for i in lens])) + + result.append("─┴─".join(["─" * i for i in lens])) + + if margin: + result = [margin * " " + x for x in result] + + return "\n".join(result) + + +def progress(done, todo, width=80): + done = min(int(done / todo * width + 0.5), width) + return green("β–ˆ" * done) + red("β–ˆ" * (width - done)) diff --git a/setup.py b/setup.py index 9fd94ee..df7bc84 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ def read(fname): "tqdm", "climetlab", # "earthkit-data" "earthkit-meteo", + "pyproj", ] @@ -80,7 +81,6 @@ def read(fname): }, zip_safe=True, keywords="tool", - entry_points={}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -93,4 +93,5 @@ def read(fname): "Programming Language :: Python :: Implementation :: PyPy", "Operating System :: OS Independent", ], + entry_points={"console_scripts": ["anemoi-datasets=ecml_tools.__main__:main"]}, ) diff --git a/tests/_test_create.py b/tests/_test_create.py old mode 100644 new mode 100755 index e633647..0b98be0 --- a/tests/_test_create.py +++ b/tests/_test_create.py @@ -1,13 +1,15 @@ +#!/usr/bin/env python3 # (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import json import os +import numpy as np + from ecml_tools.create import Creator from ecml_tools.data import open_dataset @@ -44,9 +46,21 @@ def compare_zarr(dir1, dir2): assert (a.dates == b.dates).all(), (a.dates, b.dates) for a_, b_ in zip(a.variables, b.variables): assert a_ == b_, (a, b) - for k, date in zip(range(a.shape[0]), a.dates): - for j in range(a.shape[1]): - assert (a[k, j] == b[k, j]).all(), (k, date, j, a[k, j], b[k, j]) + for i_date, date in zip(range(a.shape[0]), a.dates): + for i_param in range(a.shape[1]): + param = a.variables[i_param] + assert param == b.variables[i_param], ( + date, + param, + a.variables[i_param], + b.variables[i_param], + ) + a_ = a[i_date, i_param] + b_ = b[i_date, i_param] + assert a.shape == b.shape, (date, param, a.shape, b.shape) + delta = a_ - b_ + max_delta = np.max(np.abs(delta)) + assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta) compare(dir1, dir2) @@ -95,10 +109,6 @@ def _test_create(name): ) c.create() - # ds = open_dataset(zarr_path) - # assert ds.shape == - # assert ds.variables == - compare_zarr(reference, output) diff --git a/tests/create-1.yaml b/tests/create-1.yaml index 361afbf..68688eb 100644 --- a/tests/create-1.yaml +++ b/tests/create-1.yaml @@ -8,12 +8,11 @@ remapping: &remapping #param_level: '{param}_{levtype}{levelist}' param_level: '{param}_{levelist}' -loop: -- dates: - start: 2020-12-30 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h - group_by: monthly +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly input: join: diff --git a/tests/create-concat.yaml b/tests/create-concat.yaml index b2b3544..076664c 100644 --- a/tests/create-concat.yaml +++ b/tests/create-concat.yaml @@ -4,52 +4,46 @@ purpose: aifs name: test-small config_format_version: 2 -loop: -- dates: - start: 2020-12-30 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h - group_by: monthly +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly common: mars_request: &mars_request - name: mars - expver: '0001' + expver: "0001" class: ea - date: $datetime_format($dates,%Y%m%d) - time: $datetime_format($dates,%H%M) grid: 20./20. levtype: sfc stream: oper type: an - + param: [2t] input: concat: - - dates: - start: 2020-12-30 00:00:00 - end: 2021-01-01 12:00:00 - frequency: 12h - source: - '<<': *mars_request - param: [2t] - - dates: - start: 2021-01-02 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h - source: - '<<': *mars_request - param: [2t] + - dates: + start: 2020-12-30 00:00:00 + end: 2021-01-01 12:00:00 + frequency: 12h + mars: + <<: *mars_request + - dates: + start: 2021-01-02 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + mars: + <<: *mars_request output: - chunking: {dates: 1} + chunking: { dates: 1, ensembles: 1 } dtype: float32 flatten_grid: True order_by: - - valid_datetime - - param_level - - number + - valid_datetime + - param_level + - number statistics: param_level statistics_end: 2021 remapping: &remapping - param_level: '{param}_{levelist}' + param_level: "{param}_{levelist}" diff --git a/tests/create-join.yaml b/tests/create-join.yaml index 51c677d..fd18aee 100644 --- a/tests/create-join.yaml +++ b/tests/create-join.yaml @@ -6,64 +6,53 @@ config_format_version: 2 common: mars_request: &mars_request - name: mars - expver: '0001' + expver: "0001" class: ea grid: 20./20. - stream: oper - type: an - date: $datetime_format($dates,%Y%m%d) - time: $datetime_format($dates,%H%M) - dates: &dates_anchor - start: 2020-12-30 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h -loop: -- dates: - "<<": *dates_anchor - group_by: monthly +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly input: - dates: - "<<": *dates_anchor - join: - - label: - name: previous_data - source: - '<<': *mars_request - param: [2t] - levtype: sfc + join: + - mars: + <<: *mars_request + param: [2t] + levtype: sfc + stream: oper + type: an - - source: - '<<': *mars_request + - mars: + <<: *mars_request param: [q, t] levtype: pl level: [50, 100] + stream: oper + type: an - - source: - '<<': *mars_request - name: era5-accumulations + - accumulations: + <<: *mars_request + levtype: sfc param: [cp, tp] - # era5-accumulations requires time in hours - time: $datetime_format($dates,%H) + # accumulation_period: 6h - - source: - name: constants + - constants: + template: ${input.join.0.mars} param: - - cos_latitude - source_or_dataset: $previous_data - date: $dates + - cos_latitude output: - chunking: {dates: 1} + chunking: { dates: 1, ensembles: 1 } dtype: float32 flatten_grid: True order_by: - - valid_datetime - - param_level - - number + - valid_datetime + - param_level + - number statistics: param_level statistics_end: 2021 remapping: &remapping - param_level: '{param}_{levelist}' + param_level: "{param}_{levelist}" diff --git a/tests/create-perturbations-full.yaml b/tests/create-perturbations-full.yaml index 7485cf0..7900544 100644 --- a/tests/create-perturbations-full.yaml +++ b/tests/create-perturbations-full.yaml @@ -32,48 +32,47 @@ common: end: 2021-01-03 12:00:00 frequency: 12h -loop: -- dates: - "<<": *dates_anchor - group_by: monthly +dates: + <<: *dates_anchor + group_by: monthly input: dates: - "<<": *dates_anchor + <<: *dates_anchor join: - function: name: ensemble_perturbations ensembles: - '<<': *common_sfc + <<: *common_sfc stream: enda type: an number: 0/to/9 center: - '<<': *common_sfc + <<: *common_sfc stream: oper type: an mean: - '<<': *common_sfc + <<: *common_sfc stream: enda type: em - function: name: ensemble_perturbations ensembles: - '<<': *common_pl + <<: *common_pl stream: enda type: an number: 0/to/9 center: - '<<': *common_pl + <<: *common_pl stream: oper type: an mean: - '<<': *common_pl + <<: *common_pl stream: enda type: em diff --git a/tests/create-perturbations.yaml b/tests/create-perturbations.yaml index f1b7cb9..2b9c795 100644 --- a/tests/create-perturbations.yaml +++ b/tests/create-perturbations.yaml @@ -13,36 +13,35 @@ common: expver: '0001' grid: 20.0/20.0 levtype: sfc - param: [2t] + param: [2t, tp] dates: &dates_anchor start: 2020-12-30 00:00:00 end: 2021-01-03 12:00:00 frequency: 12h -loop: -- dates: - "<<": *dates_anchor - group_by: monthly +dates: + <<: *dates_anchor + group_by: monthly input: dates: - "<<": *dates_anchor + <<: *dates_anchor function: name: ensemble_perturbations ensembles: # the ensemble data has one additional dimension - '<<': *common + <<: *common stream: enda type: an number: 0/to/4/by/2 # number: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] center: # the new center of the data - '<<': *common + <<: *common stream: oper type: an mean: # the previous center of the data - '<<': *common + <<: *common stream: enda type: em diff --git a/tests/create-pipe.yaml b/tests/create-pipe.yaml index 6c76cfd..c8a80ac 100644 --- a/tests/create-pipe.yaml +++ b/tests/create-pipe.yaml @@ -4,72 +4,56 @@ purpose: aifs name: test-small config_format_version: 2 - common: mars_request: &mars_request - name: mars - expver: '0001' + expver: "0001" class: ea grid: 20./20. - stream: oper - type: an - date: $datetime_format($dates,%Y%m%d) - time: $datetime_format($dates,%H%M) - dates: &dates_anchor - start: 2020-12-30 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h -loop: -- dates: - "<<": *dates_anchor - group_by: monthly +dates: &dates_anchor + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly input: - dates: - "<<": *dates_anchor - join: - - label: - name: previous_data - source: - '<<': *mars_request - param: [2t] - levtype: sfc + join: + - mars: + <<: *mars_request + param: [2t] + levtype: sfc - pipe: - - source: - '<<': *mars_request - param: [q, t] - levtype: pl - level: [50, 100] - - filter: - param: [q] - - filter: - level: [50] - - - source: - '<<': *mars_request - name: era5-accumulations + - mars: + <<: *mars_request + param: [q, t] + levtype: pl + level: [50, 100] + stream: oper + type: an + - filter: + param: [q] + - filter: + level: [50] + + - accumulations: + <<: *mars_request param: [cp, tp] - # era5-accumulations requires time in hours - time: $datetime_format($dates,%H) - - source: - name: constants + - constants: + template: ${input.join.0.mars} param: - - cos_latitude - source_or_dataset: $previous_data - date: $dates + - cos_latitude output: - chunking: {dates: 1} + chunking: { dates: 1, ensembles: 1 } dtype: float32 flatten_grid: True order_by: - - valid_datetime - - param_level - - number + - valid_datetime + - param_level + - number statistics: param_level statistics_end: 2021 remapping: &remapping - param_level: '{param}_{levelist}' + param_level: "{param}_{levelist}" diff --git a/tests/create-shift.yaml b/tests/create-shift.yaml new file mode 100644 index 0000000..72dc620 --- /dev/null +++ b/tests/create-shift.yaml @@ -0,0 +1,51 @@ +description: "develop version of the dataset for a few days and a few variables, once data on mars is cached it should take a few seconds to generate the dataset" +dataset_status: testing +purpose: aifs +name: test-small +config_format_version: 2 + +common: + mars_request: &mars_request + expver: "0001" + class: ea + grid: 20./20. + +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly + +input: + join: + - mars: + <<: *mars_request + param: [2t] + levtype: sfc + stream: oper + type: an + + - constants: + template: ${input.join.0.mars} + param: + - insolation + + - date_shift: + delta: -25 + constants: + template: ${input.join.0.mars} + param: + - insolation + +output: + chunking: { dates: 1, ensembles: 1 } + dtype: float32 + flatten_grid: True + order_by: + - valid_datetime + - param_level + - number + statistics: param_level + statistics_end: 2021 + remapping: &remapping + param_level: "{param}_{levelist}" diff --git a/tests/test_data.py b/tests/test_data.py index 548206a..472ee5a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -330,7 +330,7 @@ def test_simple(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 2 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", statistics_reference_dataset="test-2021-2022-6h-o96-abcd", statistics_reference_variables="abcd", ) @@ -341,7 +341,6 @@ def test_concat(): "test-2021-2022-6h-o96-abcd", "test-2023-2023-6h-o96-abcd", ) - test.run( expected_class=Concat, expected_length=365 * 3 * 4, @@ -349,7 +348,7 @@ def test_concat(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 3 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2022-6h-o96-abcd", statistics_reference_variables="abcd", @@ -357,28 +356,15 @@ def test_concat(): def test_join_1(): - test = DatasetTester( - "test-2021-2021-6h-o96-abcd", - "test-2021-2021-6h-o96-efgh", - ) - + test = DatasetTester("test-2021-2021-6h-o96-abcd", "test-2021-2021-6h-o96-efgh") test.run( expected_class=Join, expected_length=365 * 4, start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 4, 8, 1, VALUES), - expected_variables=["a", "b", "c", "d", "e", "f", "g", "h"], - expected_name_to_index={ - "a": 0, - "b": 1, - "c": 2, - "d": 3, - "e": 4, - "f": 5, - "g": 6, - "h": 7, - }, + expected_variables="abcdefgh", + expected_name_to_index="abcdefgh", date_to_row=lambda date: simple_row(date, "abcdefgh"), # TODO: test second stats statistics_reference_dataset="test-2021-2021-6h-o96-abcd", @@ -387,19 +373,15 @@ def test_join_1(): def test_join_2(): - test = DatasetTester( - "test-2021-2021-6h-o96-abcd-1", - "test-2021-2021-6h-o96-bdef-2", - ) - + test = DatasetTester("test-2021-2021-6h-o96-abcd-1", "test-2021-2021-6h-o96-bdef-2") test.run( expected_class=Select, expected_length=365 * 4, start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 4, 6, 1, VALUES), - expected_variables=["a", "b", "c", "d", "e", "f"], - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5}, + expected_variables="abcdef", + expected_name_to_index="abcdef", date_to_row=lambda date: make_row( _(date, "a", 1), _(date, "b", 2), @@ -417,10 +399,7 @@ def test_join_2(): def test_join_3(): - test = DatasetTester( - "test-2021-2021-6h-o96-abcd-1", - "test-2021-2021-6h-o96-abcd-2", - ) + test = DatasetTester("test-2021-2021-6h-o96-abcd-1", "test-2021-2021-6h-o96-abcd-2") # TODO: This should trigger a warning about occulted dataset @@ -431,7 +410,7 @@ def test_join_3(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: make_row( _(date, "a", 2), _(date, "b", 2), @@ -445,7 +424,6 @@ def test_join_3(): def test_subset_1(): test = DatasetTester("test-2021-2023-1h-o96-abcd", frequency=12) - test.run( expected_class=Subset, expected_length=365 * 3 * 2, @@ -453,7 +431,7 @@ def test_subset_1(): time_increment=datetime.timedelta(hours=12), expected_shape=(365 * 3 * 2, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2023-1h-o96-abcd", statistics_reference_variables="abcd", @@ -467,7 +445,7 @@ def test_subset_2(): expected_length=365 * 24, expected_shape=(365 * 24, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(2022, 1, 1), time_increment=datetime.timedelta(hours=1), @@ -478,17 +456,14 @@ def test_subset_2(): def test_subset_3(): test = DatasetTester( - "test-2021-2023-1h-o96-abcd", - start=2022, - end=2022, - frequency=12, + "test-2021-2023-1h-o96-abcd", start=2022, end=2022, frequency=12 ) test.run( expected_class=Subset, expected_length=365 * 2, expected_shape=(365 * 2, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(2022, 1, 1), time_increment=datetime.timedelta(hours=12), @@ -506,7 +481,7 @@ def test_subset_4(): time_increment=datetime.timedelta(hours=1), expected_shape=((30 + 31 + 31) * 24, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2023-1h-o96-abcd", statistics_reference_variables="abcd", @@ -522,7 +497,7 @@ def test_subset_5(): time_increment=datetime.timedelta(hours=1), expected_shape=((30 + 31 + 31) * 24, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2023-1h-o96-abcd", statistics_reference_variables="abcd", @@ -531,11 +506,8 @@ def test_subset_5(): def test_subset_6(): test = DatasetTester( - "test-2021-2023-1h-o96-abcd", - start="2022-06-01", - end="2022-08-31", + "test-2021-2023-1h-o96-abcd", start="2022-06-01", end="2022-08-31" ) - test.run( expected_class=Subset, expected_length=(30 + 31 + 31) * 24, @@ -543,7 +515,7 @@ def test_subset_6(): time_increment=datetime.timedelta(hours=1), expected_shape=((30 + 31 + 31) * 24, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2023-1h-o96-abcd", statistics_reference_variables="abcd", @@ -559,7 +531,7 @@ def test_subset_7(): time_increment=datetime.timedelta(hours=1), expected_shape=((30 + 31 + 31) * 24, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2023-1h-o96-abcd", statistics_reference_variables="abcd", @@ -577,7 +549,7 @@ def test_subset_8(): expected_length=365 * 4, expected_shape=(365 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(2021, 1, 1, 3, 0, 0), time_increment=datetime.timedelta(hours=6), @@ -592,8 +564,8 @@ def test_select_1(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 2, 1, VALUES), - expected_variables=["b", "d"], - expected_name_to_index={"b": 0, "d": 1}, + expected_variables="bd", + expected_name_to_index="bd", date_to_row=lambda date: simple_row(date, "bd"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), @@ -608,8 +580,8 @@ def test_select_2(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 2, 1, VALUES), - expected_variables=["c", "a"], - expected_name_to_index={"c": 0, "a": 1}, + expected_variables="ca", + expected_name_to_index="ca", date_to_row=lambda date: simple_row(date, "ca"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), @@ -624,8 +596,8 @@ def test_select_3(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 2, 1, VALUES), - expected_variables=["a", "c"], - expected_name_to_index={"a": 0, "c": 1}, + expected_variables="ac", + expected_name_to_index="ac", date_to_row=lambda date: simple_row(date, "ac"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), @@ -640,11 +612,9 @@ def test_rename(): expected_class=Rename, expected_length=365 * 4, expected_shape=(365 * 4, 4, 1, VALUES), - expected_variables=["x", "b", "y", "d"], - expected_name_to_index={"x": 0, "b": 1, "y": 2, "d": 3}, - date_to_row=lambda date: make_row( - _(date, "a"), _(date, "b"), _(date, "c"), _(date, "d") - ), + expected_variables="xbyd", + expected_name_to_index="xbyd", + date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), statistics_reference_dataset=None, @@ -660,9 +630,9 @@ def test_drop(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 3, 1, VALUES), - expected_variables=["b", "c", "d"], - expected_name_to_index={"b": 0, "c": 1, "d": 2}, - date_to_row=lambda date: make_row(_(date, "b"), _(date, "c"), _(date, "d")), + expected_variables="bcd", + expected_name_to_index="bcd", + date_to_row=lambda date: simple_row(date, "bcd"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", @@ -676,11 +646,9 @@ def test_reorder_1(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 4, 1, VALUES), - expected_variables=["d", "c", "b", "a"], - expected_name_to_index={"d": 0, "c": 1, "b": 2, "a": 3}, - date_to_row=lambda date: make_row( - _(date, "d"), _(date, "c"), _(date, "b"), _(date, "a") - ), + expected_variables="dcba", + expected_name_to_index="dcba", + date_to_row=lambda date: simple_row(date, "dcba"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", @@ -694,11 +662,9 @@ def test_reorder_2(): expected_class=Select, expected_length=365 * 4, expected_shape=(365 * 4, 4, 1, VALUES), - expected_variables=["d", "c", "b", "a"], - expected_name_to_index={"d": 0, "c": 1, "b": 2, "a": 3}, - date_to_row=lambda date: make_row( - _(date, "d"), _(date, "c"), _(date, "b"), _(date, "a") - ), + expected_variables="dcba", + expected_name_to_index="dcba", + date_to_row=lambda date: simple_row(date, "dcba"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", @@ -719,7 +685,7 @@ def test_constructor_1(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 2 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", statistics_reference_variables="abcd", @@ -728,7 +694,10 @@ def test_constructor_1(): def test_constructor_2(): test = DatasetTester( - datasets=["test-2021-2021-6h-o96-abcd", "test-2022-2022-6h-o96-abcd"] + datasets=[ + "test-2021-2021-6h-o96-abcd", + "test-2022-2022-6h-o96-abcd", + ] ) test.run( expected_class=Concat, @@ -737,7 +706,7 @@ def test_constructor_2(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 2 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", statistics_reference_variables="abcd", @@ -760,7 +729,7 @@ def test_constructor_3(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 2 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", statistics_reference_variables="abcd", @@ -770,7 +739,10 @@ def test_constructor_3(): def test_constructor_4(): test = DatasetTester( "test-2021-2021-6h-o96-abcd", - {"dataset": "test-2022-2022-1h-o96-abcd", "frequency": 6}, + { + "dataset": "test-2022-2022-1h-o96-abcd", + "frequency": 6, + }, ) test.run( expected_class=Concat, @@ -779,7 +751,7 @@ def test_constructor_4(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 2 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), statistics_reference_dataset="test-2021-2021-6h-o96-abcd", statistics_reference_variables="abcd", @@ -797,8 +769,8 @@ def test_constructor_5(): start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 4, 7, 1, VALUES), - expected_variables=["x", "b", "y", "d", "a", "z", "t"], - expected_name_to_index={"x": 0, "b": 1, "y": 2, "d": 3, "a": 4, "z": 5, "t": 6}, + expected_variables="xbydazt", + expected_name_to_index="xbydazt", date_to_row=lambda date: make_row( _(date, "a", 1), _(date, "b", 2), @@ -847,7 +819,7 @@ def test_slice_1(): expected_length=365 * 1 * 4, expected_shape=(365 * 1 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(2021, 1, 1), time_increment=datetime.timedelta(hours=6), @@ -865,7 +837,7 @@ def test_slice_2(): expected_length=60632, expected_shape=(60632, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: simple_row(date, "abcd"), start_date=datetime.datetime(1940, 1, 1), time_increment=datetime.timedelta(hours=12), @@ -888,62 +860,8 @@ def test_slice_3(): start_date=datetime.datetime(2020, 1, 1), time_increment=datetime.timedelta(hours=6), expected_shape=(366 * 4, 26, 1, VALUES), - expected_variables=[ - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - ], - expected_name_to_index={ - "a": 0, - "b": 1, - "c": 2, - "d": 3, - "e": 4, - "f": 5, - "g": 6, - "h": 7, - "i": 8, - "j": 9, - "k": 10, - "l": 11, - "m": 12, - "n": 13, - "o": 14, - "p": 15, - "q": 16, - "r": 17, - "s": 18, - "t": 19, - "u": 20, - "v": 21, - "w": 22, - "x": 23, - "y": 24, - "z": 25, - }, + expected_variables="abcdefghijklmnopqrstuvwxyz", + expected_name_to_index="abcdefghijklmnopqrstuvwxyz", statistics_reference_dataset=None, statistics_reference_variables=None, ) @@ -961,7 +879,7 @@ def test_slice_4(): time_increment=datetime.timedelta(hours=1), expected_shape=(8784, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", statistics_reference_dataset=None, statistics_reference_variables=None, ) @@ -980,7 +898,7 @@ def test_slice_5(): time_increment=datetime.timedelta(hours=18), expected_shape=(4870, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", statistics_reference_dataset=None, statistics_reference_variables=None, ) @@ -998,7 +916,7 @@ def test_ensemble_1(): expected_length=365 * 1 * 4, expected_shape=(365 * 1 * 4, 4, 11, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: make_row( [_(date, "a", 1, i) for i in range(10)] + [_(date, "a", 2, 0)], [_(date, "b", 1, i) for i in range(10)] + [_(date, "b", 2, 0)], @@ -1026,7 +944,7 @@ def test_ensemble_2(): expected_length=365 * 1 * 4, expected_shape=(365 * 1 * 4, 4, 16, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: make_row( [_(date, "a", 1, i) for i in range(10)] + [_(date, "a", 2, 0)] @@ -1062,7 +980,7 @@ def test_ensemble_3(): expected_length=365 * 1 * 2, expected_shape=(365 * 1 * 2, 4, 16, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: make_row( [_(date, "a", 1, i) for i in range(10)] + [_(date, "a", 2, 0)] @@ -1097,7 +1015,7 @@ def test_grids(): expected_length=365 * 1 * 4, expected_shape=(365 * 1 * 4, 4, 1, VALUES + 25), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", date_to_row=lambda date: make_row( [ _(date, "a", 1), @@ -1145,7 +1063,7 @@ def test_statistics(): time_increment=datetime.timedelta(hours=6), expected_shape=(365 * 4, 4, 1, VALUES), expected_variables="abcd", - expected_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3}, + expected_name_to_index="abcd", statistics_reference_dataset="test-2000-2010-6h-o96-abcd", statistics_reference_variables="abcd", )