diff --git a/ecml_tools/create/functions/actions/constants.py b/ecml_tools/create/functions/actions/constants.py index 6e48c15..207d5fa 100644 --- a/ecml_tools/create/functions/actions/constants.py +++ b/ecml_tools/create/functions/actions/constants.py @@ -55,12 +55,8 @@ def normalise_time_to_hours(r): return r -def constants(context, dates, **request): - param = request["param"] - - template = get_template_field(request) - - print(f"✅ load_source(constants, {template}, {request}") +def constants(context, dates, template, param): + print(f"✅ load_source(constants, {template}, {param}") return load_source("constants", source_or_dataset=template, date=dates, param=param) diff --git a/ecml_tools/create/functions/actions/opendap.py b/ecml_tools/create/functions/actions/opendap.py index b49ca2f..2fad5ed 100644 --- a/ecml_tools/create/functions/actions/opendap.py +++ b/ecml_tools/create/functions/actions/opendap.py @@ -13,7 +13,6 @@ def opendap(context, dates, url_pattern, *args, **kwargs): - all_urls = Pattern(url_pattern, ignore_missing_keys=True).substitute( *args, date=dates, **kwargs ) @@ -22,7 +21,6 @@ def opendap(context, dates, url_pattern, *args, **kwargs): levels = kwargs.get("level", kwargs.get("levelist")) for url in all_urls: - print("URL", url) s = load_source("opendap", url) s = s.sel( diff --git a/ecml_tools/create/functions/steps/rotate_winds.py b/ecml_tools/create/functions/steps/rotate_winds.py new file mode 100644 index 0000000..454f046 --- /dev/null +++ b/ecml_tools/create/functions/steps/rotate_winds.py @@ -0,0 +1,15 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +def execute(context, input, **kwargs): + print("🧭", input, kwargs) + for f in input: + print(f) + return input diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index 4f444c7..018eb8b 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -18,7 +18,7 @@ from climetlab.core.order import build_remapping from .group import build_groups -from .template import substitute +from .template import resolve, substitute from .utils import seconds LOG = logging.getLogger(__name__) @@ -100,16 +100,9 @@ class Cache: class Coords: def __init__(self, owner): self.owner = owner - self.cache = Cache() + @cached_property def _build_coords(self): - # assert isinstance(self.owner.context, Context), type(self.owner.context) - # assert isinstance(self.owner, Result), type(self.owner) - # assert hasattr(self.owner, "context"), self.owner - # assert hasattr(self.owner, "datasource"), self.owner - # assert hasattr(self.owner, "get_cube"), self.owner - # self.owner.datasource - from_data = self.owner.get_cube().user_coords from_config = self.owner.context.order_by @@ -135,63 +128,75 @@ def _build_coords(self): from_config[variables_key], ) - self.cache.variables = from_data[variables_key] # "param_level" - self.cache.ensembles = from_data[ensembles_key] # "number" + self._variables = from_data[variables_key] # "param_level" + self._ensembles = from_data[ensembles_key] # "number" first_field = self.owner.datasource[0] grid_points = first_field.grid_points() grid_values = list(range(len(grid_points[0]))) - self.cache.grid_points = grid_points - self.cache.resolution = first_field.resolution - self.cache.grid_values = grid_values + self._grid_points = grid_points + self._resolution = first_field.resolution + self._grid_values = grid_values - def __getattr__(self, name): - if name in [ - "variables", - "ensembles", - "resolution", - "grid_values", - "grid_points", - ]: - if not hasattr(self.cache, name): - self._build_coords() - return getattr(self.cache, name) - raise AttributeError(name) + @cached_property + def variables(self): + self._build_coords + return self._variables + + @cached_property + def ensembles(self): + self._build_coords + return self._ensembles + + @cached_property + def resolution(self): + self._build_coords + return self._resolution + + @cached_property + def grid_values(self): + self._build_coords + return self._grid_values + + @cached_property + def grid_points(self): + self._build_coords + return self._grid_points class HasCoordsMixin: - @property + @cached_property def variables(self): return self._coords.variables - @property + @cached_property def ensembles(self): return self._coords.ensembles - @property + @cached_property def resolution(self): return self._coords.resolution - @property + @cached_property def grid_values(self): return self._coords.grid_values - @property + @cached_property def grid_points(self): return self._coords.grid_points - @property + @cached_property def dates(self): if self._dates is None: raise ValueError(f"No dates for {self}") return self._dates.values - @property + @cached_property def frequency(self): return self._dates.frequency - @property + @cached_property def shape(self): return [ len(self.dates), @@ -200,7 +205,7 @@ def shape(self): len(self.grid_values), ] - @property + @cached_property def coords(self): return { "dates": self.dates, @@ -211,7 +216,7 @@ def coords(self): class Action: - def __init__(self, context, /, *args, **kwargs): + def __init__(self, context, path, /, *args, **kwargs): if "args" in kwargs and "kwargs" in kwargs: """We have: args = [] @@ -227,6 +232,7 @@ def __init__(self, context, /, *args, **kwargs): self.context = context self.kwargs = kwargs self.args = args + self.path = path @classmethod def _short_str(cls, x): @@ -252,14 +258,29 @@ def _raise_not_implemented(self): raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") +def check_references(method): + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + self.context.notify_result(self.path, result) + return result + + return wrapper + + class Result(HasCoordsMixin): empty = False - def __init__(self, context, dates=None): + def __init__(self, context, path, dates): assert isinstance(context, Context), type(context) + + assert path is None or isinstance(path, list), path + self.context = context self._coords = Coords(self) self._dates = dates + self.path = path + if path is not None: + context.register_reference(path, self) @property def datasource(self): @@ -318,9 +339,6 @@ def _raise_not_implemented(self): class EmptyResult(Result): empty = True - def __init__(self, context, dates=None): - super().__init__(context) - @cached_property def datasource(self): from climetlab import load_source @@ -332,44 +350,41 @@ def variables(self): return [] -class ReferencesSolver(dict): - def __init__(self, context, dates): - self.context = context - self.dates = dates - - def __getitem__(self, key): - if key == "dates": - return self.dates.values - if key in self.context.references: - result = self.context.references[key] - return result.datasource - raise KeyError(key) - - class FunctionResult(Result): - def __init__(self, context, dates, action, previous_sibling=None): - super().__init__(context, dates) + def __init__(self, context, path, dates, action): + super().__init__(context, path, dates) assert isinstance(action, Action), type(action) self.action = action - _args = self.action.args - _kwargs = self.action.kwargs - - vars = ReferencesSolver(context, dates) + self.args, self.kwargs = substitute( + context, (self.action.args, self.action.kwargs) + ) - self.args = substitute(_args, vars) - self.kwargs = substitute(_kwargs, vars) + self._result = None - # @cached_property @property + @check_references def datasource(self): print( - f"applying function {self.action.function} to {self.dates}, {self.args} {self.kwargs}, {self}" + "🌎", + self.path, + f"{self.action.function.__name__}.datasource({self.dates}, {self.args} {self.kwargs})", ) - return self.action.function( - FunctionContext(self), self.dates, *self.args, **self.kwargs + + # We don't use the cached_property here because if hides + # errors in the function. + + if self._result is not None: + return self._result + + args, kwargs = resolve(self.context, (self.args, self.kwargs)) + + self._result = self.action.function( + FunctionContext(self), self.dates, *args, **kwargs ) + return self._result + def __repr__(self): content = " ".join([f"{v}" for v in self.args]) content += " ".join([f"{k}={v}" for k, v in self.kwargs.items()]) @@ -382,13 +397,14 @@ def function(self): class JoinResult(Result): - def __init__(self, context, dates, results, **kwargs): - super().__init__(context, dates) + def __init__(self, context, path, dates, results, **kwargs): + super().__init__(context, path, dates) self.results = [r for r in results if not r.empty] - @property + @cached_property + @check_references def datasource(self): - ds = EmptyResult(self.context, self._dates).datasource + ds = EmptyResult(self.context, None, self._dates).datasource for i in self.results: ds += i.datasource assert_is_fieldset(ds), i @@ -399,46 +415,14 @@ def __repr__(self): return super().__repr__(content) -class DependencyAction(Action): - def __init__(self, context, **kwargs): - super().__init__(context) - self.content = action_factory(kwargs, context) - - def select(self, dates): - self.content.select(dates) - # this should trigger a registration of the result in the context - # if there is a label - # self.context.register_reference(self.name, result) - return EmptyResult(self.context, dates) - - def __repr__(self): - return super().__repr__(self.content) - - -class LabelAction(Action): - def __init__(self, context, name, **kwargs): - super().__init__(context) - if len(kwargs) != 1: - raise ValueError(f"Invalid kwargs for label : {kwargs}") - self.name = name - self.content = action_factory(kwargs, context) - - def select(self, dates): - result = self.content.select(dates) - self.context.register_reference(self.name, result) - return result - - def __repr__(self): - return super().__repr__(_inline_=self.name, _indent_=" ") - - class FunctionAction(Action): - def __init__(self, context, _name, **kwargs): - super().__init__(context, **kwargs) + def __init__(self, context, path, _name, **kwargs): + super().__init__(context, path, **kwargs) self.name = _name def select(self, dates): - return FunctionResult(self.context, dates, action=self) + print("🚀", self.path, f"{self.name}.select({dates})") + return FunctionResult(self.context, self.path, dates, action=self) @property def function(self): @@ -455,11 +439,12 @@ def __repr__(self): class ConcatResult(Result): - def __init__(self, context, results): + def __init__(self, context, path, results): super().__init__(context, dates=None) self.results = [r for r in results if not r.empty] - @property + @cached_property + @check_references def datasource(self): ds = EmptyResult(self.context, self.dates).datasource for i in self.results: @@ -504,9 +489,11 @@ def __repr__(self): class ActionWithList(Action): result_class = None - def __init__(self, context, *configs): - super().__init__(context, *configs) - self.actions = [action_factory(c, context) for c in configs] + def __init__(self, context, path, *configs): + super().__init__(context, path, *configs) + self.actions = [ + action_factory(c, context, path + [str(i)]) for i, c in enumerate(configs) + ] def __repr__(self): content = "\n".join([str(i) for i in self.actions]) @@ -514,61 +501,87 @@ def __repr__(self): class PipeAction(Action): - def __init__(self, context, *configs): - super().__init__(context, *configs) - current = action_factory(configs[0], context) - for c in configs[1:]: - current = step_factory(c, context, _upstream_action=current) + def __init__(self, context, path, *configs): + super().__init__(context, path, *configs) + assert len(configs) > 1, configs + current = action_factory(configs[0], context, path + ["0"]) + for i, c in enumerate(configs[1:]): + current = step_factory( + c, context, path + [str(i + 1)], previous_step=current + ) self.content = current def select(self, dates): - return self.content.select(dates) + print("🚀", self.path, f"PipeAction.select({dates}, {self.content})") + result = self.content.select(dates) + print("🍎", self.path, f"PipeAction.result", result) + return result def __repr__(self): return super().__repr__(self.content) class StepResult(Result): - def __init__(self, upstream, context, dates, action): - super().__init__(context, dates) - assert isinstance(upstream, Result), type(upstream) - self.content = upstream + def __init__(self, context, path, dates, action, upstream_result): + print("🐫", "step result", path, upstream_result, type(upstream_result)) + super().__init__(context, path, dates) + assert isinstance(upstream_result, Result), type(upstream_result) + self.upstream_result = upstream_result self.action = action @property + @check_references def datasource(self): - return self.content.datasource + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + # return self.upstream_result.datasource class StepAction(Action): result_class = None - def __init__(self, context, _upstream_action, **kwargs): - super().__init__(context, **kwargs) - self.content = _upstream_action + def __init__(self, context, path, previous_step, *args, **kwargs): + super().__init__(context, path, *args, **kwargs) + self.previous_step = previous_step def select(self, dates): return self.result_class( - self.content.select(dates), self.context, + self.path, dates, self, + self.previous_step.select(dates), ) def __repr__(self): - return super().__repr__(self.content, _inline_=str(self.kwargs)) + return super().__repr__(self.previous_step, _inline_=str(self.kwargs)) + +class StepFunctionResult(StepResult): + _result = None -class StepFunctionResult(StepAction): @property + @check_references def datasource(self): - return self.function( - FunctionContext(self), self.content.datasource, **self.kwargs + # We don't use the cached_property here because if hides + # errors in the function. + + print("🥧", "StepFunctionResult.datasource", self.action.function) + + if self._result is not None: + return self._result + + self._result = self.action.function( + FunctionContext(self), + self.upstream_result.datasource, + **self.action.kwargs, ) + return self._result + class FilterStepResult(StepResult): @property + @check_references def datasource(self): ds = self.content.datasource assert_is_fieldset(ds) @@ -581,6 +594,14 @@ class FilterStepAction(StepAction): result_class = FilterStepResult +class FunctionStepAction(StepAction): + def __init__(self, context, path, previous_step, *args, **kwargs): + super().__init__(context, path, previous_step, *args, **kwargs) + self.function = import_function(args[0], "steps") + + result_class = StepFunctionResult + + class ConcatAction(ActionWithList): def select(self, dates): return ConcatResult(self.context, [a.select(dates) for a in self.actions]) @@ -588,11 +609,13 @@ def select(self, dates): class JoinAction(ActionWithList): def select(self, dates): - return JoinResult(self.context, dates, [a.select(dates) for a in self.actions]) + return JoinResult( + self.context, self.path, dates, [a.select(dates) for a in self.actions] + ) class DateAction(Action): - def __init__(self, context, **kwargs): + def __init__(self, context, path, **kwargs): super().__init__(context, **kwargs) datesconfig = {} @@ -630,23 +653,23 @@ def merge_dicts(a, b): return deepcopy(b) -def action_factory(config, context): +def action_factory(config, context, path): assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") - if len(config) == 2 and "label" in config: - config = deepcopy(config) - label = config.pop("label") - return action_factory( - dict( - label=dict( - name=label, - **config, - ) - ), - context, - ) + # if len(config) == 2 and "label" in config: + # config = deepcopy(config) + # label = config.pop("label") + # return action_factory( + # dict( + # label=dict( + # name=label, + # **config, + # ) + # ), + # context, + # ) if len(config) != 1: raise ValueError( @@ -658,12 +681,12 @@ def action_factory(config, context): cls = dict( concat=ConcatAction, join=JoinAction, - label=LabelAction, + # label=LabelAction, pipe=PipeAction, # source=SourceAction, function=FunctionAction, dates=DateAction, - dependency=DependencyAction, + # dependency=DependencyAction, ).get(key) if isinstance(config[key], list): @@ -678,10 +701,10 @@ def action_factory(config, context): cls = FunctionAction args = [key] + args - return cls(context, *args, **kwargs) + return cls(context, path + [key], *args, **kwargs) -def step_factory(config, context, _upstream_action): +def step_factory(config, context, path, previous_step): assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") @@ -694,7 +717,7 @@ def step_factory(config, context, _upstream_action): filter=FilterStepAction, # rename=RenameAction, # remapping=RemappingAction, - )[key] + ).get(key) if isinstance(config[key], list): args, kwargs = config[key], {} @@ -702,11 +725,13 @@ def step_factory(config, context, _upstream_action): if isinstance(config[key], dict): args, kwargs = [], config[key] - if "_upstream_action" in kwargs: - raise ValueError(f"Reserverd keyword '_upsream_action' in {config}") - kwargs["_upstream_action"] = _upstream_action + if cls is None: + if not is_function(key, "steps"): + raise ValueError(f"Unknown step {key}") + cls = FunctionStepAction + args = [key] + args - return cls(context, *args, **kwargs) + return cls(context, path, previous_step, *args, **kwargs) class FunctionContext: @@ -721,21 +746,49 @@ def __init__(self, /, order_by, flatten_grid, remapping): self.remapping = build_remapping(remapping) self.references = {} - - def register_reference(self, name, obj): - assert isinstance(obj, Result), type(obj) - if name in self.references: - raise ValueError(f"Duplicate reference {name}") - self.references[name] = obj - - def find_reference(self, name): - if name in self.references: - return self.references[name] - # It can happend that the required name is not yet registered, + self.used_references = set() + self.results = {} + + def register_reference(self, path, obj): + assert isinstance(path, (list, tuple)), path + path = tuple(path) + print("=======> register", path, type(obj)) + if path in self.references: + raise ValueError(f"Duplicate reference {path}") + self.references[path] = obj + + def find_reference(self, path): + assert isinstance(path, (list, tuple)), path + path = tuple(path) + if path in self.references: + return self.references[path] + # It can happend that the required path is not yet registered, # even if it is defined in the config. # Handling this case implies implementing a lazy inheritance resolution # and would complexify the code. This is not implemented. - raise ValueError(f"Cannot find reference {name}") + + raise ValueError(f"Cannot find reference {path}") + + def will_need_reference(self, path): + assert isinstance(path, (list, tuple)), path + path = tuple(path) + self.used_references.add(path) + + def notify_result(self, path, result): + print("notify_result", path, result) + assert isinstance(path, (list, tuple)), path + path = tuple(path) + if path in self.used_references: + if path in self.results: + raise ValueError(f"Duplicate result {path}") + self.results[path] = result + + def get_result(self, path): + assert isinstance(path, (list, tuple)), path + path = tuple(path) + if path in self.results: + return self.results[path] + raise ValueError(f"Cannot find result {path}") class InputBuilder: @@ -747,12 +800,12 @@ def select(self, dates): """This changes the context.""" dates = build_groups(dates) context = Context(**self.kwargs) - action = action_factory(self.config, context) + action = action_factory(self.config, context, ["input"]) return action.select(dates) def __repr__(self): context = Context(**self.kwargs) - a = action_factory(self.config, context) + a = action_factory(self.config, context, ["input"]) return repr(a) diff --git a/ecml_tools/create/template.py b/ecml_tools/create/template.py index d7d87e6..3e7436f 100644 --- a/ecml_tools/create/template.py +++ b/ecml_tools/create/template.py @@ -8,157 +8,56 @@ # import logging -import os import re -from ecml_tools.create.utils import to_datetime - LOG = logging.getLogger(__name__) -def substitute(x, vars=None, ignore_missing=False): - """Recursively substitute environment variables and dict values in a nested list ot dict of string. - substitution is performed using the environment var (if UPPERCASE) or the input dictionary. +class Substitution: + pass + +class Reference(Substitution): + def __init__(self, context, path): + self.context = context + self.path = path - >>> substitute({'bar': '$bar'}, {'bar': '43'}) - {'bar': '43'} + def resolve(self, context): + return context.get_result(self.path) - >>> substitute({'bar': '$BAR'}, {'BAR': '43'}) - Traceback (most recent call last): - ... - KeyError: 'BAR' - >>> substitute({'bar': '$BAR'}, ignore_missing=True) - {'bar': '$BAR'} +def resolve(context, x): + if isinstance(x, tuple): + return tuple([resolve(context, y) for y in x]) - >>> os.environ["BAR"] = "42" - >>> substitute({'bar': '$BAR'}) - {'bar': '42'} + if isinstance(x, list): + return [resolve(context, y) for y in x] - >>> substitute('$bar', {'bar': '43'}) - '43' + if isinstance(x, dict): + return {k: resolve(context, v) for k, v in x.items()} - >>> substitute('$hdates_from_date($date, 2015, 2018)', {'date': '2023-05-12'}) - '2015-05-12/2016-05-12/2017-05-12/2018-05-12' + if isinstance(x, Substitution): + return x.resolve(context) - """ - if vars is None: - vars = {} + return x - assert isinstance(vars, dict), vars - if isinstance(x, (tuple, list)): - return [substitute(y, vars, ignore_missing=ignore_missing) for y in x] +def substitute(context, x): + if isinstance(x, tuple): + return tuple([substitute(context, y) for y in x]) + + if isinstance(x, list): + return [substitute(context, y) for y in x] if isinstance(x, dict): - return { - k: substitute(v, vars, ignore_missing=ignore_missing) for k, v in x.items() - } - - if isinstance(x, str): - if "$" not in x: - return x - - lst = [] - - for i, bit in enumerate(re.split(r"(\$(\w+)(\([^\)]*\))?)", x)): - if bit is None: - continue - assert isinstance(bit, str), (bit, type(bit), x, type(x)) - - i %= 4 - if i in [2, 3]: - continue - if i == 1: - try: - if "(" in bit: - # substitute by a function - FUNCTIONS = dict( - hdates_from_date=hdates_from_date, - datetime_format=datetime_format, - ) - - pattern = r"\$(\w+)\(([^)]*)\)" - match = re.match(pattern, bit) - assert match, bit - - function_name = match.group(1) - params = [p.strip() for p in match.group(2).split(",")] - params = [ - substitute(p, vars, ignore_missing=ignore_missing) - for p in params - ] - - bit = FUNCTIONS[function_name](*params) - - elif bit.upper() == bit: - # substitute by the var env if $UPPERCASE - bit = os.environ[bit[1:]] - else: - # substitute by the value in the 'vars' dict - bit = vars[bit[1:]] - except KeyError as e: - if not ignore_missing: - raise e - - if bit != x: - bit = substitute(bit, vars, ignore_missing=ignore_missing) - - lst.append(bit) - - lst = [_ for _ in lst if _ != ""] - if len(lst) == 1: - return lst[0] - - out = [] - for elt in lst: - # if isinstance(elt, str): - # elt = [elt] - assert isinstance(elt, (list, tuple)), elt - out += elt - return out + return {k: substitute(context, v) for k, v in x.items()} - return x + if not isinstance(x, str): + return x + if re.match(r"^\${[\.\w]+}$", x): + path = x[2:-1].split(".") + context.will_need_reference(path) + return Reference(context, path) -def datetime_format(dates, format, join=None): - formated = [to_datetime(d).strftime(format) for d in dates] - formated = set(formated) - formated = list(formated) - formated = sorted(formated) - if join: - formated = join.join(formated) - return formated - - -def hdates_from_date(date, start_year, end_year): - """ - Returns a list of dates in the format '%Y%m%d' between start_year and end_year (inclusive), - with the year of the input date. - - Args: - date (str or datetime): The input date. - start_year (int): The start year. - end_year (int): The end year. - - Returns: - List[str]: A list of dates in the format '%Y%m%d'. - """ - if not str(start_year).isdigit(): - raise ValueError(f"start_year must be an int: {start_year}") - if not str(end_year).isdigit(): - raise ValueError(f"end_year must be an int: {end_year}") - start_year = int(start_year) - end_year = int(end_year) - - if isinstance(date, (list, tuple)): - if len(date) != 1: - raise NotImplementedError(f"{date} should have only one element.") - date = date[0] - - date = to_datetime(date) - assert not (date.hour or date.minute or date.second), date - - hdates = [date.replace(year=year) for year in range(start_year, end_year + 1)] - return "/".join(d.strftime("%Y-%m-%d") for d in hdates) + return x diff --git a/ecml_tools/grids.py b/ecml_tools/grids.py index 19e9546..51e1413 100644 --- a/ecml_tools/grids.py +++ b/ecml_tools/grids.py @@ -32,9 +32,7 @@ def latlon_to_xyz(lat, lon, radius=1.0): class Triangle3D: - def __init__(self, v0, v1, v2): - self.v0 = v0 self.v1 = v1 self.v2 = v2 @@ -168,7 +166,6 @@ def cutout_mask( if __name__ == "__main__": - global_lats, global_lons = np.meshgrid( np.linspace(90, -90, 90), np.linspace(-180, 180, 180),