diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index a4101ca..eca0d71 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -19,17 +19,12 @@ from .check import DatasetName from .config import build_output, loader_config from .input import build_input -from .statistics import ( - StatisticsRegistry, - compute_aggregated_statistics, - compute_statistics, -) +from .statistics import TempStatistics from .utils import ( bytes, compute_directory_sizes, normalize_and_check_dates, progress_bar, - to_datetime, ) from .writer import CubesFilter, DataWriter from .zarr import ZarrBuiltRegistry, add_zarr_dataset @@ -52,10 +47,7 @@ def __init__(self, *, path, print=print, **kwargs): statistics_tmp = kwargs.get("statistics_tmp") or self.path + ".statistics" - self.statistics_registry = StatisticsRegistry( - statistics_tmp, - history_callback=self.registry.add_to_history, - ) + self.statistics_registry = TempStatistics(statistics_tmp) @classmethod def from_config(cls, *, config, path, print=print, **kwargs): @@ -94,16 +86,17 @@ def read_dataset_metadata(self): ds = open_dataset(self.path) self.dataset_shape = ds.shape self.variables_names = ds.variables + assert len(self.variables_names) == ds.shape[1], self.dataset_shape + self.dates = ds.dates z = zarr.open(self.path, "r") - start = z.attrs.get("statistics_start_date") - end = z.attrs.get("statistics_end_date") - if start: - start = to_datetime(start) - if end: - end = to_datetime(end) - self._statistics_start_date_from_dataset = start - self._statistics_end_date_from_dataset = end + self.missing_dates = z.attrs.get("missing_dates") + if self.missing_dates: + self.missing_dates = [np.datetime64(d) for d in self.missing_dates] + assert type(self.missing_dates[0]) == type(self.dates[0]), ( + self.missing_dates[0], + self.dates[0], + ) @cached_property def registry(self): @@ -283,10 +276,29 @@ def initialise(self, check_name=True): self.statistics_registry.create(exist_ok=False) self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version) + statistics_start, statistics_end = self._build_statistics_dates( + self.main_config.output.get("statistics_start"), + self.main_config.output.get("statistics_end"), + ) + self.update_metadata( + statistics_start_date=statistics_start, + statistics_end_date=statistics_end, + ) + print(f"Will compute statistics from {statistics_start} to {statistics_end}") + self.registry.add_to_history("init finished") assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks()) + def _build_statistics_dates(self, start, end): + ds = open_dataset(self.path) + subset = ds.dates_interval_to_indices(start, end) + start, end = ds.dates[subset[0]], ds.dates[subset[-1]] + return ( + start.astype(datetime.datetime).isoformat(), + end.astype(datetime.datetime).isoformat(), + ) + class ContentLoader(Loader): def __init__(self, config, **kwargs): @@ -340,24 +352,20 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + assert statistics_start is None, statistics_start + assert statistics_end is None, statistics_end + self.recompute = recompute self._write_to_dataset = True self.statistics_output = statistics_output - if self.statistics_output: - self._write_to_dataset = False if config: self.main_config = loader_config(config) - self._statistics_start = statistics_start - self._statistics_end = statistics_end - self.check_complete(force=force) - self.read_dataset_metadata() - self.read_dataset_dates_metadata() def run(self): # if requested, recompute statistics from data @@ -366,19 +374,33 @@ def run(self): if self.recompute: self.recompute_temporary_statistics() - # compute the detailed statistics from temporary statistics directory - detailed = self.get_detailed_stats() + dates = [d for d in self.dates if d not in self.missing_dates] - if self._write_to_dataset: - self.write_detailed_statistics(detailed) + if self.missing_dates: + assert type(self.missing_dates[0]) == type(dates[0]), (type(self.missing_dates[0]), type(dates[0])) - # compute the aggregated statistics from the detailed statistics - # for the selected dates - selected = {k: v[self.i_start : self.i_end + 1] for k, v in detailed.items()} - stats = compute_aggregated_statistics(selected, self.variables_names) + dates_computed = self.statistics_registry.dates_computed + for d in dates: + if d in self.missing_dates: + assert d not in dates_computed, (d, date_computed) + else: + assert d in dates_computed, (d, dates_computed) - if self._write_to_dataset: - self.write_aggregated_statistics(stats) + z = zarr.open(self.path, mode="r") + start = z.attrs.get("statistics_start_date") + end = z.attrs.get("statistics_end_date") + start = np.datetime64(start) + end = np.datetime64(end) + dates = [d for d in dates if d >= start and d <= end] + assert type(start) == type(dates[0]), (type(start), type(dates[0])) + + stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + + writer = { + None: self.write_stats_to_dataset, + "-": self.write_stats_to_stdout, + }.get(self.statistics_output, self.write_stats_to_file) + writer(stats) def check_complete(self, force): if self._complete: @@ -389,57 +411,12 @@ def check_complete(self, force): print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") self._write_to_dataset = False - @property - def statistics_start(self): - user = self._statistics_start - config = self.main_config.get("output", {}).get("statistics_start") - dataset = self._statistics_start_date_from_dataset - return user or config or dataset - - @property - def statistics_end(self): - user = self._statistics_end - config = self.main_config.get("output", {}).get("statistics_end") - dataset = self._statistics_end_date_from_dataset - return user or config or dataset - @property def _complete(self): return all(self.registry.get_flags(sync=False)) - def read_dataset_dates_metadata(self): - ds = open_dataset(self.path) - subset = ds.dates_interval_to_indices(self.statistics_start, self.statistics_end) - self.i_start = subset[0] - self.i_end = subset[-1] - self.date_start = ds.dates[subset[0]] - self.date_end = ds.dates[subset[-1]] - - # do not write statistics to dataset if dates do not match the ones in the dataset metadata - start = self._statistics_start_date_from_dataset - end = self._statistics_end_date_from_dataset - - start_ok = start is None or to_datetime(self.date_start) == start - end_ok = end is None or to_datetime(self.date_end) == end - if not (start_ok and end_ok): - print( - f"Statistics start/end dates {self.date_start}/{self.date_end} " - f"do not match dates in the dataset metadata {start}/{end}. " - f"Will not write statistics to dataset." - ) - self._write_to_dataset = False - - def check(): - i_len = self.i_end + 1 - self.i_start - self.print(f"Statistics computed on {i_len}/{len(ds.dates)} samples ") - print(f"Requested ({i_len}): from {self.date_start} to {self.date_end}.") - print(f"Available ({len(ds.dates)}): from {ds.dates[0]} to {ds.dates[-1]}.") - if i_len < 1: - raise ValueError("Cannot compute statistics on an empty interval.") - - check() - def recompute_temporary_statistics(self): + raise NotImplementedError("Untested code") self.statistics_registry.create(exist_ok=True) self.print( @@ -471,67 +448,21 @@ def recompute_temporary_statistics(self): self.statistics_registry[key] = detailed_stats self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config) - def get_detailed_stats(self): - expected_shape = (self.dataset_shape[0], self.dataset_shape[1]) - try: - return self.statistics_registry.as_detailed_stats(expected_shape) - except self.statistics_registry.MissingDataException as e: - missing_index = e.args[1] - dates = open_dataset(self.path).dates - missing_dates = dates[missing_index[0]] - print( - f"Missing dates: " - f"{missing_dates[0]} ... {missing_dates[len(missing_dates)-1]} " - f"({missing_dates.shape[0]} missing)" - ) - raise - - def write_detailed_statistics(self, detailed_stats): - z = zarr.open(self.path)["_build"] - for k, v in detailed_stats.items(): - if k == "variables_names": - continue - add_zarr_dataset(zarr_root=z, name=k, array=v) - print("Wrote detailed statistics to zarr.") - - def write_aggregated_statistics(self, stats): - if self.statistics_output == "-": - print(stats) - return - - if self.statistics_output: - stats.save(self.statistics_output, provenance=dict(config=self.main_config)) - print(f"✅ Statistics written in {self.statistics_output}") - return - - if not self._write_to_dataset: - return + def write_stats_to_file(self, stats): + stats.save(self.statistics_output, provenance=dict(config=self.main_config)) + print(f"✅ Statistics written in {self.statistics_output}") + return - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - ]: + def write_stats_to_dataset(self, stats): + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]: self._add_dataset(name=k, array=stats[k]) - self.update_metadata( - statistics_start_date=str(self.date_start), - statistics_end_date=str(self.date_end), - ) - - self.registry.add_to_history( - "compute_statistics_end", - start=str(self.date_start), - end=str(self.date_end), - i_start=self.i_start, - i_end=self.i_end, - ) + self.registry.add_to_history("compute_statistics_end") print(f"Wrote statistics in {self.path}") + def write_stats_to_stdout(self, stats): + print(stats) + class SizeLoader(Loader): def __init__(self, path, print): diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 03975e6..75d75bd 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -1,4 +1,4 @@ -# (C) Copyright 2023 ECMWF. +# (C) Copyright 2024 ECMWF. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -13,7 +13,9 @@ import pickle import shutil import socket -from collections import defaultdict +from collections import defaultdict, Counter +from functools import cached_property + import numpy as np @@ -24,46 +26,94 @@ LOG = logging.getLogger(__name__) -class Registry: - # names = [ "mean", "stdev", "minimum", "maximum", "sums", "squares", "count", ] - # build_names = [ "minimum", "maximum", "sums", "squares", "count", ] - version = 2 +def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares): + if (x >= 0).all(): + return + print(x) + print(variables_names) + print(count) + for i, (var, y) in enumerate(zip(variables_names, x)): + if y >= 0: + continue + print( + var, + y, + maximum[i], + minimum[i], + mean[i], + count[i], + sums[i], + squares[i], + ) + + print(var, np.min(sums[i]), np.max(sums[i]), np.argmin(sums[i])) + print(var, np.min(squares[i]), np.max(squares[i]), np.argmin(squares[i])) + print(var, np.min(count[i]), np.max(count[i]), np.argmin(count[i])) + + raise ValueError("Negative variance") + + +def compute_statistics(array, check_variables_names=None): + nvars = array.shape[1] + + print("Stats", nvars, array.shape, check_variables_names) + if check_variables_names: + assert nvars == len(check_variables_names), (nvars, check_variables_names) + stats_shape = (array.shape[0], nvars) + + count = np.zeros(stats_shape, dtype=np.int64) + sums = np.zeros(stats_shape, dtype=np.float64) + squares = np.zeros(stats_shape, dtype=np.float64) + minimum = np.zeros(stats_shape, dtype=np.float64) + maximum = np.zeros(stats_shape, dtype=np.float64) + + for i, chunk in enumerate(array): + values = chunk.reshape((nvars, -1)) + minimum[i] = np.min(values, axis=1) + maximum[i] = np.max(values, axis=1) + sums[i] = np.sum(values, axis=1) + squares[i] = np.sum(values * values, axis=1) + count[i] = values.shape[1] - def __init__(self, dirname, history_callback=None, overwrite=False): - if history_callback is None: + for j, name in enumerate(check_variables_names): + check_data_values(values[j, :], name=name) - def dummy(*args, **kwargs): - pass + return { + "minimum": minimum, + "maximum": maximum, + "sums": sums, + "squares": squares, + "count": count, + } - history_callback = dummy +class TempStatistics: + version = 3 + # Used in parrallel, during data loading, + # to write statistics in pickled npz files. + # can provide statistics for a subset of dates. + + def __init__(self, dirname, overwrite=False): self.dirname = dirname self.overwrite = overwrite - self.history_callback = history_callback - - def create(self, exist_ok): - os.makedirs(self.dirname, exist_ok=exist_ok) - def add_provenance(self, name="provenance", **kwargs): + def add_provenance(self, **kwargs): + self.create(exist_ok=True) out = dict(provenance=gather_provenance_info(), **kwargs) - with open(os.path.join(self.dirname, f"{name}.json"), "w") as f: + with open(os.path.join(self.dirname, "provenance.json"), "w") as f: json.dump(out, f) + def create(self, exist_ok): + os.makedirs(self.dirname, exist_ok=exist_ok) + def delete(self): try: shutil.rmtree(self.dirname) except FileNotFoundError: pass - def __setitem__(self, key, data): - # if isinstance(key, slice): - # # this is just to make the filenames nicer. - # key_str = f"{key.start}_{key.stop}" - # if key.step is not None: - # key_str = f"{key_str}_{key.step}" - # else: - # key_str = str(key_str) - + def write(self, key, data, dates): + self.create(exist_ok=True) key_str = ( str(key) .replace("(", "") @@ -81,159 +131,144 @@ def __setitem__(self, key, data): tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" with open(tmp_path, "wb") as f: - pickle.dump((key, data), f) + pickle.dump((key, dates, data), f) shutil.move(tmp_path, path) - LOG.info(f"Written {self.name} data for {key} in {path}") + LOG.info(f"Written statistics data for {key} in {path} ({dates})") - def __iter__(self): + def _gather_data(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.npz") - - LOG.info( - f"Reading {self.name} data, found {len(files)} for {self.name} in {self.dirname}" - ) - + LOG.info(f"Reading stats data, found {len(files)} in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" key_strs = dict() for f in files: with open(f, "rb") as f: - key, data = pickle.load(f) + key, dates, data = pickle.load(f) key_str = str(key) if key_str in key_strs: - raise Exception( - f"Duplicate key {key}, found in {f} and {key_strs[key_str]}" - ) + raise Exception(f"Duplicate key {key}, found in {f} and {key_strs[key_str]}") key_strs[key_str] = f - yield key, data + yield key, dates, data - def __str__(self): - return f"Registry({self.dirname})" + @cached_property + def n_dates_computed(self): + return len(self.dates_computed) + @property + def dates_computed(self): + all_dates = [] + for key, dates, data in self._gather_data(): + all_dates += dates -class MissingDataException(Exception): - pass + # assert no duplicates + duplicates = [item for item, count in Counter(all_dates).items() if count > 1] + if duplicates: + raise StatisticsValueError(f"Duplicate dates found in statistics: {duplicates}") + all_dates = normalise_dates(all_dates) + return all_dates -class StatisticsRegistry(Registry): - name = "statistics" + def get_aggregated(self, dates, variables_names): + aggregator = StatAggregator(variables_names, self) + aggregator.read(dates) + return aggregator.aggregate() - MissingDataException = MissingDataException + def __str__(self): + return f"TempStatistics({self.dirname})" - def as_detailed_stats(self, shape): - detailed_stats = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - ) - flags = np.full(shape, False, dtype=np.bool_) - for key, data in self: - assert isinstance(data, dict), data - assert not np.any(flags[key]), f"Overlapping values for {key} {flags}" - flags[key] = True - for name, array in detailed_stats.items(): - d = data[name] - array[key] = d - if not np.all(flags): - missing_indexes = np.where(flags == False) # noqa: E712 - raise self.MissingDataException( - f"Missing statistics data for {missing_indexes}", missing_indexes - ) - - return detailed_stats - - -def compute_aggregated_statistics(data, variables_names): - i_len = None - for name, array in data.items(): - if i_len is None: - i_len = len(array) - assert len(array) == i_len, (name, len(array), i_len) - - for name, array in data.items(): - if name == "count": - continue - assert not np.isnan(array).any(), (name, array) - - # for i in range(0, i_len): - # for j in range(len(variables_names)): - # stats = Statistics( - # minimum=data["minimum"][i,:], - # maximum=data["maximum"][i,:], - # mean=data["mean"][i,:], - # stdev=data["stdev"][i,:], - # variables_names=variables_names, - # ) - - _minimum = np.amin(data["minimum"], axis=0) - _maximum = np.amax(data["maximum"], axis=0) - _count = np.sum(data["count"], axis=0) - _sums = np.sum(data["sums"], axis=0) - _squares = np.sum(data["squares"], axis=0) - _mean = _sums / _count - - assert all(_count[0] == c for c in _count), _count - - x = _squares / _count - _mean * _mean - # remove negative variance due to numerical errors - # x[- 1e-15 < (x / (np.sqrt(_squares / _count) + np.abs(_mean))) < 0] = 0 - check_variance_is_positive( - x, variables_names, _minimum, _maximum, _mean, _count, _sums, _squares - ) - _stdev = np.sqrt(x) - - stats = Statistics( - minimum=_minimum, - maximum=_maximum, - mean=_mean, - count=_count, - sums=_sums, - squares=_squares, - stdev=_stdev, - variables_names=variables_names, - ) - - return stats +def normalise_date(d): + if isinstance(d, str): + d = np.datetime64(d) + return d -def compute_statistics(array, check_variables_names=None): - nvars = array.shape[1] +def normalise_dates(dates): + return [normalise_date(d) for d in dates] - print(nvars, array.shape, check_variables_names) - if check_variables_names: - assert nvars == len(check_variables_names), (nvars, check_variables_names) - stats_shape = (array.shape[0], nvars) - count = np.zeros(stats_shape, dtype=np.int64) - sums = np.zeros(stats_shape, dtype=np.float64) - squares = np.zeros(stats_shape, dtype=np.float64) - minimum = np.zeros(stats_shape, dtype=np.float64) - maximum = np.zeros(stats_shape, dtype=np.float64) +class StatAggregator: + NAMES = ["minimum", "maximum", "sums", "squares", "count"] - for i, chunk in enumerate(array): - values = chunk.reshape((nvars, -1)) - minimum[i] = np.min(values, axis=1) - maximum[i] = np.max(values, axis=1) - sums[i] = np.sum(values, axis=1) - squares[i] = np.sum(values * values, axis=1) - count[i] = values.shape[1] + def __init__(self, variables_names, owner): + self.owner = owner + self.computed_dates = owner.dates_computed + self.shape = (len(self.computed_dates), len(variables_names)) + self.variables_names = variables_names + print("Aggregating on ", self.shape, variables_names) - for j, name in enumerate(check_variables_names): - check_data_values(values[j, :], name=name) + self.minimum = np.full(self.shape, np.nan, dtype=np.float64) + self.maximum = np.full(self.shape, np.nan, dtype=np.float64) + self.sums = np.full(self.shape, np.nan, dtype=np.float64) + self.squares = np.full(self.shape, np.nan, dtype=np.float64) + self.count = np.full(self.shape, -1, dtype=np.int64) + self.flags = np.full(self.shape, False, dtype=np.bool_) - return { - "minimum": minimum, - "maximum": maximum, - "sums": sums, - "squares": squares, - "count": count, - } + def read(self, dates): + assert type(dates[0]) == type(self.computed_dates[0]), ( + dates[0], + self.computed_dates[0], + ) + + dates_bitmap = np.isin(self.computed_dates, dates) + + for key, dates, data in self.owner._gather_data(): + assert isinstance(data, dict), data + assert not np.any(self.flags[key]), f"Overlapping values for {key} {self.flags} ({dates})" + self.flags[key] = True + for name in self.NAMES: + array = getattr(self, name) + array[key] = data[name] + + if not np.all(self.flags[dates_bitmap]): + not_found = np.where(self.flags == False) # noqa: E712 + raise Exception(f"Missing statistics data for {not_found}", not_found) + + print(f"Selection statistics data from {self.minimum.shape[0]} to {self.minimum[dates_bitmap].shape[0]} dates.") + for name in self.NAMES: + array = getattr(self, name) + array = array[dates_bitmap] + setattr(self, name, array) + + def aggregate(self): + print(f"Aggregating statistics on {self.minimum.shape}") + for name in self.NAMES: + if name == "count": + continue + array = getattr(self, name) + assert not np.isnan(array).any(), (name, array) + + minimum = np.amin(self.minimum, axis=0) + maximum = np.amax(self.maximum, axis=0) + count = np.sum(self.count, axis=0) + sums = np.sum(self.sums, axis=0) + squares = np.sum(self.squares, axis=0) + mean = sums / count + + assert all(count[0] == c for c in count), count + + x = squares / count - mean * mean + # remove negative variance due to numerical errors + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) + stdev = np.sqrt(x) + + stats = Statistics( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables_names, + ) + + return stats class Statistics(dict): @@ -298,7 +333,7 @@ def save(self, filename, provenance=None): def load(self, filename): assert filename.endswith(".json"), filename - with open(filename, "r") as f: + with open(filename) as f: dic = json.load(f) dic_ = {} @@ -311,32 +346,3 @@ def load(self, filename): continue dic_[k] = np.array(v, dtype=np.float64) return Statistics(dic_) - - -def check_variance_is_positive( - x, variables_names, minimum, maximum, mean, count, sums, squares -): - if (x >= 0).all(): - return - print(x) - print(variables_names) - print(count) - for i, (var, y) in enumerate(zip(variables_names, x)): - if y >= 0: - continue - print( - var, - y, - maximum[i], - minimum[i], - mean[i], - count[i], - sums[i], - squares[i], - ) - - print(var, np.min(sums[i]), np.max(sums[i]), np.argmin(sums[i])) - print(var, np.min(squares[i]), np.max(squares[i]), np.argmin(squares[i])) - print(var, np.min(count[i]), np.max(count[i]), np.argmin(count[i])) - - raise ValueError("Negative variance") diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index dcc4783..907d85a 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -44,11 +44,7 @@ def __init__(self, *, parts, total): ) chunk_size = total / n_chunks - parts = [ - x - for x in range(total) - if x >= (i_chunk - 1) * chunk_size and x < i_chunk * chunk_size - ] + parts = [x for x in range(total) if x >= (i_chunk - 1) * chunk_size and x < i_chunk * chunk_size] parts = [int(_) for _ in parts] LOG.info(f"Running parts: {parts}") @@ -140,18 +136,12 @@ def new_key(self, key, values_shape): # Create a new key for indexing the large array. new_key = tuple( - ( - slice(self.offset, self.offset + values_shape[i]) - if i == self.axis - else slice(None) - ) + (slice(self.offset, self.offset + values_shape[i]) if i == self.axis else slice(None)) for i in range(len(self.shape)) ) else: # For non-slice keys, adjust the key based on the offset and axis. - new_key = tuple( - k + self.offset if i == self.axis else k for i, k in enumerate(key) - ) + new_key = tuple(k + self.offset if i == self.axis else k for i, k in enumerate(key)) return new_key def __setitem__(self, key, values): @@ -195,25 +185,22 @@ def write_cube(self, cube, icube): assert isinstance(icube, int), icube shape = cube.extended_user_shape + dates = cube.user_coords["valid_datetime"] slice = self.registry.get_slice_for(icube) LOG.info( f"Building dataset '{self.path}' i={icube} total={self.n_cubes} " f"(total shape ={shape}) at {slice}, {self.full_array.chunks=}" ) - self.print( - f"Building dataset (total shape ={shape}) at {slice}, {self.full_array.chunks=}" - ) + self.print(f"Building dataset (total shape ={shape}) at {slice}, {self.full_array.chunks=}") offset = slice.start - array = OffsetView( - self.full_array, offset=offset, axis=self.append_axis, shape=shape - ) + array = OffsetView(self.full_array, offset=offset, axis=self.append_axis, shape=shape) array = FastWriteArray(array, shape=shape) self.load_datacube(cube, array) new_key, stats = array.compute_statistics_and_key(self.variables_names) - self.statistics_registry[new_key] = stats + self.statistics_registry.write(new_key, stats, dates=dates) array.flush() @@ -258,12 +245,8 @@ def load_datacube(self, cube, array): LOG.info("Written.") self.print( - f"Elapsed: {seconds(time.time() - start)}," - f" load time: {seconds(load)}," - f" write time: {seconds(save)}." + f"Elapsed: {seconds(time.time() - start)}," f" load time: {seconds(load)}," f" write time: {seconds(save)}." ) LOG.info( - f"Elapsed: {seconds(time.time() - start)}," - f" load time: {seconds(load)}," - f" write time: {seconds(save)}." + f"Elapsed: {seconds(time.time() - start)}," f" load time: {seconds(load)}," f" write time: {seconds(save)}." ) diff --git a/tests/create-join-reference/create-join.zarr/_build/count/.zarray b/tests/create-join-reference/create-join.zarr/_build/count/.zarray deleted file mode 100644 index c4a5621..0000000 --- a/tests/create-join-reference/create-join.zarr/_build/count/.zarray +++ /dev/null @@ -1,22 +0,0 @@ -{ - "chunks": [ - 10, - 8 - ], - "compressor": { - "blocksize": 0, - "clevel": 5, - "cname": "lz4", - "id": "blosc", - "shuffle": 1 - }, - "dtype": "