Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
Merge branch 'feature/nans' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Mar 6, 2024
2 parents b1bccb0 + 9f8cd5d commit e1cec79
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 187 deletions.
4 changes: 2 additions & 2 deletions ecml_tools/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def load(self, parts=None):

with self._cache_context():
loader = ContentLoader.from_dataset_config(
path=self.path, statistics_tmp=self.statistics_tmp, print=self.print
path=self.path, statistics_tmp=self.statistics_tmp, print=self.print, parts=parts
)
loader.load(parts=parts)
loader.load()

def statistics(self, force=False, output=None, start=None, end=None):
from .loaders import StatisticsLoader
Expand Down
44 changes: 22 additions & 22 deletions ecml_tools/create/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def compute_directory_size(path):
return None
size = 0
n = 0
for dirpath, _, filenames in tqdm.tqdm(
os.walk(path), desc="Computing size", leave=False
):
for dirpath, _, filenames in tqdm.tqdm(os.walk(path), desc="Computing size", leave=False):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
size += os.path.getsize(file_path)
Expand Down Expand Up @@ -57,10 +55,7 @@ def __init__(
self.check_end_date(end_date)

if self.messages:
self.messages.append(
f"{self} is parsed as :"
+ "/".join(f"{k}={v}" for k, v in self.parsed.items())
)
self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items()))

@property
def error_message(self):
Expand All @@ -76,9 +71,7 @@ def raise_if_not_valid(self, print=print):
raise ValueError(self.error_message)

def _parse(self, name):
pattern = (
r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$"
)
pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$"
match = re.match(pattern, name)

assert match, (name, pattern)
Expand Down Expand Up @@ -112,10 +105,7 @@ def check_parsed(self):
)

def check_resolution(self, resolution):
if (
self.parsed.get("resolution")
and self.parsed["resolution"][0] not in "0123456789on"
):
if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on":
self.messages.append(
f"the resolution {self.parsed['resolution'] } should start "
f"with a number or 'o' or 'n' in the dataset name {self}."
Expand Down Expand Up @@ -150,22 +140,23 @@ def check_end_date(self, end_date):

def _check_missing(self, key, value):
if value not in self.name:
self.messages.append(
f"the {key} is {value}, but is missing in {self.name}."
)
self.messages.append(f"the {key} is {value}, but is missing in {self.name}.")

def _check_mismatch(self, key, value):
if self.parsed.get(key) and self.parsed[key] != value:
self.messages.append(
f"the {key} is {value}, but is {self.parsed[key]} in {self.name}."
)
self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.")


class StatisticsValueError(ValueError):
pass


def check_data_values(arr, *, name: str, log=[]):
def check_data_values(arr, *, name: str, log=[], allow_nan=False):
if allow_nan is False:
allow_nan = lambda x: False
if allow_nan(name):
arr = arr[~np.isnan(arr)]

min, max = arr.min(), arr.max()
assert not (np.isnan(arr).any()), (name, min, max, *log)

Expand All @@ -175,10 +166,19 @@ def check_data_values(arr, *, name: str, log=[]):
warnings.warn(f"Max value 9999 for {name}")

in_0_1 = dict(minimum=0, maximum=1)
is_temperature = dict(minimum=173.15, maximum=373.15) # -100 celsius to +200 celsius
# is_wind = dict(minimum=-500., maximum=500.)
limits = {
"lsm": in_0_1,
"cos_latitude": in_0_1,
"sin_latitude": in_0_1,
"cos_longitude": in_0_1,
"sin_longitude": in_0_1,
"insolation": in_0_1,
"2t": dict(minimum=173.15, maximum=373.15),
"2t": is_temperature,
"sst": is_temperature,
# "10u": is_wind,
# "10v": is_wind,
}

if name in limits:
Expand Down
8 changes: 0 additions & 8 deletions ecml_tools/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ def get_chunking(self, coords):
)
return tuple(chunks)

@property
def append_axis(self):
return self.config.append_axis

@property
def order_by(self):
return self.config.order_by
Expand Down Expand Up @@ -163,10 +159,6 @@ def normalise(self):
assert "flatten_values" not in self.output
assert "flatten_grid" in self.output, self.output

# The axis along which we append new data
# TODO: assume grid points can be 2d as well
self.output.append_axis = 0

assert "statistics" in self.output
statistics_axis_name = self.output.statistics
statistics_axis = -1
Expand Down
5 changes: 5 additions & 0 deletions ecml_tools/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def data_request(self):
"""Returns a dictionary with the parameters needed to retrieve the data."""
return _data_request(self.datasource)

@property
def variables_with_nans(self):
print("❌❌HERE")
return

def get_cube(self):
trace("🧊", f"getting cube from {self.__class__.__name__}")
ds = self.datasource
Expand Down
127 changes: 111 additions & 16 deletions ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datetime
import logging
import os
import time
import uuid
from functools import cached_property

Expand All @@ -16,12 +17,18 @@
from ecml_tools.data import open_dataset
from ecml_tools.utils.dates.groups import Groups

from .check import DatasetName
from .check import DatasetName, check_data_values
from .config import build_output, loader_config
from .input import build_input
from .statistics import TempStatistics
from .utils import bytes, compute_directory_sizes, normalize_and_check_dates
from .writer import CubesFilter, DataWriter
from .statistics import TempStatistics, compute_statistics
from .utils import (
bytes,
compute_directory_sizes,
normalize_and_check_dates,
progress_bar,
seconds,
)
from .writer import CubesFilter, ViewCacheArray
from .zarr import ZarrBuiltRegistry, add_zarr_dataset

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,6 +104,9 @@ def read_dataset_metadata(self):
self.missing_dates = z.attrs.get("missing_dates", [])
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]

