From 3e4a00ed20232be48735212824d878cf94cf7eb2 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 29 Feb 2024 17:39:14 +0000 Subject: [PATCH] statistics --- ecml_tools/create/__init__.py | 27 +----- ecml_tools/create/loaders.py | 166 ++++++++++---------------------- ecml_tools/create/statistics.py | 52 +++------- ecml_tools/create/writer.py | 29 +++--- 4 files changed, 88 insertions(+), 186 deletions(-) diff --git a/ecml_tools/create/__init__.py b/ecml_tools/create/__init__.py index 0c22af9..aefe64e 100644 --- a/ecml_tools/create/__init__.py +++ b/ecml_tools/create/__init__.py @@ -35,9 +35,7 @@ def init(self, check_name=False): from .loaders import InitialiseLoader if self._path_readable() and not self.overwrite: - raise Exception( - f"{self.path} already exists. Use overwrite=True to overwrite." - ) + raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") with self._cache_context(): obj = InitialiseLoader.from_config( @@ -57,11 +55,7 @@ def load(self, parts=None): ) loader.load(parts=parts) - def statistics( - self, - force=False, - output=None, - ): + def statistics(self, force=False, output=None, start=None, end=None): from .loaders import StatisticsLoader loader = StatisticsLoader.from_dataset( @@ -71,25 +65,8 @@ def statistics( statistics_tmp=self.statistics_tmp, statistics_output=output, recompute=False, - ) - loader.run() - - def recompute_statistics( - self, - start=None, - end=None, - force=False, - ): - from .loaders import StatisticsLoader - - loader = StatisticsLoader.from_dataset( - path=self.path, - print=self.print, - force=force, - statistics_tmp=self.statistics_tmp, statistics_start=start, statistics_end=end, - recompute=True, ) loader.run() diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index 5e15022..244ff23 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -20,12 +20,7 @@ from .config import build_output, loader_config from .input import build_input from .statistics import TempStatistics -from .utils import ( - bytes, - compute_directory_sizes, - normalize_and_check_dates, - progress_bar, -) +from .utils import bytes, compute_directory_sizes, normalize_and_check_dates from .writer import CubesFilter, DataWriter from .zarr import ZarrBuiltRegistry, add_zarr_dataset @@ -82,6 +77,15 @@ def build_input(self): print(builder) return builder + def build_statistics_dates(self, start, end): + ds = open_dataset(self.path) + subset = ds.dates_interval_to_indices(start, end) + start, end = ds.dates[subset[0]], ds.dates[subset[-1]] + return ( + start.astype(datetime.datetime).isoformat(), + end.astype(datetime.datetime).isoformat(), + ) + def read_dataset_metadata(self): ds = open_dataset(self.path) self.dataset_shape = ds.shape @@ -90,21 +94,8 @@ def read_dataset_metadata(self): self.dates = ds.dates z = zarr.open(self.path, "r") - self.missing_dates = z.attrs.get("missing_dates") - if self.missing_dates: - self.missing_dates = [np.datetime64(d) for d in self.missing_dates] - assert type(self.missing_dates[0]) == type(self.dates[0]), ( - self.missing_dates[0], - self.dates[0], - ) - - def date_to_index(self, date): - if isinstance(date, str): - date = np.datetime64(date) - if isinstance(date, datetime.datetime): - date = np.datetime64(date) - assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) - return np.where(self.dates == date)[0][0] + self.missing_dates = z.attrs.get("missing_dates", []) + self.missing_dates = [np.datetime64(d) for d in self.missing_dates] @cached_property def registry(self): @@ -284,7 +275,7 @@ def initialise(self, check_name=True): self.statistics_registry.create(exist_ok=False) self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version) - statistics_start, statistics_end = self._build_statistics_dates( + statistics_start, statistics_end = self.build_statistics_dates( self.main_config.output.get("statistics_start"), self.main_config.output.get("statistics_end"), ) @@ -298,15 +289,6 @@ def initialise(self, check_name=True): assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks()) - def _build_statistics_dates(self, start, end): - ds = open_dataset(self.path) - subset = ds.dates_interval_to_indices(start, end) - start, end = ds.dates[subset[0]], ds.dates[subset[-1]] - return ( - start.astype(datetime.datetime).isoformat(), - end.astype(datetime.datetime).isoformat(), - ) - class ContentLoader(Loader): def __init__(self, config, **kwargs): @@ -322,7 +304,7 @@ def load(self, parts): self.registry.add_to_history("loading_data_start", parts=parts) z = zarr.open(self.path, mode="r+") - data_writer = DataWriter(parts, parent=self, full_array=z["data"], print=self.print) + data_writer = DataWriter(parts, full_array=z["data"], owner=self) total = len(self.registry.get_flags()) filter = CubesFilter(parts=parts, total=total) @@ -356,112 +338,70 @@ def __init__( statistics_start=None, statistics_end=None, force=False, - recompute=False, **kwargs, ): super().__init__(**kwargs) - assert statistics_start is None, statistics_start - assert statistics_end is None, statistics_end - - self.recompute = recompute - - self._write_to_dataset = True + self.user_statistics_start = statistics_start + self.user_statistics_end = statistics_end self.statistics_output = statistics_output + self.output_writer = { + None: self.write_stats_to_dataset, + "-": self.write_stats_to_stdout, + }.get(self.statistics_output, self.write_stats_to_file) + if config: self.main_config = loader_config(config) - self.check_complete(force=force) self.read_dataset_metadata() - def run(self): - # if requested, recompute statistics from data - # into the temporary statistics directory - # (this should have been done already when creating the dataset content) - if self.recompute: - self.recompute_temporary_statistics() - - dates = [d for d in self.dates if d not in self.missing_dates] + def _get_statistics_dates(self): + dates = self.dates + dtype = type(dates[0]) + # remove missing dates if self.missing_dates: - assert type(self.missing_dates[0]) is type(dates[0]), (type(self.missing_dates[0]), type(dates[0])) - - dates_computed = self.statistics_registry.dates_computed - for d in dates: - if d in self.missing_dates: - assert d not in dates_computed, (d, date_computed) - else: - assert d in dates_computed, (d, dates_computed) + assert type(self.missing_dates[0]) is dtype, (type(self.missing_dates[0]), dtype) + dates = [d for d in dates if d not in self.missing_dates] + # filter dates according the the start and end dates in the metadata z = zarr.open(self.path, mode="r") - start = z.attrs.get("statistics_start_date") - end = z.attrs.get("statistics_end_date") - start = np.datetime64(start) - end = np.datetime64(end) + start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date") + start, end = np.datetime64(start), np.datetime64(end) + assert type(start) is dtype, (type(start), dtype) dates = [d for d in dates if d >= start and d <= end] - assert type(start) is type(dates[0]), (type(start), type(dates[0])) - stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + # filter dates according the the user specified start and end dates + if self.user_statistics_start or self.user_statistics_end: + start, end = self.build_statistics_dates(self.user_statistics_start, self.user_statistics_end) + start, end = np.datetime64(start), np.datetime64(end) + assert type(start) is dtype, (type(start), dtype) + dates = [d for d in dates if d >= start and d <= end] - writer = { - None: self.write_stats_to_dataset, - "-": self.write_stats_to_stdout, - }.get(self.statistics_output, self.write_stats_to_file) - writer(stats) - - def check_complete(self, force): - if self._complete: - return - if not force: - raise Exception(f"❗Zarr {self.path} is not fully built. Use 'force' option.") - if self._write_to_dataset: - print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") - self._write_to_dataset = False - - @property - def _complete(self): - return all(self.registry.get_flags(sync=False)) - - def recompute_temporary_statistics(self): - raise NotImplementedError("Untested code") - self.statistics_registry.create(exist_ok=True) - - self.print( - f"Building temporary statistics from data {self.path}. " f"From {self.date_start} to {self.date_end}" - ) - - shape = (self.i_end + 1 - self.i_start, len(self.variables_names)) - detailed_stats = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - ) + return dates - ds = open_dataset(self.path) - key = (slice(self.i_start, self.i_end + 1), slice(None, None)) - for i in progress_bar( - desc="Computing Statistics", - iterable=range(self.i_start, self.i_end + 1), - ): - i_ = i - self.i_start - data = ds[slice(i, i + 1), :] - one = compute_statistics(data, self.variables_names) - for k, v in one.items(): - detailed_stats[k][i_] = v - - print(f"✅ Saving statistics for {key} shape={detailed_stats['count'].shape}") - self.statistics_registry[key] = detailed_stats - self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config) + def run(self): + dates = self._get_statistics_dates() + stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + self.output_writer(stats) def write_stats_to_file(self, stats): stats.save(self.statistics_output, provenance=dict(config=self.main_config)) print(f"✅ Statistics written in {self.statistics_output}") - return def write_stats_to_dataset(self, stats): + if self.user_statistics_start or self.user_statistics_end: + raise ValueError( + ( + "Cannot write statistics in dataset with user specified dates. " + "This would be conflicting with the dataset metadata." + ) + ) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]: self._add_dataset(name=k, array=stats[k]) diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 6d4cbea..9b0a58b 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -15,8 +15,7 @@ import pickle import shutil import socket -from collections import Counter, defaultdict -from functools import cached_property +from collections import defaultdict import numpy as np @@ -125,12 +124,9 @@ def delete(self): except FileNotFoundError: pass - def _hash_key(self, key): - return hashlib.sha256(str(key).encode("utf-8")).hexdigest() - def write(self, key, data, dates): self.create(exist_ok=True) - h = self._hash_key(dates) + h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest() path = os.path.join(self.dirname, f"{h}.npz") if not self.overwrite: @@ -146,27 +142,12 @@ def write(self, key, data, dates): def _gather_data(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.npz") - LOG.info(f"Reading stats data, found {len(files)} in {self.dirname}") + LOG.info(f"Reading stats data, found {len(files)} files in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" for f in files: with open(f, "rb") as f: yield pickle.load(f) - @property - def dates_computed(self): - all_dates = [] - for _, dates, data in self._gather_data(): - all_dates += dates - - # assert no duplicates - duplicates = [item for item, count in Counter(all_dates).items() if count > 1] - if duplicates: - raise StatisticsValueError(f"Duplicate dates found in statistics: {duplicates}") - - all_dates = normalise_dates(all_dates) - all_dates = sorted(all_dates) - return all_dates - def get_aggregated(self, dates, variables_names): aggregator = StatAggregator(dates, variables_names, self) return aggregator.aggregate() @@ -245,13 +226,13 @@ def _read(self): for d in self.dates: assert d in available_dates, f"Statistics for date {d} not precomputed." assert len(available_dates) == len(self.dates) + print(f"Statistics for {len(available_dates)} dates found.") def aggregate(self): if not np.all(self.flags): not_found = np.where(self.flags == False) # noqa: E712 raise Exception(f"Statistics not precomputed for {not_found}", not_found) - print(f"Aggregating statistics on {self.minimum.shape}") for name in self.NAMES: if name == "count": continue @@ -273,7 +254,7 @@ def aggregate(self): check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) stdev = np.sqrt(x) - stats = Statistics( + return Statistics( minimum=minimum, maximum=maximum, mean=mean, @@ -284,16 +265,13 @@ def aggregate(self): variables_names=self.variables_names, ) - return stats - class Statistics(dict): STATS_NAMES = ["minimum", "maximum", "mean", "stdev"] # order matter for __str__. - def __init__(self, *args, check=True, **kwargs): - super().__init__(*args, **kwargs) - if check: - self.check() + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.check() @property def size(self): @@ -324,12 +302,14 @@ def check(self): raise def __str__(self): - header = ["Variables"] + [self[name] for name in self.STATS_NAMES] - out = " ".join(header) - - for i, v in enumerate(self["variables_names"]): - out += " ".join([v] + [f"{x[i]:.2f}" for x in self.values()]) - return out + header = ["Variables"] + self.STATS_NAMES + out = [" ".join(header)] + + out += [ + " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) + for i, v in enumerate(self["variables_names"]) + ] + return "\n".join(out) def save(self, filename, provenance=None): assert filename.endswith(".json"), filename diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index 322834e..21f0a7a 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -106,21 +106,26 @@ def __call__(self, first, *others): class DataWriter: - def __init__(self, parts, full_array, parent, print=print): - self.parent = parent + def __init__(self, parts, full_array, owner): self.full_array = full_array - self.path = parent.path - self.statistics_registry = parent.statistics_registry - self.registry = parent.registry - self.print = parent.print + self.path = owner.path + self.statistics_registry = owner.statistics_registry + self.registry = owner.registry + self.print = owner.print + self.dates = owner.dates + self.variables_names = owner.variables_names - self.append_axis = parent.output.append_axis - self.n_groups = len(parent.groups) + self.append_axis = owner.output.append_axis + self.n_groups = len(owner.groups) - @property - def variables_names(self): - return self.parent.variables_names + def date_to_index(self, date): + if isinstance(date, str): + date = np.datetime64(date) + if isinstance(date, datetime.datetime): + date = np.datetime64(date) + assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) + return np.where(self.dates == date)[0][0] def write(self, result, igroup, dates): cube = result.get_cube() @@ -137,7 +142,7 @@ def write(self, result, igroup, dates): LOG.info(msg) self.print(msg) - indexes = [self.parent.date_to_index(d) for d in dates_in_data] + indexes = [self.date_to_index(d) for d in dates_in_data] array = ViewCacheArray(self.full_array, shape=shape, indexes=indexes) self.load_datacube(cube, array)