Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add splitting and calculation of statistics #10

Merged
merged 10 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/python-package-pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,8 @@ jobs:
python -m pip install pytest

- name: Run tests
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
python -m pytest tests/
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased](https://github.com/mllam/mllam-data-prep/compare/v0.1.0...HEAD)

### Added

- add support for creating dataset splits (e.g. train, validation, test)
through `output.splitting` section in the config file, and support for
optionally compute statistics for a given split (with
`output.splitting.splits.{split_name}.compute_statistics`).
![\#28](https://github.com/mllam/mllam-data-prep/pull/10)

### Changed

- split dataset creation and storage to zarr into separate functions `mllam_data_prep.create_dataset(...)` and `mllam_data_prep.create_dataset_zarr(...)` respectively ![\#7](https://github.com/mllam/mllam-data-prep/pull/7)

- changes to spec from v0.1.0:
Expand Down
37 changes: 34 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ output:
step: PT3H
chunking:
time: 1
splitting:
dim: time
splits:
train:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
start: 1990-09-07T00:00
end: 1990-09-09T00:00

inputs:
danra_height_levels:
Expand Down Expand Up @@ -138,9 +153,9 @@ inputs:

```

Apart from identifiers to keep track of the configuration file format version and the datasets version, the configuration file is divided into two main sections:
Apart from identifiers to keep track of the configuration file format version and the dataset version (for you to keep track of changes that you make to the dataset), the configuration file is divided into two main sections:

- `output`: defines the input variables and dimensions of the output dataset produced by `mllam-data-prep`. These are the variables and dimensions that the inputs datasets will be mapped to. These should match the variables and dimensions expected by the model architecture you are training.
- `output`: defines the variables and dimensions of the output dataset produced by `mllam-data-prep`. These are the variables and dimensions that the input datasets will be mapped to. These output variables and dimensions should match the input variables and dimensions expected by the model architecture you are training.
- `inputs`: a list of source datasets to extract data from. These are the datasets that will be mapped to the architecture defined in the `architecture` section.

### The `output` section
Expand All @@ -158,13 +173,29 @@ output:
step: PT3H
chunking:
time: 1
splitting:
dim: time
splits:
train:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
start: 1990-09-07T00:00
end: 1990-09-09T00:00
```

The `output` section defines three things:

1. `variables`: what input variables the model architecture you are targeting expects, and what the dimensions are for each of these variables.
2. `coord_ranges`: the range of values for each of the dimensions that the model architecture expects as input. These are optional, but allows you to ensure that the training dataset is created with the correct range of values for each dimension.
3. `chunking`: the chunk sizes to use when writing the training dataset to zarr. This is optional, but can be used to optimise the performance of the zarr dataset. By default the chunk sizes are set to the size of the dimension, but this can be overridden by setting the chunk size in the configuration file. A common choice is to set the dimension along which you are batching to align with the of each training item (e.g. if you are training a model with time-step roll-out of 10 timesteps, you might choose a chunksize of 10 along the time dimension).
4. Splitting and calculation of statistics of the output variables, using the `splitting` section. The `output.splitting.splits` attribute defines the individual splits to create (for example `train`, `val` and `test`) and `output.splitting.dim` defines the dimension to split along. The `compute_statistics` can be optionally set for a given split to calculate the statistical properties requested (for example `mean`, `std`) any method available on `xarray.Dataset.{op}` can be used. In addition methods prefixed by `diff_` (so the operational would be listed as `diff_{op}`) to compute a statistic based on difference of consecutive time-steps, e.g. `diff_mean` to compute the `mean` of the difference between consecutive timesteps (these are used for normalisating increments). The `dims` attribute defines the dimensions to calculate the statistics over (for example `grid_index` and `time`).

### The `inputs` section

Expand Down Expand Up @@ -217,7 +248,7 @@ inputs:
...
```

The `inputs` section defines the source datasets to extract data from. Each source dataset is defined by a key (e.g. `danra_height_levels`) which names the source, and the attributes of the source dataset:
The `inputs` section defines the source datasets to extract data from. Each source dataset is defined by a key (e.g. `danra_height_levels`) which names the source dataset, and the attributes of the source dataset:

- `path`: the path to the source dataset. This can be a local path or a URL to e.g. a zarr dataset or netCDF file, anything that can be read by `xarray.open_dataset(...)`.
- `dims`: the dimensions that the source dataset is expected to have. This is used to check that the source dataset has the expected dimensions and also makes it clearer in the config file what the dimensions of the source dataset are.
Expand Down
15 changes: 15 additions & 0 deletions example.danra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ output:
step: PT3H
chunking:
time: 1
splitting:
dim: time
splits:
train:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
start: 1990-09-07T00:00
end: 1990-09-09T00:00

inputs:
danra_height_levels:
Expand Down
71 changes: 70 additions & 1 deletion mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,62 @@ class InputDataset:
attributes: Dict[str, Any] = None


@dataclass
class Statistics:
"""
Define the statistics to compute for the output dataset, this includes defining
the the statistics to compute and the dimensions to compute the statistics over.
The statistics will be computed for each variable in the output dataset seperately.

Attributes
----------
ops: List[str]
The statistics to compute, e.g. ["mean", "std", "min", "max"].
dims: List[str]
The dimensions to compute the statistics over, e.g. ["time", "grid_index"].
"""

ops: List[str]
dims: List[str]


@dataclass
class Split:
"""
Define the `start` and `end` coordinate value (e.g. time) for a split of the dataset and optionally
the statistics to compute for the split.

Attributes
----------
start: str
The start of the split, e.g. "1990-09-03T00:00".
end: str
The end of the split, e.g. "1990-09-04T00:00".
compute_statistics: StatisticsInput
The statistics to compute for the split.
"""

start: str
end: str
compute_statistics: Statistics = None


@dataclass
class Splitting:
"""
dim: str
The dimension to split the dataset along, e.g. "time", this must be provided if splits are defined.

splits: Dict[str, Split]
Defines the splits of the dataset, the keys are the names of the splits and the values
are the `Split` objects defining the start and end of the split. Optionally, the
`compute_statistics` attribute can be used to define the statistics to compute for the split.
"""

dim: str
splits: Dict[str, Split]


@dataclass
class Output:
"""
Expand Down Expand Up @@ -206,11 +262,16 @@ class Output:
names of the dimensions and the values are the chunk size for that dimension.
If chunking is not specified for a dimension, then the entire dimension
will be a single chunk.

splitting: Splitting
Defines the splits of the dataset (e.g. train, test, validation), the dimension to split
the dataset along, and optionally the statistics to compute for each split.
"""

variables: Dict[str, List[str]]
coord_ranges: Dict[str, Range] = None
chunking: Dict[str, int] = None
splitting: Splitting = None


@dataclass
Expand Down Expand Up @@ -246,7 +307,15 @@ class Config(dataclass_wizard.YAMLWizard):


if __name__ == "__main__":
config = Config.from_yaml_file("example.danra.yaml")
import argparse

argparser = argparse.ArgumentParser()
argparser.add_argument(
"-f", help="Path to the yaml file to load.", default="example.danra.yaml"
)
args = argparser.parse_args()

config = Config.from_yaml_file(args.f)
import rich

rich.print(config)
42 changes: 42 additions & 0 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import shutil
from collections import defaultdict
from pathlib import Path

import numpy as np
import xarray as xr
from loguru import logger

from .config import Config, InvalidConfigException
from .ops.loading import load_and_subset_dataset
from .ops.mapping import map_dims_and_variables
from .ops.selection import select_by_kwargs
from .ops.statistics import calc_stats


def _check_dataset_attributes(ds, expected_attributes, dataset_name):
Expand Down Expand Up @@ -175,6 +178,45 @@ def create_dataset(config: Config):
chunks = {d: chunking_config.get(d, int(ds[d].count())) for d in ds.dims}
ds = ds.chunk(chunks)

splitting = config.output.splitting

if splitting is not None:
splits = splitting.splits
logger.info(
f"Setting splitting information to define `{list(splits.keys())}` splits "
f"along dimension `{splitting.dim}`"
)

for split_name, split_config in splits.items():
if split_config.compute_statistics is not None:
ds_split = ds.sel(
{splitting.dim: slice(split_config.start, split_config.end)}
)
logger.info(f"Computing statistics for split {split_name}")
split_stats = calc_stats(
ds=ds_split,
statistics_config=split_config.compute_statistics,
splitting_dim=splitting.dim,
)
for op, op_dataarrays in split_stats.items():
for var_name, da in op_dataarrays.items():
ds[f"{var_name}__{split_name}__{op}"] = da

# add a new variable which contains the start, stop for each split, the coords would then be the split names
# and the data would be the start, stop values
split_vals = np.array([[split.start, split.end] for split in splits.values()])
da_splits = xr.DataArray(
split_vals,
dims=["split_name", "split_part"],
coords={"split_name": list(splits.keys()), "split_part": ["start", "end"]},
)
ds["splits"] = da_splits

ds.attrs = {}
ds.attrs["schema_version"] = config.schema_version
ds.attrs["dataset_version"] = config.dataset_version
ds.attrs["created_on"] = datetime.datetime.now().replace(microsecond=0).isoformat()

return ds


Expand Down
52 changes: 52 additions & 0 deletions mllam_data_prep/ops/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict

import xarray as xr

from ..config import Statistics


def calc_stats(
ds: xr.Dataset, statistics_config: Statistics, splitting_dim: str
) -> Dict[str, xr.Dataset]:
"""
Calculate statistics for a given DataArray by applying the operations
specified in the Statistics object and reducing over the dimensions
specified in the Statistics object.

Parameters
----------
ds : xr.Dataset
Dataset to calculate statistics for
statistics_config : Statistics
Configuration object specifying the operations and dimensions to reduce over
splitting_dim : str
Dimension along which splits are made, this is used to calculate differences
for operations prefixed with "diff_", for example "diff_mean" or "diff_std".
Only the variables which actually span along the splitting_dim will be included
in the output.

Returns
-------
stats : Dict[str, xr.Dataset]
Dictionary with the operation names as keys and the calculated statistics as values
"""
stats = {}
for op_split in statistics_config.ops:
try:
pre_op, op = op_split.split("_")
except ValueError:
op = op_split
pre_op = None

if pre_op is not None:
if pre_op == "diff":
# subset to select only the variable which have the splitting_dim
vars_to_keep = [v for v in ds.data_vars if splitting_dim in ds[v].dims]
ds = ds[vars_to_keep].diff(dim=splitting_dim)
else:
raise NotImplementedError(pre_op)

fn = getattr(ds, op)
stats[op_split] = fn(dim=statistics_config.dims)

return stats
10 changes: 5 additions & 5 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"aiohttp>=3.9.3",
"dataclass-wizard>=0.22.3",
"semver>=3.0.2",
"rich>=13.7.1",
]
requires-python = ">=3.10"
readme = "README.md"
Expand All @@ -35,6 +36,5 @@ distribution = true
dev = [
"pytest>=8.0.2",
"ipdb>=0.13.13",
"rich>=13.7.1",
"pre-commit>=3.7.1",
]
Loading
Loading