Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
make aggregated stats faster
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Mar 4, 2024
1 parent 7043278 commit e5e4618
Showing 1 changed file with 42 additions and 28 deletions.
70 changes: 42 additions & 28 deletions ecml_tools/create/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def __init__(self, dates, variables_names, owner):
self.sums = np.full(self.shape, np.nan, dtype=np.float64)
self.squares = np.full(self.shape, np.nan, dtype=np.float64)
self.count = np.full(self.shape, -1, dtype=np.int64)
self.flags = np.full(self.shape, False, dtype=np.bool_)

self._read()

Expand All @@ -195,44 +194,59 @@ def _date_to_index(self, date):
return np.where(self.dates == date)[0][0]

def _read(self):
available_dates = []
for _, dates, stats in self.owner._gather_data():
def check_type(a, b):
a = list(a)
b = list(b)
a = a[0] if a else None
b = b[0] if b else None
assert type(a) is type(b), (type(a), type(b))

found = set()
offset = 0
for _, _dates, stats in self.owner._gather_data():
assert isinstance(stats, dict), stats
assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates))
assert stats["minimum"].shape[1] == len(self.variables_names), (
stats["minimum"].shape,
len(self.variables_names),
)
for n in self.NAMES:
assert n in stats, (n, list(stats.keys()))
dates = to_datetimes(dates)

indexes = []
stats_indexes = []
for i, d in enumerate(dates):
if d not in self.dates:
continue
stats_indexes.append(i)
indexes.append(self._date_to_index(d))
available_dates.append(d)

if not indexes:
_dates = to_datetimes(_dates)
check_type(_dates, self.dates)
if found:
check_type(found, self.dates)
assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics"

# filter dates
dates = set(_dates) & set(self.dates)

if not dates:
# dates have been completely filtered for this chunk
continue

self.flags[indexes] = True
# filter data
bitmap = np.isin(_dates, self.dates)
for k in self.NAMES:
stats[k] = stats[k][bitmap]

assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates))

# store data in self
found |= set(dates)
for name in self.NAMES:
array = getattr(self, name)
data = stats[name]
data = data[stats_indexes]
array[indexes] = data
assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates))
array[offset : offset + len(dates)] = stats[name]
offset += len(dates)

assert type(available_dates[0]) is type(self.dates[0]), (available_dates[0], self.dates[0])
assert len(available_dates) == len(set(available_dates)), "Duplicate dates found in statistics"
for d in self.dates:
assert d in available_dates, f"Statistics for date {d} not precomputed."
assert len(available_dates) == len(self.dates)
print(f"Statistics for {len(available_dates)} dates found.")
assert d in found, f"Statistics for date {d} not precomputed."
assert len(self.dates) == len(found), "Not all dates found in precomputed statistics"
assert len(self.dates) == offset, "Not all dates found in precomputed statistics."
print(f"Statistics for {len(found)} dates found.")

def aggregate(self):
if not np.all(self.flags):
not_found = np.where(self.flags == False) # noqa: E712
raise Exception(f"Statistics not precomputed for {not_found}", not_found)

for name in self.NAMES:
if name == "count":
continue
Expand Down

0 comments on commit e5e4618

Please sign in to comment.