diff --git a/src/tad_dftd3/reference-c6.pt b/src/tad_dftd3/reference-c6.pt new file mode 100644 index 0000000..db078e2 Binary files /dev/null and b/src/tad_dftd3/reference-c6.pt differ diff --git a/src/tad_dftd3/reference.py b/src/tad_dftd3/reference.py index 1375b8d..5c2d67a 100644 --- a/src/tad_dftd3/reference.py +++ b/src/tad_dftd3/reference.py @@ -20,7 +20,6 @@ C6 dispersion coefficients. """ import os.path as op - import torch from ._typing import Any, NoReturn, Optional, Tensor @@ -142,11 +141,29 @@ def _load_cn( ) -def _load_c6( +def _load_c6_pt( dtype: torch.dtype = torch.double, device: Optional[torch.device] = None ) -> Tensor: """ - Load reference C6 coefficients from file and fill them into a tensor + Load reference C6 coefficients from torch file. + """ + path = op.join(op.dirname(__file__), "reference-c6.pt") + ref = torch.load(path).type(dtype).to(device) + return ref + + +def _load_c6_npy( + dtype: torch.dtype = torch.double, device: Optional[torch.device] = None +) -> Tensor: + """ + Load reference C6 coefficients from file and fill them into a tensor. + + Warning + ------- + The loops in this function are actually really slow and create a bottleneck + for the whole calculation. Since the output of this function is not + dependent on the system, we only use it to store the tensor (".pt" file) + and now skip the calculation. """ # pylint: disable=import-outside-toplevel @@ -170,9 +187,16 @@ def _load_c6( ij = i * (i - 1) // 2 + j - 1 if j < i else j * (j - 1) // 2 + i - 1 c6[i, j, :, :] = ref[ij, :, :].T if j < i else ref[ij, :, :] + # torch.save(c6, "reference-c6.pt") return c6 +def _load_c6( + dtype: torch.dtype = torch.double, device: Optional[torch.device] = None +) -> Tensor: + return _load_c6_pt(dtype=dtype, device=device) + + class Reference: """ Reference systems for the D3 dispersion model diff --git a/tests/test_model/test_load.py b/tests/test_model/test_load.py new file mode 100644 index 0000000..0a1bcfd --- /dev/null +++ b/tests/test_model/test_load.py @@ -0,0 +1,31 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test loading C6 coefficients. +""" +import torch + +from tad_dftd3 import reference + + +def test_ref(): + c6_np = reference._load_c6_npy(dtype=torch.double) + c6_pt = reference._load_c6_pt(dtype=torch.double) + + assert c6_np.shape == c6_pt.shape + assert (c6_np == c6_pt).all() + + maxelem = 104 # 103 + dummy + assert c6_np.shape == torch.Size((maxelem, maxelem, 7, 7))