Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
statistics with missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Feb 28, 2024
1 parent fa749b8 commit 84d7f1f
Show file tree
Hide file tree
Showing 14 changed files with 263 additions and 453 deletions.
205 changes: 68 additions & 137 deletions ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@
from .check import DatasetName
from .config import build_output, loader_config
from .input import build_input
from .statistics import (
StatisticsRegistry,
compute_aggregated_statistics,
compute_statistics,
)
from .statistics import TempStatistics
from .utils import (
bytes,
compute_directory_sizes,
normalize_and_check_dates,
progress_bar,
to_datetime,
)
from .writer import CubesFilter, DataWriter
from .zarr import ZarrBuiltRegistry, add_zarr_dataset
Expand All @@ -52,10 +47,7 @@ def __init__(self, *, path, print=print, **kwargs):

statistics_tmp = kwargs.get("statistics_tmp") or self.path + ".statistics"

self.statistics_registry = StatisticsRegistry(
statistics_tmp,
history_callback=self.registry.add_to_history,
)
self.statistics_registry = TempStatistics(statistics_tmp)

@classmethod
def from_config(cls, *, config, path, print=print, **kwargs):
Expand Down Expand Up @@ -94,16 +86,17 @@ def read_dataset_metadata(self):
ds = open_dataset(self.path)
self.dataset_shape = ds.shape
self.variables_names = ds.variables
assert len(self.variables_names) == ds.shape[1], self.dataset_shape
self.dates = ds.dates

z = zarr.open(self.path, "r")
start = z.attrs.get("statistics_start_date")
end = z.attrs.get("statistics_end_date")
if start:
start = to_datetime(start)
if end:
end = to_datetime(end)
self._statistics_start_date_from_dataset = start
self._statistics_end_date_from_dataset = end
self.missing_dates = z.attrs.get("missing_dates")
if self.missing_dates:
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]
assert type(self.missing_dates[0]) == type(self.dates[0]), (
self.missing_dates[0],
self.dates[0],
)

@cached_property
def registry(self):
Expand Down Expand Up @@ -283,10 +276,29 @@ def initialise(self, check_name=True):
self.statistics_registry.create(exist_ok=False)
self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version)

statistics_start, statistics_end = self._build_statistics_dates(
self.main_config.output.get("statistics_start"),
self.main_config.output.get("statistics_end"),
)
self.update_metadata(
statistics_start_date=statistics_start,
statistics_end_date=statistics_end,
)
print(f"Will compute statistics from {statistics_start} to {statistics_end}")

self.registry.add_to_history("init finished")

assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())

def _build_statistics_dates(self, start, end):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(start, end)
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
return (
start.astype(datetime.datetime).isoformat(),
end.astype(datetime.datetime).isoformat(),
)


class ContentLoader(Loader):
def __init__(self, config, **kwargs):
Expand Down Expand Up @@ -340,24 +352,20 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
assert statistics_start is None, statistics_start
assert statistics_end is None, statistics_end

self.recompute = recompute

self._write_to_dataset = True

self.statistics_output = statistics_output
if self.statistics_output:
self._write_to_dataset = False

if config:
self.main_config = loader_config(config)

self._statistics_start = statistics_start
self._statistics_end = statistics_end

self.check_complete(force=force)

self.read_dataset_metadata()
self.read_dataset_dates_metadata()

def run(self):
# if requested, recompute statistics from data
Expand All @@ -366,19 +374,33 @@ def run(self):
if self.recompute:
self.recompute_temporary_statistics()

# compute the detailed statistics from temporary statistics directory
detailed = self.get_detailed_stats()
dates = [d for d in self.dates if d not in self.missing_dates]

if self._write_to_dataset:
self.write_detailed_statistics(detailed)
if self.missing_dates:
assert type(self.missing_dates[0]) == type(dates[0]), (type(self.missing_dates[0]), type(dates[0]))

# compute the aggregated statistics from the detailed statistics
# for the selected dates
selected = {k: v[self.i_start : self.i_end + 1] for k, v in detailed.items()}
stats = compute_aggregated_statistics(selected, self.variables_names)
dates_computed = self.statistics_registry.dates_computed
for d in dates:
if d in self.missing_dates:
assert d not in dates_computed, (d, date_computed)
else:
assert d in dates_computed, (d, dates_computed)

if self._write_to_dataset:
self.write_aggregated_statistics(stats)
z = zarr.open(self.path, mode="r")
start = z.attrs.get("statistics_start_date")
end = z.attrs.get("statistics_end_date")
start = np.datetime64(start)
end = np.datetime64(end)
dates = [d for d in dates if d >= start and d <= end]
assert type(start) == type(dates[0]), (type(start), type(dates[0]))

