Skip to content

Commit

Permalink
Fix tests (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jan 13, 2024
1 parent 64b7044 commit a8e44b7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 13 deletions.
13 changes: 9 additions & 4 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
# PyTorch does not support Python 3.11 on non-Linux platforms
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
# PyTorch now fully supports Python =<3.11
# see: https://github.com/pytorch/pytorch/issues/86566
#
# PyTorch does not support Python 3.12 (all platforms)
# see: https://github.com/pytorch/pytorch/issues/110436
exclude:
- os: ubuntu-latest
python-version: "3.12"
- os: macos-latest
python-version: "3.11"
python-version: "3.12"
- os: windows-latest
python-version: "3.11"
python-version: "3.12"

defaults:
run:
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ repos:
hooks:
- id: mypy
additional_dependencies: [types-all]
exclude: 'test/conftest.py'
pass_filenames: false
args: [--config-file=pyproject.toml, --ignore-missing-imports, src]
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ disallow_untyped_defs = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
exclude = '''
(?x)
^test?s/conftest.py$
'''
exclude = '''(?x)(
test/conftest.py
)'''


[tool.coverage.run]
Expand Down
2 changes: 1 addition & 1 deletion test/test_grad/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_fail() -> None:
param = {"a1": numbers}

# differentiable variable is not a tensor
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
hessian(dftd3, (numbers, positions, param), argnums=2)


Expand Down
43 changes: 40 additions & 3 deletions test/test_model/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""
Test the reference.
"""
from typing import Union
from typing import Optional, Union
from unittest.mock import patch

import pytest
import torch
from tad_mctc.convert import str_to_device

from tad_dftd3 import reference
from tad_dftd3.typing import DD
from tad_dftd3.typing import DD, Any, Tensor, TypedDict

from ..conftest import DEVICE

Expand All @@ -37,8 +38,12 @@ def test_reference_dtype(dtype: torch.dtype) -> None:

@pytest.mark.parametrize("dtype", [torch.float16, None])
def test_reference_dtype_both(dtype: Union[torch.dtype, None]) -> None:
class DDNone(TypedDict):
device: torch.device
dtype: Optional[torch.dtype]

dev = torch.device("cpu")
dd = {"device": dev, "dtype": dtype}
dd: DDNone = {"device": dev, "dtype": dtype}
ref = reference.Reference(device=dev).to(**dd)
assert ref.dtype == torch.tensor(1.0, dtype=dtype).dtype

Expand All @@ -63,6 +68,38 @@ def test_reference_device(device_str: str, device_str2: str) -> None:
ref.device = device


def test_reference_different_devices() -> None:
# Custom Tensor class with overridable device property
class MockTensor(Tensor):
@property
def device(self) -> Any:
return self._device

@device.setter
def device(self, value: Any) -> None:
self._device = value

# Custom mock functions
def mock_load_cn(*_: Any, **__: Any) -> Tensor:
tensor = MockTensor([1, 2, 3])
tensor.device = torch.device("cpu")
return tensor

def mock_load_c6(*_: Any, **__: Any) -> Tensor:
tensor = MockTensor([4, 5, 6])
tensor.device = torch.device("cuda")
return tensor

with patch("tad_dftd3.reference._load_cn", new=mock_load_cn):
with patch("tad_dftd3.reference._load_c6", new=mock_load_c6):
with pytest.raises(RuntimeError) as exc:
# Assuming the device is not explicitly passed, so it picks
# from _load_cn and _load_c6
reference.Reference()

assert "All tensors must be on the same device!" in str(exc.value)


def test_reference_fail() -> None:
c6 = reference._load_c6() # pylint: disable=protected-access

Expand Down

0 comments on commit a8e44b7

Please sign in to comment.