diff --git a/ecml_tools/create/__init__.py b/ecml_tools/create/__init__.py index aefe64e..6498d19 100644 --- a/ecml_tools/create/__init__.py +++ b/ecml_tools/create/__init__.py @@ -51,9 +51,9 @@ def load(self, parts=None): with self._cache_context(): loader = ContentLoader.from_dataset_config( - path=self.path, statistics_tmp=self.statistics_tmp, print=self.print + path=self.path, statistics_tmp=self.statistics_tmp, print=self.print, parts=parts ) - loader.load(parts=parts) + loader.load() def statistics(self, force=False, output=None, start=None, end=None): from .loaders import StatisticsLoader diff --git a/ecml_tools/create/check.py b/ecml_tools/create/check.py index acc5be9..ecc37d9 100644 --- a/ecml_tools/create/check.py +++ b/ecml_tools/create/check.py @@ -23,9 +23,7 @@ def compute_directory_size(path): return None size = 0 n = 0 - for dirpath, _, filenames in tqdm.tqdm( - os.walk(path), desc="Computing size", leave=False - ): + for dirpath, _, filenames in tqdm.tqdm(os.walk(path), desc="Computing size", leave=False): for filename in filenames: file_path = os.path.join(dirpath, filename) size += os.path.getsize(file_path) @@ -57,10 +55,7 @@ def __init__( self.check_end_date(end_date) if self.messages: - self.messages.append( - f"{self} is parsed as :" - + "/".join(f"{k}={v}" for k, v in self.parsed.items()) - ) + self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items())) @property def error_message(self): @@ -76,9 +71,7 @@ def raise_if_not_valid(self, print=print): raise ValueError(self.error_message) def _parse(self, name): - pattern = ( - r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$" - ) + pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$" match = re.match(pattern, name) assert match, (name, pattern) @@ -112,10 +105,7 @@ def check_parsed(self): ) def check_resolution(self, resolution): - if ( - self.parsed.get("resolution") - and self.parsed["resolution"][0] not in "0123456789on" - ): + if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on": self.messages.append( f"the resolution {self.parsed['resolution'] } should start " f"with a number or 'o' or 'n' in the dataset name {self}." @@ -150,22 +140,23 @@ def check_end_date(self, end_date): def _check_missing(self, key, value): if value not in self.name: - self.messages.append( - f"the {key} is {value}, but is missing in {self.name}." - ) + self.messages.append(f"the {key} is {value}, but is missing in {self.name}.") def _check_mismatch(self, key, value): if self.parsed.get(key) and self.parsed[key] != value: - self.messages.append( - f"the {key} is {value}, but is {self.parsed[key]} in {self.name}." - ) + self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.") class StatisticsValueError(ValueError): pass -def check_data_values(arr, *, name: str, log=[]): +def check_data_values(arr, *, name: str, log=[], allow_nan=False): + if allow_nan is False: + allow_nan = lambda x: False + if allow_nan(name): + arr = arr[~np.isnan(arr)] + min, max = arr.min(), arr.max() assert not (np.isnan(arr).any()), (name, min, max, *log) @@ -175,10 +166,19 @@ def check_data_values(arr, *, name: str, log=[]): warnings.warn(f"Max value 9999 for {name}") in_0_1 = dict(minimum=0, maximum=1) + is_temperature = dict(minimum=173.15, maximum=373.15) # -100 celsius to +200 celsius + # is_wind = dict(minimum=-500., maximum=500.) limits = { "lsm": in_0_1, + "cos_latitude": in_0_1, + "sin_latitude": in_0_1, + "cos_longitude": in_0_1, + "sin_longitude": in_0_1, "insolation": in_0_1, - "2t": dict(minimum=173.15, maximum=373.15), + "2t": is_temperature, + "sst": is_temperature, + # "10u": is_wind, + # "10v": is_wind, } if name in limits: diff --git a/ecml_tools/create/config.py b/ecml_tools/create/config.py index db396bd..c27e8ea 100644 --- a/ecml_tools/create/config.py +++ b/ecml_tools/create/config.py @@ -93,10 +93,6 @@ def get_chunking(self, coords): ) return tuple(chunks) - @property - def append_axis(self): - return self.config.append_axis - @property def order_by(self): return self.config.order_by @@ -163,10 +159,6 @@ def normalise(self): assert "flatten_values" not in self.output assert "flatten_grid" in self.output, self.output - # The axis along which we append new data - # TODO: assume grid points can be 2d as well - self.output.append_axis = 0 - assert "statistics" in self.output statistics_axis_name = self.output.statistics statistics_axis = -1 diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index 0e4902c..3631252 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -332,6 +332,11 @@ def data_request(self): """Returns a dictionary with the parameters needed to retrieve the data.""" return _data_request(self.datasource) + @property + def variables_with_nans(self): + print("❌❌HERE") + return + def get_cube(self): trace("🧊", f"getting cube from {self.__class__.__name__}") ds = self.datasource diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index 244ff23..1b14a5f 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -7,6 +7,7 @@ import datetime import logging import os +import time import uuid from functools import cached_property @@ -16,12 +17,18 @@ from ecml_tools.data import open_dataset from ecml_tools.utils.dates.groups import Groups -from .check import DatasetName +from .check import DatasetName, check_data_values from .config import build_output, loader_config from .input import build_input -from .statistics import TempStatistics -from .utils import bytes, compute_directory_sizes, normalize_and_check_dates -from .writer import CubesFilter, DataWriter +from .statistics import TempStatistics, compute_statistics +from .utils import ( + bytes, + compute_directory_sizes, + normalize_and_check_dates, + progress_bar, + seconds, +) +from .writer import CubesFilter, ViewCacheArray from .zarr import ZarrBuiltRegistry, add_zarr_dataset LOG = logging.getLogger(__name__) @@ -97,6 +104,9 @@ def read_dataset_metadata(self): self.missing_dates = z.attrs.get("missing_dates", []) self.missing_dates = [np.datetime64(d) for d in self.missing_dates] + def allow_nan(self, name): + return name in self.main_config.get("has_nans", []) + @cached_property def registry(self): return ZarrBuiltRegistry(self.path) @@ -180,6 +190,8 @@ def initialise(self, check_name=True): variables = self.minimal_input.variables self.print(f"Found {len(variables)} variables : {','.join(variables)}.") + variables_with_nans = self.config.get("has_nans", []) + ensembles = self.minimal_input.ensembles self.print(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") @@ -222,6 +234,7 @@ def initialise(self, check_name=True): metadata["ensemble_dimension"] = len(ensembles) metadata["variables"] = variables + metadata["variables_with_nans"] = variables_with_nans metadata["resolution"] = resolution metadata["licence"] = self.main_config["licence"] @@ -291,7 +304,7 @@ def initialise(self, check_name=True): class ContentLoader(Loader): - def __init__(self, config, **kwargs): + def __init__(self, config, parts, **kwargs): super().__init__(**kwargs) self.main_config = loader_config(config) @@ -300,33 +313,115 @@ def __init__(self, config, **kwargs): self.input = self.build_input() self.read_dataset_metadata() - def load(self, parts): - self.registry.add_to_history("loading_data_start", parts=parts) + self.parts = parts + total = len(self.registry.get_flags()) + self.cube_filter = CubesFilter(parts=self.parts, total=total) - z = zarr.open(self.path, mode="r+") - data_writer = DataWriter(parts, full_array=z["data"], owner=self) + self.data_array = zarr.open(self.path, mode="r+")["data"] + self.n_groups = len(self.groups) + + def load(self): + self.registry.add_to_history("loading_data_start", parts=self.parts) - total = len(self.registry.get_flags()) - filter = CubesFilter(parts=parts, total=total) for igroup, group in enumerate(self.groups): + if not self.cube_filter(igroup): + continue if self.registry.get_flag(igroup): LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") continue - if not filter(igroup): - continue self.print(f" -> Processing {igroup} total={len(self.groups)}") print("========", group) assert isinstance(group[0], datetime.datetime), group result = self.input.select(dates=group) - data_writer.write(result, igroup, group) + assert result.dates == group, (len(result.dates), len(group)) + + msg = f"Building data for group {igroup}/{self.n_groups}" + LOG.info(msg) + self.print(msg) - self.registry.add_to_history("loading_data_end", parts=parts) + # There are several groups. + # There is one result to load for each group. + self.load_result(result) + self.registry.set_flag(igroup) + + self.registry.add_to_history("loading_data_end", parts=self.parts) self.registry.add_provenance(name="provenance_load") self.statistics_registry.add_provenance(name="provenance_load", config=self.main_config) self.print_info() + def load_result(self, result): + # There is one cube to load for each result. + dates = result.dates + + cube = result.get_cube() + assert cube.extended_user_shape[0] == len(dates), (cube.extended_user_shape[0], len(dates)) + + shape = cube.extended_user_shape + dates_in_data = cube.user_coords["valid_datetime"] + + LOG.info(f"Loading {shape=} in {self.data_array.shape=}") + + def check_dates_in_data(lst, lst2): + lst2 = list(lst2) + lst = [datetime.datetime.fromisoformat(_) for _ in lst] + assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2) + + check_dates_in_data(dates_in_data, dates) + + def dates_to_indexes(dates, all_dates): + x = np.array(dates, dtype=np.datetime64) + y = np.array(all_dates, dtype=np.datetime64) + bitmap = np.isin(x, y) + return np.where(bitmap)[0] + + indexes = dates_to_indexes(self.dates, dates_in_data) + + array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) + self.load_cube(cube, array) + + stats = compute_statistics(array.cache, self.variables_names, allow_nan=self.allow_nan) + self.statistics_registry.write(indexes, stats, dates=dates_in_data) + + array.flush() + + def load_cube(self, cube, array): + # There are several cubelets for each cube + start = time.time() + load = 0 + save = 0 + + reading_chunks = None + total = cube.count(reading_chunks) + self.print(f"Loading datacube: {cube}") + bar = progress_bar( + iterable=cube.iterate_cubelets(reading_chunks), + total=total, + desc=f"Loading datacube {cube}", + ) + for i, cubelet in enumerate(bar): + bar.set_description(f"Loading {i}/{total}") + + now = time.time() + data = cubelet.to_numpy() + local_indexes = cubelet.coords + load += time.time() - now + + name = self.variables_names[local_indexes[1]] + check_data_values(data[:], name=name, log=[i, data.shape, local_indexes], allow_nan=self.allow_nan) + + now = time.time() + array[local_indexes] = data + save += time.time() - now + + now = time.time() + save += time.time() - now + LOG.info("Written.") + msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}." + self.print(msg) + LOG.info(msg) + class StatisticsLoader(Loader): main_config = {} @@ -383,7 +478,7 @@ def _get_statistics_dates(self): def run(self): dates = self._get_statistics_dates() - stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + stats = self.statistics_registry.get_aggregated(dates, self.variables_names, self.allow_nan) self.output_writer(stats) def write_stats_to_file(self, stats): diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index ec6e1a8..d0c2edb 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -65,7 +65,7 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa raise ValueError("Negative variance") -def compute_statistics(array, check_variables_names=None): +def compute_statistics(array, check_variables_names=None, allow_nan=False): nvars = array.shape[1] print("Stats", nvars, array.shape, check_variables_names) @@ -81,22 +81,21 @@ def compute_statistics(array, check_variables_names=None): for i, chunk in enumerate(array): values = chunk.reshape((nvars, -1)) - minimum[i] = np.min(values, axis=1) - maximum[i] = np.max(values, axis=1) - sums[i] = np.sum(values, axis=1) - squares[i] = np.sum(values * values, axis=1) - count[i] = values.shape[1] for j, name in enumerate(check_variables_names): - check_data_values(values[j, :], name=name) + check_data_values(values[j, :], name=name, allow_nan=allow_nan) + if np.isnan(values[j, :]).all(): + # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") + raise ValueError(f"All NaN values for {name} ({j}) for date {i}") - return { - "minimum": minimum, - "maximum": maximum, - "sums": sums, - "squares": squares, - "count": count, - } + # Ignore NaN values + minimum[i] = np.nanmin(values, axis=1) + maximum[i] = np.nanmax(values, axis=1) + sums[i] = np.nansum(values, axis=1) + squares[i] = np.nansum(values * values, axis=1) + count[i] = np.sum(~np.isnan(values), axis=1) + + return {"minimum": minimum, "maximum": maximum, "sums": sums, "squares": squares, "count": count} class TempStatistics: @@ -148,8 +147,8 @@ def _gather_data(self): with open(f, "rb") as f: yield pickle.load(f) - def get_aggregated(self, dates, variables_names): - aggregator = StatAggregator(dates, variables_names, self) + def get_aggregated(self, *args, **kwargs): + aggregator = StatAggregator(self, *args, **kwargs) return aggregator.aggregate() def __str__(self): @@ -169,12 +168,13 @@ def normalise_dates(dates): class StatAggregator: NAMES = ["minimum", "maximum", "sums", "squares", "count"] - def __init__(self, dates, variables_names, owner): + def __init__(self, owner, dates, variables_names, allow_nan): dates = sorted(dates) dates = to_datetimes(dates) self.owner = owner self.dates = dates self.variables_names = variables_names + self.allow_nan = allow_nan self.shape = (len(self.dates), len(self.variables_names)) print("Aggregating statistics on ", self.shape, self.variables_names) @@ -187,12 +187,6 @@ def __init__(self, dates, variables_names, owner): self._read() - def _date_to_index(self, date): - date = to_datetime(date) - assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) - assert date in self.dates, f"Statistics for date {date} is not needed." - return np.where(self.dates == date)[0][0] - def _read(self): def check_type(a, b): a = list(a) @@ -247,20 +241,15 @@ def check_type(a, b): print(f"Statistics for {len(found)} dates found.") def aggregate(self): - for name in self.NAMES: - if name == "count": - continue - array = getattr(self, name) - assert not np.isnan(array).any(), (name, array) - - minimum = np.amin(self.minimum, axis=0) - maximum = np.amax(self.maximum, axis=0) - count = np.sum(self.count, axis=0) - sums = np.sum(self.sums, axis=0) - squares = np.sum(self.squares, axis=0) + + minimum = np.nanmin(self.minimum, axis=0) + maximum = np.nanmax(self.maximum, axis=0) + sums = np.nansum(self.sums, axis=0) + squares = np.nansum(self.squares, axis=0) + count = np.nansum(self.count, axis=0) mean = sums / count - assert all(count[0] == c for c in count), count + assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape x = squares / count - mean * mean # remove negative variance due to numerical errors @@ -268,6 +257,17 @@ def aggregate(self): check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) stdev = np.sqrt(x) + for j, name in enumerate(self.variables_names): + check_data_values( + np.array( + [ + mean[j], + ] + ), + name=name, + allow_nan=False, + ) + return Statistics( minimum=minimum, maximum=maximum, diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index 21f0a7a..d5f324b 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -13,9 +13,9 @@ import warnings import numpy as np +import zarr from .check import check_data_values -from .statistics import compute_statistics from .utils import progress_bar, seconds LOG = logging.getLogger(__name__) @@ -84,106 +84,3 @@ def flush(self): for i in range(self.cache.shape[0]): global_i = self.indexes[i] self.array[global_i] = self.cache[i] - - -class ReindexFirst: - def __init__(self, indexes): - self.indexes = indexes - - def __call__(self, first, *others): - if isinstance(first, int): - return (self.indexes[first], *others) - - if isinstance(first, slice): - start, stop, step = first.start, first.stop, first.step - start = self.indexes[start] - stop = self.indexes[stop] - return (slice(start, stop, step), *others) - if isinstance(first, tuple): - return ([self.indexes[_] for _ in first], *others) - - raise NotImplementedError(type(first)) - - -class DataWriter: - def __init__(self, parts, full_array, owner): - self.full_array = full_array - - self.path = owner.path - self.statistics_registry = owner.statistics_registry - self.registry = owner.registry - self.print = owner.print - self.dates = owner.dates - self.variables_names = owner.variables_names - - self.append_axis = owner.output.append_axis - self.n_groups = len(owner.groups) - - def date_to_index(self, date): - if isinstance(date, str): - date = np.datetime64(date) - if isinstance(date, datetime.datetime): - date = np.datetime64(date) - assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) - return np.where(self.dates == date)[0][0] - - def write(self, result, igroup, dates): - cube = result.get_cube() - assert cube.extended_user_shape[0] == len(dates), (cube.extended_user_shape[0], len(dates)) - dates_in_data = cube.user_coords["valid_datetime"] - dates_in_data = [datetime.datetime.fromisoformat(_) for _ in dates_in_data] - assert dates_in_data == list(dates), (dates_in_data, list(dates)) - - assert isinstance(igroup, int), igroup - - shape = cube.extended_user_shape - - msg = f"Building data for group {igroup}/{self.n_groups} ({shape=} in {self.full_array.shape=})" - LOG.info(msg) - self.print(msg) - - indexes = [self.date_to_index(d) for d in dates_in_data] - array = ViewCacheArray(self.full_array, shape=shape, indexes=indexes) - self.load_datacube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names) - dates = cube.user_coords["valid_datetime"] - self.statistics_registry.write(indexes, stats, dates=dates) - - array.flush() - self.registry.set_flag(igroup) - - def load_datacube(self, cube, array): - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - self.print(f"Loading datacube: {cube}") - bar = progress_bar( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - ) - for i, cubelet in enumerate(bar): - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values(data[:], name=name, log=[i, data.shape, local_indexes]) - - bar.set_description(f"Loading {i}/{total} {name} {str(cubelet)} ({data.shape})") - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.info("Written.") - msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}." - self.print(msg) - LOG.info(msg) diff --git a/tests/create-nan.yaml b/tests/create-nan.yaml new file mode 100644 index 0000000..7b708b7 --- /dev/null +++ b/tests/create-nan.yaml @@ -0,0 +1,34 @@ +description: "testing for nans using sst" +dataset_status: testing +purpose: aifs +name: test-nan +config_format_version: 2 + +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly + +input: + mars: + expver: "0001" + class: ea + grid: 20./20. + param: [2t, sst] + levtype: sfc + stream: oper + type: an + +#allow_nans: True +has_nans: [sst] + +output: + chunking: { dates: 1, ensembles: 1 } + dtype: float32 + flatten_grid: True + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level + statistics_end: 2020