stats = self.statistics_registry.get_aggregated(dates, self.variables_names)

writer = {
None: self.write_stats_to_dataset,
"-": self.write_stats_to_stdout,
}.get(self.statistics_output, self.write_stats_to_file)
writer(stats)

def check_complete(self, force):
if self._complete:
Expand All @@ -389,57 +411,12 @@ def check_complete(self, force):
print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
self._write_to_dataset = False

@property
def statistics_start(self):
user = self._statistics_start
config = self.main_config.get("output", {}).get("statistics_start")
dataset = self._statistics_start_date_from_dataset
return user or config or dataset

@property
def statistics_end(self):
user = self._statistics_end
config = self.main_config.get("output", {}).get("statistics_end")
dataset = self._statistics_end_date_from_dataset
return user or config or dataset

@property
def _complete(self):
return all(self.registry.get_flags(sync=False))

def read_dataset_dates_metadata(self):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(self.statistics_start, self.statistics_end)
self.i_start = subset[0]
self.i_end = subset[-1]
self.date_start = ds.dates[subset[0]]
self.date_end = ds.dates[subset[-1]]

# do not write statistics to dataset if dates do not match the ones in the dataset metadata
start = self._statistics_start_date_from_dataset
end = self._statistics_end_date_from_dataset

start_ok = start is None or to_datetime(self.date_start) == start
end_ok = end is None or to_datetime(self.date_end) == end
if not (start_ok and end_ok):
print(
f"Statistics start/end dates {self.date_start}/{self.date_end} "
f"do not match dates in the dataset metadata {start}/{end}. "
f"Will not write statistics to dataset."
)
self._write_to_dataset = False

def check():
i_len = self.i_end + 1 - self.i_start
self.print(f"Statistics computed on {i_len}/{len(ds.dates)} samples ")
print(f"Requested ({i_len}): from {self.date_start} to {self.date_end}.")
print(f"Available ({len(ds.dates)}): from {ds.dates[0]} to {ds.dates[-1]}.")
if i_len < 1:
raise ValueError("Cannot compute statistics on an empty interval.")

check()

def recompute_temporary_statistics(self):
raise NotImplementedError("Untested code")
self.statistics_registry.create(exist_ok=True)

self.print(
Expand Down Expand Up @@ -471,67 +448,21 @@ def recompute_temporary_statistics(self):
self.statistics_registry[key] = detailed_stats
self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config)

def get_detailed_stats(self):
expected_shape = (self.dataset_shape[0], self.dataset_shape[1])
try:
return self.statistics_registry.as_detailed_stats(expected_shape)
except self.statistics_registry.MissingDataException as e:
missing_index = e.args[1]
dates = open_dataset(self.path).dates
missing_dates = dates[missing_index[0]]
print(
f"Missing dates: "
f"{missing_dates[0]} ... {missing_dates[len(missing_dates)-1]} "
f"({missing_dates.shape[0]} missing)"
)
raise

def write_detailed_statistics(self, detailed_stats):
z = zarr.open(self.path)["_build"]
for k, v in detailed_stats.items():
if k == "variables_names":
continue
add_zarr_dataset(zarr_root=z, name=k, array=v)
print("Wrote detailed statistics to zarr.")

def write_aggregated_statistics(self, stats):
if self.statistics_output == "-":
print(stats)
return

if self.statistics_output:
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
print(f"✅ Statistics written in {self.statistics_output}")
return

if not self._write_to_dataset:
return
def write_stats_to_file(self, stats):
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
print(f"✅ Statistics written in {self.statistics_output}")
return

for k in [
"mean",
"stdev",
"minimum",
"maximum",
"sums",
"squares",
"count",
]:
def write_stats_to_dataset(self, stats):
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]:
self._add_dataset(name=k, array=stats[k])

self.update_metadata(
statistics_start_date=str(self.date_start),
statistics_end_date=str(self.date_end),
)

self.registry.add_to_history(
"compute_statistics_end",
start=str(self.date_start),
end=str(self.date_end),
i_start=self.i_start,
i_end=self.i_end,
)
self.registry.add_to_history("compute_statistics_end")
print(f"Wrote statistics in {self.path}")

def write_stats_to_stdout(self, stats):
print(stats)


class SizeLoader(Loader):
def __init__(self, path, print):
Expand Down
Loading

0 comments on commit 84d7f1f

Please sign in to comment.