From 5ef887619855942a0f7492d5a8ca288de964b8b0 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Sun, 5 Jan 2025 16:30:16 -0700 Subject: [PATCH] fixed bug in energy loss --- chgnet/trainer/trainer.py | 12 --------- tests/test_trainer.py | 56 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index b742118..0ce7adb 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -208,7 +208,6 @@ def __init__( self.criterion = CombinedLoss( target_str=self.targets, criterion=criterion, - is_intensive=self.model.is_intensive, energy_loss_ratio=energy_loss_ratio, force_loss_ratio=force_loss_ratio, stress_loss_ratio=stress_loss_ratio, @@ -725,7 +724,6 @@ def __init__( *, target_str: str = "ef", criterion: str = "MSE", - is_intensive: bool = True, energy_loss_ratio: float = 1, force_loss_ratio: float = 1, stress_loss_ratio: float = 0.1, @@ -740,8 +738,6 @@ def __init__( Default = "ef" criterion: loss criterion to use Default = "MSE" - is_intensive (bool): whether the energy label is intensive - Default = True energy_loss_ratio (float): energy loss ratio in loss function Default = 1 force_loss_ratio (float): force loss ratio in loss function @@ -765,7 +761,6 @@ def __init__( else: raise NotImplementedError self.target_str = target_str - self.is_intensive = is_intensive self.energy_loss_ratio = energy_loss_ratio if "f" not in self.target_str: self.force_loss_ratio = 0 @@ -803,19 +798,12 @@ def forward( if self.allow_missing_labels: valid_value_indices = ~torch.isnan(targets["e"]) valid_e_target = targets["e"][valid_value_indices] - valid_atoms_per_graph = prediction["atoms_per_graph"][ - valid_value_indices - ] valid_e_pred = prediction["e"][valid_value_indices] if valid_e_pred.shape == torch.Size([]): valid_e_pred = valid_e_pred.view(1) else: valid_e_target = targets["e"] - valid_atoms_per_graph = prediction["atoms_per_graph"] valid_e_pred = prediction["e"] - if self.is_intensive: - valid_e_target = valid_e_target / valid_atoms_per_graph - valid_e_pred = valid_e_pred / valid_atoms_per_graph out["loss"] += self.energy_loss_ratio * self.criterion( valid_e_target, valid_e_pred diff --git a/tests/test_trainer.py b/tests/test_trainer.py index db769dc..cf048a2 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -11,7 +11,7 @@ from chgnet.data.dataset import StructureData, get_train_val_test_loader from chgnet.model import CHGNet -from chgnet.trainer import Trainer +from chgnet.trainer.trainer import CombinedLoss, Trainer if TYPE_CHECKING: from pathlib import Path @@ -50,6 +50,60 @@ chgnet = CHGNet.load() +def test_combined_loss() -> None: + criterion = CombinedLoss( + target_str="ef", + criterion="MSE", + energy_loss_ratio=1, + force_loss_ratio=1, + stress_loss_ratio=0.1, + mag_loss_ratio=0.1, + allow_missing_labels=False, + ) + target1 = {"e": torch.Tensor([1]), "f": [torch.Tensor([[[1, 1, 1], [2, 2, 2]]])]} + prediction1 = chgnet.predict_structure(NaCl) + prediction1 = { + "e": torch.from_numpy(prediction1["e"]).unsqueeze(0), + "f": [torch.from_numpy(prediction1["f"])], + "atoms_per_graph": torch.tensor([2]), + } + out1 = criterion( + targets=target1, + prediction=prediction1, + ) + target2 = { + "e": torch.Tensor([1]), + "f": [ + torch.Tensor( + [ + [ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + ] + ] + ) + ], + } + supercell = NaCl.make_supercell([2, 2, 1], in_place=False) + prediction2 = chgnet.predict_structure(supercell) + prediction2 = { + "e": torch.from_numpy(prediction2["e"]).unsqueeze(0), + "f": [torch.from_numpy(prediction2["f"])], + "atoms_per_graph": torch.tensor([8]), + } + out2 = criterion( + targets=target2, + prediction=prediction2, + ) + assert np.isclose(out1["loss"], out2["loss"], rtol=1e-04, atol=1e-05) + + def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: extra_run_config = dict(some_other_hyperparam=42) trainer = Trainer(