Skip to content

Commit

Permalink
normalisation ok
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Oct 11, 2024
1 parent bfe77b2 commit 6a9dda4
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions src/anemoi/utils/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def dtype(self):
assert self._dtype # this should be set by the subclass
return self._dtype

def clone(self):
raise NotImplementedError


#######################################
# this may move to anemoi-training ?
Expand Down Expand Up @@ -125,6 +128,7 @@ class NestedTrainingSample(TrainingSample):
# states are not regular and cannot be stacked
# appropriate for Observations
def __init__(self, states, state_type="torch", **kwargs):
if kwargs: print('❌ the clone method assumes no kwargs')
super().__init__(**kwargs)
assert isinstance(states, (list, tuple, NestedTrainingSample)), type(states)
self._state_type = state_type
Expand All @@ -135,6 +139,9 @@ def __init__(self, states, state_type="torch", **kwargs):
self._len = len(states)
self._dtype = states[0].dtype

def clone(self):
return self.__class__(tuple(v.clone() for v in self), state_type=self._state_type)

@classmethod
def from_tuple_of_tuple_of_arrays(cls, tuple_of_tuple_of_arrays, **kwargs):
return cls(tuple_of_tuple_of_arrays, **kwargs)
Expand All @@ -159,12 +166,8 @@ def to(self, device):
def as_torch(self):
return self.__class__(tuple(v.as_torch() for v in self))

def as_tuple_of_tuples(self):
return tuple(v.as_tuple() for v in self)

def as_tuple_of_dicts(self, keys=None):
return tuple(v.as_dict(keys) for v in self)

def as_native(self):
return tuple(v.as_native() for v in self)

class EnsembleTrainingSample(TrainingSample):
# One additional dimension and potentially different behavior
Expand Down Expand Up @@ -226,6 +229,9 @@ def __init__(self, arrays, **kwargs):
arrays = {i: v for i, v in enumerate(arrays)}
self.arrays = arrays

def clone(self):
raise NotImplementedError

def check_array_type(self, arrays):
_type = None
for _, a in arrays.items():
Expand Down Expand Up @@ -277,13 +283,8 @@ def as_list(self):
def as_tuple(self):
return tuple(self.arrays.values())

def as_dict(self, keys=None):
if keys is None:
return self.arrays

assert all(isinstance(k, int) for k in self.arrays), (keys, self.arrays.keys())
assert len(keys) == len(self.arrays), (len(keys), len(self.arrays))
return {keys[k]: v for k, v in self.arrays.items()}
def as_native(self):
return self.arrays

class NumpyNestedAnemoiTensor(NestedAnemoiTensor):
def flatten(self):
Expand All @@ -300,12 +301,14 @@ def check_array_type(self, arrays):

class TorchNestedAnemoiTensor(NestedAnemoiTensor):
def __init__(self, arrays, **kwargs):

arrays = {k:self._cast_to_torch(v) for k, v in arrays.items()}

super().__init__(arrays, **kwargs)
self.check_array_type(arrays)

def clone(self):
return self.__class__({k: v.clone() for k, v in self.arrays.items()})


@classmethod
def _cast_to_torch(cls, v):
Expand Down

0 comments on commit 6a9dda4

Please sign in to comment.