Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Mar 24, 2024
1 parent 74d1915 commit 887496f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 14 deletions.
6 changes: 3 additions & 3 deletions ecml_tools/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _subset(self, **kwargs):

from .subset import Subset

return Subset(self, self._dates_to_indices(start, end))._subset(**kwargs)
return Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs)

if "frequency" in kwargs:
from .subset import Subset

frequency = kwargs.pop("frequency")
return Subset(self, self._frequency_to_indices(frequency))._subset(**kwargs)
return Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))._subset(**kwargs)

if "select" in kwargs:
from .select import Select
Expand Down Expand Up @@ -90,7 +90,7 @@ def _subset(self, **kwargs):
shuffle = kwargs.pop("shuffle")

if shuffle:
return Subset(self, self._shuffle_indices())._subset(**kwargs)
return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs)

raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))

Expand Down
47 changes: 47 additions & 0 deletions ecml_tools/data/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,53 @@ def __repr__(self):
self._put(0, result)
return "\n".join(result)

def graph(self, digraph, nodes):
label = self.dataset.__class__.__name__.lower()
if self.kwargs:
param = []
for k, v in self.kwargs.items():
if k == "path" and isinstance(v, str):
v = os.path.basename(v)
if isinstance(v, (list, tuple)):
v = ", ".join(str(i) for i in v)
else:
v = str(v)
v = textwrap.shorten(v, width=40, placeholder="...")
# if len(self.kwargs) == 1:
# param.append(v)
# else:
param.append(f"{k}={v}")
label = f'{label}({",".join(param)})'

label += "\n" + "\n".join(
textwrap.shorten(str(v), width=40, placeholder="...")
for v in (
self.dataset.dates[0],
self.dataset.dates[-1],
self.dataset.frequency,
self.dataset.shape,
self.dataset.variables,
)
)

nodes[f"N{id(self)}"] = label
for kid in self.kids:
digraph.append(f"N{id(self)} -> N{id(kid)}")
kid.graph(digraph, nodes)

def digraph(self):
digraph = ["digraph {"]
digraph.append("node [shape=box];")
nodes = {}

self.graph(digraph, nodes)

for node, label in nodes.items():
digraph.append(f'{node} [label="{label}"];')

digraph.append("}")
return "\n".join(digraph)


class Source:
"""Class used to follow the provenance of a data point."""
Expand Down
5 changes: 3 additions & 2 deletions ecml_tools/data/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _overlay(self):
# No overlay
return self

variables = [v[1:-1] for v in self.variables if v[0] == "(" and v[-1] == ")"]
indices = list(indices.values())

i = 0
Expand All @@ -95,7 +96,7 @@ def _overlay(self):

from .select import Select

return Select(self, indices, {"overlay": True})
return Select(self, indices, {"overlay": variables})

@cached_property
def variables(self):
Expand Down Expand Up @@ -154,4 +155,4 @@ def join_factory(args, kwargs, zarr_root):

datasets, kwargs = _auto_adjust(datasets, kwargs)

return Join(datasets)._subset(**kwargs)
return Join(datasets)._overlay()._subset(**kwargs)
24 changes: 17 additions & 7 deletions ecml_tools/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _auto_adjust(datasets, kwargs):
for p in ("select", "frequency", "start", "end"):
kwargs[p] = "matching"

adjust = {}
adjust = [{} for _ in datasets]

if kwargs.get("select") == "matching":
kwargs.pop("select")
Expand All @@ -230,22 +230,32 @@ def _auto_adjust(datasets, kwargs):
if len(variables) == 0:
raise ValueError("No common variables")

adjust["select"] = sorted(variables)
for i, d in enumerate(datasets):
if set(d.variables) != variables:
adjust[i]["select"] = sorted(variables)

if kwargs.get("start") == "matching":
kwargs.pop("start")
adjust["start"] = max(d.dates[0] for d in datasets).astype(object)
start = max(d.dates[0] for d in datasets).astype(object)
for i, d in enumerate(datasets):
if start != d.dates[0]:
adjust[i]["start"] = start

if kwargs.get("end") == "matching":
kwargs.pop("end")
adjust["end"] = min(d.dates[-1] for d in datasets).astype(object)
end = min(d.dates[-1] for d in datasets).astype(object)
for i, d in enumerate(datasets):
if end != d.dates[-1]:
adjust[i]["end"] = end

if kwargs.get("frequency") == "matching":
kwargs.pop("frequency")
adjust["frequency"] = max(d.frequency for d in datasets)
frequency = max(d.frequency for d in datasets)
for i, d in enumerate(datasets):
if d.frequency != frequency:
adjust[i]["frequency"] = frequency

if adjust:
datasets = [d._subset(**adjust) for d in datasets]
datasets = [d._subset(**adjust[i]) for i, d in enumerate(datasets)]

return datasets, kwargs

Expand Down
5 changes: 3 additions & 2 deletions ecml_tools/data/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
class Subset(Forwards):
"""Select a subset of the dates."""

def __init__(self, dataset, indices):
def __init__(self, dataset, indices, reason):
while isinstance(dataset, Subset):
indices = [dataset.indices[i] for i in indices]
dataset = dataset.dataset

self.dataset = dataset
self.indices = list(indices)
self.reason = {k: v for k, v in reason.items() if v is not None}

# Forward other properties to the super dataset
super().__init__(dataset)
Expand Down Expand Up @@ -101,4 +102,4 @@ def missing(self):
return {self.indices[i] for i in self.dataset.missing if i in self.indices}

def tree(self):
return Node(self, [self.dataset.tree()])
return Node(self, [self.dataset.tree()], **self.reason)

0 comments on commit 887496f

Please sign in to comment.