diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5308434..e8159d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.2.0 hooks: - id: black stages: [commit] diff --git a/test/test_grad/test_param.py b/test/test_grad/test_param.py index dc80e94..dfc166d 100644 --- a/test/test_grad/test_param.py +++ b/test/test_grad/test_param.py @@ -33,9 +33,7 @@ tol = 1e-8 -def gradchecker( - dtype: torch.dtype, name: str -) -> tuple[ +def gradchecker(dtype: torch.dtype, name: str) -> tuple[ Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], # autograd function tuple[Tensor, Tensor, Tensor, Tensor], # differentiable variables ]: @@ -85,9 +83,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None: assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE) -def gradchecker_batch( - dtype: torch.dtype, name1: str, name2: str -) -> tuple[ +def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], # autograd function tuple[Tensor, Tensor, Tensor, Tensor], # differentiable variables ]: diff --git a/test/test_grad/test_pos.py b/test/test_grad/test_pos.py index 8b02302..277c0b8 100644 --- a/test/test_grad/test_pos.py +++ b/test/test_grad/test_pos.py @@ -33,9 +33,7 @@ tol = 1e-8 -def gradchecker( - dtype: torch.dtype, name: str -) -> tuple[ +def gradchecker(dtype: torch.dtype, name: str) -> tuple[ Callable[[Tensor], Tensor], # autograd function Tensor, # differentiable variables ]: @@ -86,9 +84,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None: assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE) -def gradchecker_batch( - dtype: torch.dtype, name1: str, name2: str -) -> tuple[ +def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[ Callable[[Tensor], Tensor], # autograd function Tensor, # differentiable variables ]: