From ddcee7dcae1abc5fc8679fba6cb9f3af328ae6d5 Mon Sep 17 00:00:00 2001 From: Aram Salihi Date: Wed, 18 Dec 2024 18:15:51 +0100 Subject: [PATCH] Added reorder sort string, which sorts variables alphabetically. (#144) * Added reorder sort string, which sorts variables alphabetically. Crucial for pre-training and transfer learning. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/anemoi/datasets/data/dataset.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index fcb7a384..6036e630 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -261,13 +261,19 @@ def _drop_to_columns(self, vars): return sorted([v for k, v in self.name_to_index.items() if k not in vars]) def _reorder_to_columns(self, vars): + if isinstance(vars, str) and vars == "sort": + # Sorting the variables alphabetically. + # This is cruical for pre-training then transfer learning in combination with + # cutout and adjust = 'all' + + indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])] + assert set(indices) == set(range(len(self.name_to_index))) + return indices + if isinstance(vars, (list, tuple)): vars = {k: i for i, k in enumerate(vars)} - indices = [] - - for k, v in sorted(vars.items(), key=lambda x: x[1]): - indices.append(self.name_to_index[k]) + indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])] # Make sure we don't forget any variables assert set(indices) == set(range(len(self.name_to_index)))