Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Feb 29, 2024
1 parent 10dcc64 commit 3e4a00e
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 186 deletions.
27 changes: 2 additions & 25 deletions ecml_tools/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def init(self, check_name=False):
from .loaders import InitialiseLoader

if self._path_readable() and not self.overwrite:
raise Exception(
f"{self.path} already exists. Use overwrite=True to overwrite."
)
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

with self._cache_context():
obj = InitialiseLoader.from_config(
Expand All @@ -57,11 +55,7 @@ def load(self, parts=None):
)
loader.load(parts=parts)

def statistics(
self,
force=False,
output=None,
):
def statistics(self, force=False, output=None, start=None, end=None):
from .loaders import StatisticsLoader

loader = StatisticsLoader.from_dataset(
Expand All @@ -71,25 +65,8 @@ def statistics(
statistics_tmp=self.statistics_tmp,
statistics_output=output,
recompute=False,
)
loader.run()

def recompute_statistics(
self,
start=None,
end=None,
force=False,
):
from .loaders import StatisticsLoader

loader = StatisticsLoader.from_dataset(
path=self.path,
print=self.print,
force=force,
statistics_tmp=self.statistics_tmp,
statistics_start=start,
statistics_end=end,
recompute=True,
)
loader.run()

Expand Down
166 changes: 53 additions & 113 deletions ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@
from .config import build_output, loader_config
from .input import build_input
from .statistics import TempStatistics
from .utils import (
bytes,
compute_directory_sizes,
normalize_and_check_dates,
progress_bar,
)
from .utils import bytes, compute_directory_sizes, normalize_and_check_dates
from .writer import CubesFilter, DataWriter
from .zarr import ZarrBuiltRegistry, add_zarr_dataset

Expand Down Expand Up @@ -82,6 +77,15 @@ def build_input(self):
print(builder)
return builder

def build_statistics_dates(self, start, end):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(start, end)
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
return (
start.astype(datetime.datetime).isoformat(),
end.astype(datetime.datetime).isoformat(),
)

def read_dataset_metadata(self):
ds = open_dataset(self.path)
self.dataset_shape = ds.shape
Expand All @@ -90,21 +94,8 @@ def read_dataset_metadata(self):
self.dates = ds.dates

z = zarr.open(self.path, "r")
self.missing_dates = z.attrs.get("missing_dates")
if self.missing_dates:
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]
assert type(self.missing_dates[0]) == type(self.dates[0]), (
self.missing_dates[0],
self.dates[0],
)

def date_to_index(self, date):
if isinstance(date, str):
date = np.datetime64(date)
if isinstance(date, datetime.datetime):
date = np.datetime64(date)
assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0]))
return np.where(self.dates == date)[0][0]
self.missing_dates = z.attrs.get("missing_dates", [])
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]

@cached_property
def registry(self):
Expand Down Expand Up @@ -284,7 +275,7 @@ def initialise(self, check_name=True):
self.statistics_registry.create(exist_ok=False)
self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version)

statistics_start, statistics_end = self._build_statistics_dates(
statistics_start, statistics_end = self.build_statistics_dates(
self.main_config.output.get("statistics_start"),
self.main_config.output.get("statistics_end"),
)
Expand All @@ -298,15 +289,6 @@ def initialise(self, check_name=True):

assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())

def _build_statistics_dates(self, start, end):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(start, end)
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
return (
start.astype(datetime.datetime).isoformat(),
end.astype(datetime.datetime).isoformat(),
)


class ContentLoader(Loader):
def __init__(self, config, **kwargs):
Expand All @@ -322,7 +304,7 @@ def load(self, parts):
self.registry.add_to_history("loading_data_start", parts=parts)

z = zarr.open(self.path, mode="r+")
data_writer = DataWriter(parts, parent=self, full_array=z["data"], print=self.print)
data_writer = DataWriter(parts, full_array=z["data"], owner=self)

total = len(self.registry.get_flags())
filter = CubesFilter(parts=parts, total=total)
Expand Down Expand Up @@ -356,112 +338,70 @@ def __init__(
statistics_start=None,
statistics_end=None,
force=False,
recompute=False,
**kwargs,
):
super().__init__(**kwargs)
assert statistics_start is None, statistics_start
assert statistics_end is None, statistics_end

self.recompute = recompute

self._write_to_dataset = True
self.user_statistics_start = statistics_start
self.user_statistics_end = statistics_end

