Skip to content

Commit

Permalink
Added reorder sort string, which sorts variables alphabetically. (ecm…
Browse files Browse the repository at this point in the history
…wf#144)

* Added reorder sort string, which sorts variables alphabetically. Crucial for pre-training and transfer learning.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
einrone and pre-commit-ci[bot] authored Dec 18, 2024
1 parent 8fd1000 commit ddcee7d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,19 @@ def _drop_to_columns(self, vars):
return sorted([v for k, v in self.name_to_index.items() if k not in vars])

def _reorder_to_columns(self, vars):
if isinstance(vars, str) and vars == "sort":
# Sorting the variables alphabetically.
# This is cruical for pre-training then transfer learning in combination with
# cutout and adjust = 'all'

indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])]
assert set(indices) == set(range(len(self.name_to_index)))
return indices

if isinstance(vars, (list, tuple)):
vars = {k: i for i, k in enumerate(vars)}

indices = []

for k, v in sorted(vars.items(), key=lambda x: x[1]):
indices.append(self.name_to_index[k])
indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])]

# Make sure we don't forget any variables
assert set(indices) == set(range(len(self.name_to_index)))
Expand Down

0 comments on commit ddcee7d

Please sign in to comment.