Skip to content

Commit

Permalink
Breaking: rename robust loss functions (#77)
Browse files Browse the repository at this point in the history
* fix ruff PGH003 Use specific rule codes when ignoring type issues

* breaking: fix ruff N802 Function name RobustL(1|2)Loss should be lowercase

* fix ruff FBT003 Boolean positional is_best value in function call

* .pow(2) to **2

* try installing torch and torch-scatter 2.1.0 in CI

* bump torch and torch-scatter to 2.1.0 in readme and GH codespaces config
  • Loading branch information
janosh authored Oct 16, 2023
1 parent 5aa92f4 commit 8bd9a23
Show file tree
Hide file tree
Showing 20 changed files with 66 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"image": "mcr.microsoft.com/devcontainers/universal:2",
"waitFor": "onCreateCommand",
"updateContentCommand": "pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu && pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html && pip install -e .",
"updateContentCommand": "pip install torch==2.1.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu && pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html && pip install -e .",
"customizations": {
"codespaces": {
"openFiles": ["examples/notebooks/wren-example.ipynb"]
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: Tests

on:
push:
paths: ['**/*.py', .github/workflows/test.yml]
paths: ["**/*.py", .github/workflows/test.yml]
branches: [main]
pull_request:
paths: ['**/*.py', .github/workflows/test.yml]
paths: ["**/*.py", .github/workflows/test.yml]
branches: [main]

jobs:
Expand All @@ -24,8 +24,8 @@ jobs:

- name: Install dependencies
run: |
pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
pip install .[test]
- name: Run Tests
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
args: [--fix]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-case-conflict
- id: check-symlinks
Expand All @@ -33,7 +33,7 @@ repos:
- id: black-jupyter

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.6.0
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ The aim of `aviary` is to contain multiple models for materials discovery under
Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with

```sh
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
```

Make sure you replace `1.13.0` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.
Make sure you replace `2.1.0` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.

Then install `aviary` from source with

Expand Down
9 changes: 7 additions & 2 deletions aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def fit(
}

# TODO saving a model at each epoch may be slow?
save_checkpoint(checkpoint_dict, False, model_name, run_id)
save_checkpoint(
checkpoint_dict,
is_best=False,
model_name=model_name,
run_id=run_id,
)

# TODO when to save best models? should this be done task-wise in
# the multi-task case?
Expand Down Expand Up @@ -244,7 +249,7 @@ def evaluate(
# compute output
outputs = self(*inputs)

mixed_loss: Tensor = 0 # type: ignore
mixed_loss: Tensor = 0 # type: ignore[assignment]

for target_name, targets, output, normalizer in zip(
self.target_names, targets_list, outputs, normalizer_dict.values()
Expand Down
8 changes: 3 additions & 5 deletions aviary/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor


def RobustL1Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor:
def robust_l1_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor:
"""Robust L1 loss using a Lorentzian prior. Trains the model to learn to predict aleatoric
(per-sample) uncertainty.
Expand All @@ -21,7 +21,7 @@ def RobustL1Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Ten
return torch.mean(loss)


def RobustL2Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor:
def robust_l2_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Tensor:
"""Robust L2 loss using a Gaussian prior. Trains the model to learn to predict aleatoric
(per-sample) uncertainty.
Expand All @@ -34,7 +34,5 @@ def RobustL2Loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> Ten
Returns:
Tensor: Evaluated robust L2 loss
"""
loss = (
0.5 * (pred_mean - target).pow(2) * torch.exp(-2 * pred_log_std) + pred_log_std
)
loss = 0.5 * (pred_mean - target) ** 2 * torch.exp(-2 * pred_log_std) + pred_log_std
return torch.mean(loss)
8 changes: 4 additions & 4 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def __getitem__(self, idx: int):
f"{material_ids} ({composition}) composition cannot be parsed into elements"
) from exc

nele = len(elements)
n_elems = len(elements)
self_idx = []
nbr_idx = []
for i, _ in enumerate(elements):
self_idx += [i] * nele
nbr_idx += list(range(nele))
for idx in range(n_elems):
self_idx += [idx] * n_elems
nbr_idx += list(range(n_elems))

# convert all data to tensors
elem_weights = Tensor(weights)
Expand Down
2 changes: 1 addition & 1 deletion aviary/roost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
"cry_msg": cry_msg,
}

self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore
self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore[arg-type]

model_params = {
"robust": robust,
Expand Down
2 changes: 1 addition & 1 deletion aviary/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:

def __repr__(self) -> str:
pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn
return f"{type(self).__name__}(pow={pow:.3}, {gate_nn=}, {message_nn=})"
return f"{type(self).__name__}({pow=:.3}, {gate_nn=}, {message_nn=})"


class MessageLayer(nn.Module):
Expand Down
10 changes: 5 additions & 5 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from aviary import ROOT
from aviary.core import BaseModelClass, Normalizer, TaskType, np_softmax
from aviary.losses import RobustL1Loss
from aviary.losses import robust_l1_loss
from aviary.utils import get_metrics, print_walltime
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer
Expand Down Expand Up @@ -130,7 +130,7 @@ def train_model(
print(f"Pytorch running on {device=}")

loss_func = (
(RobustL1Loss if robust else torch.nn.L1Loss())
(robust_l1_loss if robust else torch.nn.L1Loss())
if task_type == reg_key
else (torch.nn.NLLLoss() if robust else torch.nn.CrossEntropyLoss())
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def train_model(
start=swa_start, epochs=int(swa_start * epochs), learning_rate=swa_lr
)
if task_type == reg_key and hasattr(train_loader, "df"):
train_df = getattr(train_loader, "df", train_loader.dataset.df) # type: ignore
train_df = getattr(train_loader, "df", train_loader.dataset.df) # type: ignore[union-attr]
targets = train_df[target_col]
run_params["dummy_mae"] = (targets - targets.mean()).abs().mean()
if timestamp:
Expand Down Expand Up @@ -415,11 +415,11 @@ def train_wrenformer(
embedding_type=embedding_type,
)
train_loader = df_to_in_mem_dataloader(
train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs # type: ignore
train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs # type: ignore[arg-type]
)

test_loader = df_to_in_mem_dataloader(
test_df, batch_size=512, shuffle=False, **data_loader_kwargs # type: ignore
test_df, batch_size=512, shuffle=False, **data_loader_kwargs # type: ignore[arg-type]
)

# embedding_len is the length of the embedding vector for a Wyckoff position encoding the
Expand Down
22 changes: 11 additions & 11 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from aviary import ROOT
from aviary.core import BaseModelClass, Normalizer, TaskType, sampled_softmax
from aviary.losses import RobustL1Loss, RobustL2Loss
from aviary.losses import robust_l1_loss, robust_l2_loss

if TYPE_CHECKING:
from types import ModuleType
Expand Down Expand Up @@ -224,9 +224,9 @@ def initialize_losses(
elif task == "regression":
if robust:
if loss_name_dict[name] == "L1":
loss_func_dict[name] = (task, RobustL1Loss)
loss_func_dict[name] = (task, robust_l1_loss)
elif loss_name_dict[name] == "L2":
loss_func_dict[name] = (task, RobustL2Loss)
loss_func_dict[name] = (task, robust_l2_loss)
else:
raise NameError(
"Only L1 or L2 losses are allowed for robust regression tasks"
Expand Down Expand Up @@ -472,7 +472,7 @@ def results_multitask(
results_dict[target_name] = defaultdict(
list
if task_type == "classification"
else lambda: np.zeros((ensemble_folds, len(test_set))) # type: ignore
else lambda: np.zeros((ensemble_folds, len(test_set))) # type: ignore[call-overload]
)

for ens_idx in range(ensemble_folds):
Expand Down Expand Up @@ -518,11 +518,11 @@ def results_multitask(
mean, log_std = output.unbind(dim=1)
preds = normalizer.denorm(mean.data.cpu())
ale_std = torch.exp(log_std).data.cpu() * normalizer.std
res_dict["ale"][ens_idx, :] = ale_std.view(-1).numpy() # type: ignore
res_dict["ale"][ens_idx, :] = ale_std.view(-1).numpy() # type: ignore[call-overload]
else:
preds = normalizer.denorm(output.data.cpu())

res_dict["preds"][ens_idx, :] = preds.view(-1).numpy() # type: ignore
res_dict["preds"][ens_idx, :] = preds.view(-1).numpy() # type: ignore[call-overload]

elif task_type == "classification":
if model.robust:
Expand All @@ -532,13 +532,13 @@ def results_multitask(
)
pre_logits = mean.data.cpu().numpy()
pre_logits_std = torch.exp(log_std).data.cpu().numpy()
res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore
res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore[union-attr]
else:
pre_logits = output.data.cpu().numpy()
logits = pre_logits.softmax(1)

res_dict["pre-logits"].append(pre_logits) # type: ignore
res_dict["logits"].append(logits) # type: ignore
res_dict["pre-logits"].append(pre_logits) # type: ignore[union-attr]
res_dict["logits"].append(logits) # type: ignore[union-attr]

res_dict["targets"] = targets

Expand All @@ -555,9 +555,9 @@ def results_multitask(
for target_name, task_type in task_dict.items():
print(f"\nTask: {target_name=} on test set")
if task_type == "regression":
print_metrics_regression(**results_dict[target_name]) # type: ignore
print_metrics_regression(**results_dict[target_name]) # type: ignore[arg-type]
elif task_type == "classification":
print_metrics_classification(**results_dict[target_name]) # type: ignore
print_metrics_classification(**results_dict[target_name]) # type: ignore[arg-type]

return results_dict

Expand Down
4 changes: 2 additions & 2 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def __init__(
self.identifiers = list(identifiers)
self.df = df

if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"):
elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json"

with open(elem_embedding) as emb_file:
self.elem_features = json.load(emb_file)

self.elem_emb_len = len(next(iter(self.elem_features.values())))

if sym_emb in ["bra-alg-off", "spg-alg-off"]:
if sym_emb in ("bra-alg-off", "spg-alg-off"):
sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json"

with open(sym_emb) as sym_file:
Expand Down
2 changes: 1 addition & 1 deletion aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
"cry_msg": cry_msg,
}

self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore
self.material_nn = DescriptorNetwork(**desc_dict) # type: ignore[arg-type]

model_params = {
"robust": robust,
Expand Down
10 changes: 5 additions & 5 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,12 @@ def count_crystal_dof(aflow_label: str) -> int:
Returns:
int: Number of free-parameters in given prototype
"""
num_params = 0
n_params = 0

aflow_label, _ = aflow_label.split(":") # chop off chemical system
_, pearson, spg, *wyks = aflow_label.split("_")

num_params += cry_param_dict[pearson[0]]
n_params += cry_param_dict[pearson[0]]

for wyk_letters_per_elem in wyks:
# normalize Wyckoff letters to start with 1 if missing digit
Expand All @@ -404,12 +404,12 @@ def count_crystal_dof(aflow_label: str) -> int:
sep_el_wyks = [
"".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)
]
num_params += sum(
n_params += sum(
float(n) * param_dict[spg][k]
for n, k in zip(sep_el_wyks[0::2], sep_el_wyks[1::2])
)

return int(num_params)
return n_params


def get_isopointal_proto_from_aflow(aflow_label: str) -> str:
Expand Down Expand Up @@ -455,7 +455,7 @@ def get_isopointal_proto_from_aflow(aflow_label: str) -> str:
isopointal: list[str] = []

for wyks_list in valid_permutations:
for trans in relab_dict[str(spg)]:
for trans in relab_dict[spg]:
t = str.maketrans(trans)
isopointal.append("_".join(wyks_list).translate(t))

Expand Down
2 changes: 1 addition & 1 deletion aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def df_to_in_mem_dataloader(
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

if embedding_type not in ["wyckoff", "composition"]:
if embedding_type not in ("wyckoff", "composition"):
raise ValueError(f"{embedding_type = } must be 'wyckoff' or 'composition'")

initial_embeddings = df[input_col].map(
Expand Down
2 changes: 1 addition & 1 deletion aviary/wrenformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
ResidualNetwork(out_hidden[0], n, out_hidden[1:]) for n in n_targets
)

def forward( # type: ignore
def forward( # type: ignore[override]
self, features: Tensor, mask: BoolTensor, *args
) -> tuple[Tensor, ...]:
"""Forward pass through the Wrenformer.
Expand Down
3 changes: 2 additions & 1 deletion examples/cgcnn-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aviary.utils import results_multitask, train_ensemble


def main( # noqa: D103
def main(
data_path,
targets,
tasks,
Expand Down Expand Up @@ -52,6 +52,7 @@ def main( # noqa: D103
device=None,
**kwargs,
):
"""Train and evaluate a CGCNN model."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The model will run on the {args.device} device")
Expand Down
3 changes: 2 additions & 1 deletion examples/roost-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aviary.utils import results_multitask, train_ensemble


def main( # noqa: D103
def main(
data_path,
targets,
tasks,
Expand Down Expand Up @@ -45,6 +45,7 @@ def main( # noqa: D103
device=None,
**kwargs,
):
"""Train and evaluate a Roost model."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The model will run on the {args.device} device")
Expand Down
3 changes: 2 additions & 1 deletion examples/wren-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aviary.wren.model import Wren


def main( # noqa: D103
def main(
data_path,
targets,
tasks,
Expand Down Expand Up @@ -47,6 +47,7 @@ def main( # noqa: D103
device=None,
**kwargs,
):
"""Train and evaluate a Wren model."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The model will run on the {args.device} device")
Expand Down
Loading

0 comments on commit 8bd9a23

Please sign in to comment.