From 8bd9a23416926075b286f0082ab5ed5a53ef5bdb Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 16 Oct 2023 16:02:47 -0700 Subject: [PATCH] Breaking: rename robust loss functions (#77) * fix ruff PGH003 Use specific rule codes when ignoring type issues * breaking: fix ruff N802 Function name RobustL(1|2)Loss should be lowercase * fix ruff FBT003 Boolean positional is_best value in function call * .pow(2) to **2 * try installing torch and torch-scatter 2.1.0 in CI * bump torch and torch-scatter to 2.1.0 in readme and GH codespaces config --- .devcontainer/devcontainer.json | 2 +- .github/workflows/test.yml | 8 ++++---- .pre-commit-config.yaml | 4 ++-- README.md | 4 ++-- aviary/core.py | 9 +++++++-- aviary/losses.py | 8 +++----- aviary/roost/data.py | 8 ++++---- aviary/roost/model.py | 2 +- aviary/segments.py | 2 +- aviary/train.py | 10 +++++----- aviary/utils.py | 22 +++++++++++----------- aviary/wren/data.py | 4 ++-- aviary/wren/model.py | 2 +- aviary/wren/utils.py | 10 +++++----- aviary/wrenformer/data.py | 2 +- aviary/wrenformer/model.py | 2 +- examples/cgcnn-example.py | 3 ++- examples/roost-example.py | 3 ++- examples/wren-example.py | 3 ++- pyproject.toml | 19 +++++++++---------- 20 files changed, 66 insertions(+), 61 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index c97bd416..c6d93f06 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,7 +1,7 @@ { "image": "mcr.microsoft.com/devcontainers/universal:2", "waitFor": "onCreateCommand", - "updateContentCommand": "pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu && pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html && pip install -e .", + "updateContentCommand": "pip install torch==2.1.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu && pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html && pip install -e .", "customizations": { "codespaces": { "openFiles": ["examples/notebooks/wren-example.ipynb"] diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ca69bd4..9e4a9891 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,10 +2,10 @@ name: Tests on: push: - paths: ['**/*.py', .github/workflows/test.yml] + paths: ["**/*.py", .github/workflows/test.yml] branches: [main] pull_request: - paths: ['**/*.py', .github/workflows/test.yml] + paths: ["**/*.py", .github/workflows/test.yml] branches: [main] jobs: @@ -24,8 +24,8 @@ jobs: - name: Install dependencies run: | - pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu - pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html + pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html pip install .[test] - name: Run Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e2a2702f..5eedd1dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: args: [--fix] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-case-conflict - id: check-symlinks @@ -33,7 +33,7 @@ repos: - id: black-jupyter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.6.0 hooks: - id: mypy exclude: (tests|examples)/ diff --git a/README.md b/README.md index 10c3a54f..6d1e799c 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,10 @@ The aim of `aviary` is to contain multiple models for materials discovery under Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with ```sh -pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html +pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html ``` -Make sure you replace `1.13.0` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable. +Make sure you replace `2.1.0` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable. Then install `aviary` from source with diff --git a/aviary/core.py b/aviary/core.py index 2d546f53..78738a63 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -176,7 +176,12 @@ def fit( } # TODO saving a model at each epoch may be slow? - save_checkpoint(checkpoint_dict, False, model_name, run_id) + save_checkpoint( + checkpoint_dict, + is_best=False, + model_name=model_name, + run_id=run_id, + ) # TODO when to save best models? should this be done task-wise in # the multi-task case? @@ -244,7 +249,7 @@ def evaluate( # compute output outputs = self(*inputs) - mixed_loss: Tensor = 0 # type: ignore + mixed_loss: Tensor = 0 # type: ignore[assignment] for target_name, targets, output, normalizer in zip( self.target_names, targets_list, outputs, normalizer_dict.values() diff --git a/aviary/losses.py b/aviary/losses.py index 82b9ee12..b983ddb2 100644 --- a/aviary/losses.py +++ b/aviary/losses.py @@ -2,7 +2,7 @@ from torch import Tensor -def RobustL1Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: +def robust_l1_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: """Robust L1 loss using a Lorentzian prior. Trains the model to learn to predict aleatoric (per-sample) uncertainty. @@ -21,7 +21,7 @@ def RobustL1Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Ten return torch.mean(loss) -def RobustL2Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: +def robust_l2_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor: """Robust L2 loss using a Gaussian prior. Trains the model to learn to predict aleatoric (per-sample) uncertainty. @@ -34,7 +34,5 @@ def RobustL2Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Ten Returns: Tensor: Evaluated robust L2 loss """ - loss = ( - 0.5 * (pred_mean - target).pow(2) * torch.exp(-2 * pred_log_std) + pred_log_std - ) + loss = 0.5 * (pred_mean - target) ** 2 * torch.exp(-2 * pred_log_std) + pred_log_std return torch.mean(loss) diff --git a/aviary/roost/data.py b/aviary/roost/data.py index 84706cd1..945fafba 100644 --- a/aviary/roost/data.py +++ b/aviary/roost/data.py @@ -107,12 +107,12 @@ def __getitem__(self, idx: int): f"{material_ids} ({composition}) composition cannot be parsed into elements" ) from exc - nele = len(elements) + n_elems = len(elements) self_idx = [] nbr_idx = [] - for i, _ in enumerate(elements): - self_idx += [i] * nele - nbr_idx += list(range(nele)) + for idx in range(n_elems): + self_idx += [idx] * n_elems + nbr_idx += list(range(n_elems)) # convert all data to tensors elem_weights = Tensor(weights) diff --git a/aviary/roost/model.py b/aviary/roost/model.py index 116958c5..a5f161ff 100644 --- a/aviary/roost/model.py +++ b/aviary/roost/model.py @@ -76,7 +76,7 @@ def __init__( "cry_msg": cry_msg, } - self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore + self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore[arg-type] model_params = { "robust": robust, diff --git a/aviary/segments.py b/aviary/segments.py index 8bb5d701..5b650b72 100644 --- a/aviary/segments.py +++ b/aviary/segments.py @@ -84,7 +84,7 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor: def __repr__(self) -> str: pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn - return f"{type(self).__name__}(pow={pow:.3}, {gate_nn=}, {message_nn=})" + return f"{type(self).__name__}({pow=:.3}, {gate_nn=}, {message_nn=})" class MessageLayer(nn.Module): diff --git a/aviary/train.py b/aviary/train.py index 1612bf57..01996f82 100644 --- a/aviary/train.py +++ b/aviary/train.py @@ -11,7 +11,7 @@ from aviary import ROOT from aviary.core import BaseModelClass, Normalizer, TaskType, np_softmax -from aviary.losses import RobustL1Loss +from aviary.losses import robust_l1_loss from aviary.utils import get_metrics, print_walltime from aviary.wrenformer.data import df_to_in_mem_dataloader from aviary.wrenformer.model import Wrenformer @@ -130,7 +130,7 @@ def train_model( print(f"Pytorch running on {device=}") loss_func = ( - (RobustL1Loss if robust else torch.nn.L1Loss()) + (robust_l1_loss if robust else torch.nn.L1Loss()) if task_type == reg_key else (torch.nn.NLLLoss() if robust else torch.nn.CrossEntropyLoss()) ) @@ -193,7 +193,7 @@ def train_model( start=swa_start, epochs=int(swa_start * epochs), learning_rate=swa_lr ) if task_type == reg_key and hasattr(train_loader, "df"): - train_df = getattr(train_loader, "df", train_loader.dataset.df) # type: ignore + train_df = getattr(train_loader, "df", train_loader.dataset.df) # type: ignore[union-attr] targets = train_df[target_col] run_params["dummy_mae"] = (targets - targets.mean()).abs().mean() if timestamp: @@ -415,11 +415,11 @@ def train_wrenformer( embedding_type=embedding_type, ) train_loader = df_to_in_mem_dataloader( - train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs # type: ignore + train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs # type: ignore[arg-type] ) test_loader = df_to_in_mem_dataloader( - test_df, batch_size=512, shuffle=False, **data_loader_kwargs # type: ignore + test_df, batch_size=512, shuffle=False, **data_loader_kwargs # type: ignore[arg-type] ) # embedding_len is the length of the embedding vector for a Wyckoff position encoding the diff --git a/aviary/utils.py b/aviary/utils.py index f6c9ea03..42fd7c65 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -29,7 +29,7 @@ from aviary import ROOT from aviary.core import BaseModelClass, Normalizer, TaskType, sampled_softmax -from aviary.losses import RobustL1Loss, RobustL2Loss +from aviary.losses import robust_l1_loss, robust_l2_loss if TYPE_CHECKING: from types import ModuleType @@ -224,9 +224,9 @@ def initialize_losses( elif task == "regression": if robust: if loss_name_dict[name] == "L1": - loss_func_dict[name] = (task, RobustL1Loss) + loss_func_dict[name] = (task, robust_l1_loss) elif loss_name_dict[name] == "L2": - loss_func_dict[name] = (task, RobustL2Loss) + loss_func_dict[name] = (task, robust_l2_loss) else: raise NameError( "Only L1 or L2 losses are allowed for robust regression tasks" @@ -472,7 +472,7 @@ def results_multitask( results_dict[target_name] = defaultdict( list if task_type == "classification" - else lambda: np.zeros((ensemble_folds, len(test_set))) # type: ignore + else lambda: np.zeros((ensemble_folds, len(test_set))) # type: ignore[call-overload] ) for ens_idx in range(ensemble_folds): @@ -518,11 +518,11 @@ def results_multitask( mean, log_std = output.unbind(dim=1) preds = normalizer.denorm(mean.data.cpu()) ale_std = torch.exp(log_std).data.cpu() * normalizer.std - res_dict["ale"][ens_idx, :] = ale_std.view(-1).numpy() # type: ignore + res_dict["ale"][ens_idx, :] = ale_std.view(-1).numpy() # type: ignore[call-overload] else: preds = normalizer.denorm(output.data.cpu()) - res_dict["preds"][ens_idx, :] = preds.view(-1).numpy() # type: ignore + res_dict["preds"][ens_idx, :] = preds.view(-1).numpy() # type: ignore[call-overload] elif task_type == "classification": if model.robust: @@ -532,13 +532,13 @@ def results_multitask( ) pre_logits = mean.data.cpu().numpy() pre_logits_std = torch.exp(log_std).data.cpu().numpy() - res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore + res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore[union-attr] else: pre_logits = output.data.cpu().numpy() logits = pre_logits.softmax(1) - res_dict["pre-logits"].append(pre_logits) # type: ignore - res_dict["logits"].append(logits) # type: ignore + res_dict["pre-logits"].append(pre_logits) # type: ignore[union-attr] + res_dict["logits"].append(logits) # type: ignore[union-attr] res_dict["targets"] = targets @@ -555,9 +555,9 @@ def results_multitask( for target_name, task_type in task_dict.items(): print(f"\nTask: {target_name=} on test set") if task_type == "regression": - print_metrics_regression(**results_dict[target_name]) # type: ignore + print_metrics_regression(**results_dict[target_name]) # type: ignore[arg-type] elif task_type == "classification": - print_metrics_classification(**results_dict[target_name]) # type: ignore + print_metrics_classification(**results_dict[target_name]) # type: ignore[arg-type] return results_dict diff --git a/aviary/wren/data.py b/aviary/wren/data.py index f4c2a971..8adf2d40 100644 --- a/aviary/wren/data.py +++ b/aviary/wren/data.py @@ -53,7 +53,7 @@ def __init__( self.identifiers = list(identifiers) self.df = df - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"): elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" with open(elem_embedding) as emb_file: @@ -61,7 +61,7 @@ def __init__( self.elem_emb_len = len(next(iter(self.elem_features.values()))) - if sym_emb in ["bra-alg-off", "spg-alg-off"]: + if sym_emb in ("bra-alg-off", "spg-alg-off"): sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json" with open(sym_emb) as sym_file: diff --git a/aviary/wren/model.py b/aviary/wren/model.py index 1c98b0ff..7060b8ef 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -85,7 +85,7 @@ def __init__( "cry_msg": cry_msg, } - self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore + self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore[arg-type] model_params = { "robust": robust, diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index 2a3fff63..206e3081 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -389,12 +389,12 @@ def count_crystal_dof(aflow_label: str) -> int: Returns: int: Number of free-parameters in given prototype """ - num_params = 0 + n_params = 0 aflow_label, _ = aflow_label.split(":") # chop off chemical system _, pearson, spg, *wyks = aflow_label.split("_") - num_params += cry_param_dict[pearson[0]] + n_params += cry_param_dict[pearson[0]] for wyk_letters_per_elem in wyks: # normalize Wyckoff letters to start with 1 if missing digit @@ -404,12 +404,12 @@ def count_crystal_dof(aflow_label: str) -> int: sep_el_wyks = [ "".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha) ] - num_params += sum( + n_params += sum( float(n) * param_dict[spg][k] for n, k in zip(sep_el_wyks[0::2], sep_el_wyks[1::2]) ) - return int(num_params) + return n_params def get_isopointal_proto_from_aflow(aflow_label: str) -> str: @@ -455,7 +455,7 @@ def get_isopointal_proto_from_aflow(aflow_label: str) -> str: isopointal: list[str] = [] for wyks_list in valid_permutations: - for trans in relab_dict[str(spg)]: + for trans in relab_dict[spg]: t = str.maketrans(trans) isopointal.append("_".join(wyks_list).translate(t)) diff --git a/aviary/wrenformer/data.py b/aviary/wrenformer/data.py index 065148b1..4d63547b 100644 --- a/aviary/wrenformer/data.py +++ b/aviary/wrenformer/data.py @@ -156,7 +156,7 @@ def df_to_in_mem_dataloader( if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - if embedding_type not in ["wyckoff", "composition"]: + if embedding_type not in ("wyckoff", "composition"): raise ValueError(f"{embedding_type = } must be 'wyckoff' or 'composition'") initial_embeddings = df[input_col].map( diff --git a/aviary/wrenformer/model.py b/aviary/wrenformer/model.py index 1da52d4c..d09dee91 100644 --- a/aviary/wrenformer/model.py +++ b/aviary/wrenformer/model.py @@ -91,7 +91,7 @@ def __init__( ResidualNetwork(out_hidden[0], n, out_hidden[1:]) for n in n_targets ) - def forward( # type: ignore + def forward( # type: ignore[override] self, features: Tensor, mask: BoolTensor, *args ) -> tuple[Tensor, ...]: """Forward pass through the Wrenformer. diff --git a/examples/cgcnn-example.py b/examples/cgcnn-example.py index 3f1f1493..f02b1e88 100644 --- a/examples/cgcnn-example.py +++ b/examples/cgcnn-example.py @@ -11,7 +11,7 @@ from aviary.utils import results_multitask, train_ensemble -def main( # noqa: D103 +def main( data_path, targets, tasks, @@ -52,6 +52,7 @@ def main( # noqa: D103 device=None, **kwargs, ): + """Train and evaluate a CGCNN model.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The model will run on the {args.device} device") diff --git a/examples/roost-example.py b/examples/roost-example.py index fcef5130..1c5532b0 100644 --- a/examples/roost-example.py +++ b/examples/roost-example.py @@ -10,7 +10,7 @@ from aviary.utils import results_multitask, train_ensemble -def main( # noqa: D103 +def main( data_path, targets, tasks, @@ -45,6 +45,7 @@ def main( # noqa: D103 device=None, **kwargs, ): + """Train and evaluate a Roost model.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The model will run on the {args.device} device") diff --git a/examples/wren-example.py b/examples/wren-example.py index 7c40226f..6a68b401 100644 --- a/examples/wren-example.py +++ b/examples/wren-example.py @@ -10,7 +10,7 @@ from aviary.wren.model import Wren -def main( # noqa: D103 +def main( data_path, targets, tasks, @@ -47,6 +47,7 @@ def main( # noqa: D103 device=None, **kwargs, ): + """Train and evaluate a Wren model.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The model will run on the {args.device} device") diff --git a/pyproject.toml b/pyproject.toml index 8b5c5a8f..ead8722e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,16 +105,15 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "C408", # Unnecessary dict call - rewrite as a literal - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - "D205", # 1 blank line required between summary line and description - "E731", # Do not assign a lambda expression, use a def - "PD901", # pandas-df-variable-name - "PLC1901", # compare-to-empty-string - "PLR", # pylint refactor - "PT006", # pytest-parametrize-names-wrong-type + "C408", # Unnecessary dict call - rewrite as a literal + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "E731", # Do not assign a lambda expression, use a def + "PD901", # pandas-df-variable-name + "PLR", # pylint refactor + "PT006", # pytest-parametrize-names-wrong-type ] pydocstyle.convention = "google" isort.known-third-party = ["wandb"]