def allow_nan(self, name):
return name in self.main_config.get("has_nans", [])

@cached_property
def registry(self):
return ZarrBuiltRegistry(self.path)
Expand Down Expand Up @@ -180,6 +190,8 @@ def initialise(self, check_name=True):
variables = self.minimal_input.variables
self.print(f"Found {len(variables)} variables : {','.join(variables)}.")

variables_with_nans = self.config.get("has_nans", [])

ensembles = self.minimal_input.ensembles
self.print(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")

Expand Down Expand Up @@ -222,6 +234,7 @@ def initialise(self, check_name=True):

metadata["ensemble_dimension"] = len(ensembles)
metadata["variables"] = variables
metadata["variables_with_nans"] = variables_with_nans
metadata["resolution"] = resolution

metadata["licence"] = self.main_config["licence"]
Expand Down Expand Up @@ -291,7 +304,7 @@ def initialise(self, check_name=True):


class ContentLoader(Loader):
def __init__(self, config, **kwargs):
def __init__(self, config, parts, **kwargs):
super().__init__(**kwargs)
self.main_config = loader_config(config)

Expand All @@ -300,33 +313,115 @@ def __init__(self, config, **kwargs):
self.input = self.build_input()
self.read_dataset_metadata()

def load(self, parts):
self.registry.add_to_history("loading_data_start", parts=parts)
self.parts = parts
total = len(self.registry.get_flags())
self.cube_filter = CubesFilter(parts=self.parts, total=total)

z = zarr.open(self.path, mode="r+")
data_writer = DataWriter(parts, full_array=z["data"], owner=self)
self.data_array = zarr.open(self.path, mode="r+")["data"]
self.n_groups = len(self.groups)

def load(self):
self.registry.add_to_history("loading_data_start", parts=self.parts)

total = len(self.registry.get_flags())
filter = CubesFilter(parts=parts, total=total)
for igroup, group in enumerate(self.groups):
if not self.cube_filter(igroup):
continue
if self.registry.get_flag(igroup):
LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
continue
if not filter(igroup):
continue
self.print(f" -> Processing {igroup} total={len(self.groups)}")
print("========", group)
assert isinstance(group[0], datetime.datetime), group

result = self.input.select(dates=group)
data_writer.write(result, igroup, group)
assert result.dates == group, (len(result.dates), len(group))

msg = f"Building data for group {igroup}/{self.n_groups}"
LOG.info(msg)
self.print(msg)

self.registry.add_to_history("loading_data_end", parts=parts)
# There are several groups.
# There is one result to load for each group.
self.load_result(result)
self.registry.set_flag(igroup)

self.registry.add_to_history("loading_data_end", parts=self.parts)
self.registry.add_provenance(name="provenance_load")
self.statistics_registry.add_provenance(name="provenance_load", config=self.main_config)

self.print_info()

def load_result(self, result):
# There is one cube to load for each result.
dates = result.dates

cube = result.get_cube()
assert cube.extended_user_shape[0] == len(dates), (cube.extended_user_shape[0], len(dates))

shape = cube.extended_user_shape
dates_in_data = cube.user_coords["valid_datetime"]

LOG.info(f"Loading {shape=} in {self.data_array.shape=}")

def check_dates_in_data(lst, lst2):
lst2 = list(lst2)
lst = [datetime.datetime.fromisoformat(_) for _ in lst]
assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)

check_dates_in_data(dates_in_data, dates)

def dates_to_indexes(dates, all_dates):
x = np.array(dates, dtype=np.datetime64)
y = np.array(all_dates, dtype=np.datetime64)
bitmap = np.isin(x, y)
return np.where(bitmap)[0]

indexes = dates_to_indexes(self.dates, dates_in_data)

array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
self.load_cube(cube, array)

stats = compute_statistics(array.cache, self.variables_names, allow_nan=self.allow_nan)
self.statistics_registry.write(indexes, stats, dates=dates_in_data)

array.flush()

def load_cube(self, cube, array):
# There are several cubelets for each cube
start = time.time()
load = 0
save = 0

reading_chunks = None
total = cube.count(reading_chunks)
self.print(f"Loading datacube: {cube}")
bar = progress_bar(
iterable=cube.iterate_cubelets(reading_chunks),
total=total,
desc=f"Loading datacube {cube}",
)
for i, cubelet in enumerate(bar):
bar.set_description(f"Loading {i}/{total}")

now = time.time()
data = cubelet.to_numpy()
local_indexes = cubelet.coords
load += time.time() - now

name = self.variables_names[local_indexes[1]]
check_data_values(data[:], name=name, log=[i, data.shape, local_indexes], allow_nan=self.allow_nan)

now = time.time()
array[local_indexes] = data
save += time.time() - now

now = time.time()
save += time.time() - now
LOG.info("Written.")
msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}."
self.print(msg)
LOG.info(msg)


class StatisticsLoader(Loader):
main_config = {}
Expand Down Expand Up @@ -383,7 +478,7 @@ def _get_statistics_dates(self):

def run(self):
dates = self._get_statistics_dates()
stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
stats = self.statistics_registry.get_aggregated(dates, self.variables_names, self.allow_nan)
self.output_writer(stats)

def write_stats_to_file(self, stats):
Expand Down
Loading

0 comments on commit e1cec79

Please sign in to comment.