self.statistics_output = statistics_output

self.output_writer = {
None: self.write_stats_to_dataset,
"-": self.write_stats_to_stdout,
}.get(self.statistics_output, self.write_stats_to_file)

if config:
self.main_config = loader_config(config)

self.check_complete(force=force)
self.read_dataset_metadata()

def run(self):
# if requested, recompute statistics from data
# into the temporary statistics directory
# (this should have been done already when creating the dataset content)
if self.recompute:
self.recompute_temporary_statistics()

dates = [d for d in self.dates if d not in self.missing_dates]
def _get_statistics_dates(self):
dates = self.dates
dtype = type(dates[0])

# remove missing dates
if self.missing_dates:
assert type(self.missing_dates[0]) is type(dates[0]), (type(self.missing_dates[0]), type(dates[0]))

dates_computed = self.statistics_registry.dates_computed
for d in dates:
if d in self.missing_dates:
assert d not in dates_computed, (d, date_computed)
else:
assert d in dates_computed, (d, dates_computed)
assert type(self.missing_dates[0]) is dtype, (type(self.missing_dates[0]), dtype)
dates = [d for d in dates if d not in self.missing_dates]

# filter dates according the the start and end dates in the metadata
z = zarr.open(self.path, mode="r")
start = z.attrs.get("statistics_start_date")
end = z.attrs.get("statistics_end_date")
start = np.datetime64(start)
end = np.datetime64(end)
start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date")
start, end = np.datetime64(start), np.datetime64(end)
assert type(start) is dtype, (type(start), dtype)
dates = [d for d in dates if d >= start and d <= end]
assert type(start) is type(dates[0]), (type(start), type(dates[0]))

stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
# filter dates according the the user specified start and end dates
if self.user_statistics_start or self.user_statistics_end:
start, end = self.build_statistics_dates(self.user_statistics_start, self.user_statistics_end)
start, end = np.datetime64(start), np.datetime64(end)
assert type(start) is dtype, (type(start), dtype)
dates = [d for d in dates if d >= start and d <= end]

writer = {
None: self.write_stats_to_dataset,
"-": self.write_stats_to_stdout,
}.get(self.statistics_output, self.write_stats_to_file)
writer(stats)

def check_complete(self, force):
if self._complete:
return
if not force:
raise Exception(f"❗Zarr {self.path} is not fully built. Use 'force' option.")
if self._write_to_dataset:
print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
self._write_to_dataset = False

@property
def _complete(self):
return all(self.registry.get_flags(sync=False))

def recompute_temporary_statistics(self):
raise NotImplementedError("Untested code")
self.statistics_registry.create(exist_ok=True)

self.print(
f"Building temporary statistics from data {self.path}. " f"From {self.date_start} to {self.date_end}"
)

shape = (self.i_end + 1 - self.i_start, len(self.variables_names))
detailed_stats = dict(
minimum=np.full(shape, np.nan, dtype=np.float64),
maximum=np.full(shape, np.nan, dtype=np.float64),
sums=np.full(shape, np.nan, dtype=np.float64),
squares=np.full(shape, np.nan, dtype=np.float64),
count=np.full(shape, -1, dtype=np.int64),
)
return dates

ds = open_dataset(self.path)
key = (slice(self.i_start, self.i_end + 1), slice(None, None))
for i in progress_bar(
desc="Computing Statistics",
iterable=range(self.i_start, self.i_end + 1),
):
i_ = i - self.i_start
data = ds[slice(i, i + 1), :]
one = compute_statistics(data, self.variables_names)
for k, v in one.items():
detailed_stats[k][i_] = v

print(f"✅ Saving statistics for {key} shape={detailed_stats['count'].shape}")
self.statistics_registry[key] = detailed_stats
self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config)
def run(self):
dates = self._get_statistics_dates()
stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
self.output_writer(stats)

def write_stats_to_file(self, stats):
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
print(f"✅ Statistics written in {self.statistics_output}")
return

def write_stats_to_dataset(self, stats):
if self.user_statistics_start or self.user_statistics_end:
raise ValueError(
(
"Cannot write statistics in dataset with user specified dates. "
"This would be conflicting with the dataset metadata."
)
)

if not all(self.registry.get_flags(sync=False)):
raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")

for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]:
self._add_dataset(name=k, array=stats[k])

Expand Down
Loading

0 comments on commit 3e4a00e

Please sign in to comment.