Skip to content

Commit

Permalink
separate test file
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Oct 16, 2023
1 parent 64deee4 commit 58e6383
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 133 deletions.
6 changes: 2 additions & 4 deletions pyscf_ipu/experimental/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
def to_pyscf(
structure: Structure, basis_name: str = "sto-3g", unit: str = "Bohr"
) -> "gto.Mole":
position = np.asarray(structure.position)
mol = gto.Mole(unit=unit, spin=structure.num_electrons % 2, cart=True)
mol.atom = [
(symbol, pos)
for symbol, pos in zip(structure.atomic_symbol, structure.position)
]
mol.atom = [(symbol, pos) for symbol, pos in zip(structure.atomic_symbol, position)]
mol.basis = basis_name
mol.build(unit=unit)
return mol
Expand Down
1 change: 0 additions & 1 deletion pyscf_ipu/experimental/nuclear_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def grad_primitive_integral(
Returns:
Float3: Gradient of the integral with respect to cartesian axes.
"""

t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)]
t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)]
grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2)
Expand Down
1 change: 1 addition & 0 deletions pyscf_ipu/experimental/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

Float3 = Float[Array, "3"]
Float3xNxN = Float[Array, "3 N N"]
Float3xNxNxNxN = Float[Array, "3 N N N N"]
FloatNx3 = Float[Array, "N 3"]
FloatN = Float[Array, "N"]
FloatNxN = Float[Array, "N N"]
Expand Down
128 changes: 0 additions & 128 deletions test/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax import tree_map, vmap
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.basis import basisset
from pyscf_ipu.experimental.device import has_ipu, ipu_func
from pyscf_ipu.experimental.integrals import (
eri_basis,
eri_basis_sparse,
Expand All @@ -19,51 +17,10 @@
overlap_primitives,
)
from pyscf_ipu.experimental.interop import to_pyscf
from pyscf_ipu.experimental.nuclear_gradients import (
grad_kinetic_basis,
grad_nuclear_basis,
grad_overlap_basis,
)
from pyscf_ipu.experimental.primitive import Primitive
from pyscf_ipu.experimental.structure import molecule


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"])
def test_to_pyscf(basis_name):
mol = molecule("water")
basis = basisset(mol, basis_name)
pyscf_mol = to_pyscf(mol, basis_name)
assert basis.num_orbitals == pyscf_mol.nao


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"])
def test_gto(basis_name):
from pyscf.dft.numint import eval_rho

# Atomic orbitals
structure = molecule("water")
basis = basisset(structure, basis_name)
mesh, _ = uniform_mesh()
actual = basis(mesh)

mol = to_pyscf(structure, basis_name)
expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh))
assert_allclose(actual, expect_ao, atol=1e-6)

# Molecular orbitals
mf = mol.KS()
mf.kernel()
C = jnp.array(mf.mo_coeff, dtype=jnp.float32)
actual = basis.occupancy * C @ C.T
expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32)
assert_allclose(actual, expect, atol=1e-6)

# Electron density
actual = electron_density(basis, mesh, C)
expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda")
assert_allclose(actual, expect, atol=1e-6)


def test_overlap():
# Exercise 3.21 of "Modern quantum chemistry: introduction to advanced
# electronic structure theory."" by Szabo and Ostlund
Expand Down Expand Up @@ -151,19 +108,6 @@ def test_water_nuclear():
assert_allclose(actual, expect, atol=1e-4)


def eri_orbitals(orbitals):
def take(orbital, index):
p = tree_map(lambda *xs: jnp.stack(xs), *orbital.primitives)
p = tree_map(lambda x: jnp.take(x, index, axis=0), p)
c = jnp.take(orbital.coefficients, index)
return p, c

indices = [jnp.arange(o.num_primitives) for o in orbitals]
indices = [i.reshape(-1) for i in jnp.meshgrid(*indices)]
prim, coef = zip(*[take(o, i) for o, i in zip(orbitals, indices)])
return jnp.sum(jnp.prod(jnp.stack(coef), axis=0) * vmap(eri_primitives)(*prim))


def test_eri():
# PyQuante test cases for ERI
a, b, c, d = [Primitive()] * 4
Expand All @@ -172,18 +116,6 @@ def test_eri():
c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2
assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5)

# H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
h2 = molecule("h2")
basis = basisset(h2, "sto-3g")
indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)]
expected = [0.7746, 0.5697, 0.4441, 0.2970]

for ijkl, expect in zip(indices, expected):
actual = eri_orbitals([basis.orbitals[aoid] for aoid in ijkl])
assert_allclose(actual, expect, atol=1e-4)


def test_eri_basis():
# H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
h2 = molecule("h2")
basis = basisset(h2, "sto-3g")
Expand Down Expand Up @@ -219,63 +151,3 @@ def test_water_eri(sparse):
aosym = "s8" if sparse else "s1"
expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym)
assert_allclose(actual, expect, atol=1e-4)


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_overlap():
from pyscf_ipu.experimental.integrals import _overlap_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_overlap_primitives)(a, b)
assert_allclose(actual, overlap_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_kinetic():
from pyscf_ipu.experimental.integrals import _kinetic_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_kinetic_primitives)(a, b)
assert_allclose(actual, kinetic_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_nuclear():
from pyscf_ipu.experimental.integrals import _nuclear_primitives

# PyQuante test case for nuclear attraction integral
a, b = [Primitive()] * 2
c = jnp.zeros(3)
actual = ipu_func(_nuclear_primitives)(a, b, c)
assert_allclose(actual, -1.595769, atol=1e-5)


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_eri():
from pyscf_ipu.experimental.integrals import _eri_primitives

# PyQuante test cases for ERI
a, b, c, d = [Primitive()] * 4
actual = ipu_func(_eri_primitives)(a, b, c, d)
assert_allclose(actual, 1.128379, atol=1e-5)


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"])
def test_nuclear_gradients(basis_name):
h2 = molecule("h2")
scfmol = to_pyscf(h2, basis_name)
basis = basisset(h2, basis_name)

actual = grad_overlap_basis(basis)
expect = scfmol.intor("int1e_ipovlp_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

actual = grad_kinetic_basis(basis)
expect = scfmol.intor("int1e_ipkin_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

# TODO: investigate possible inconsistency in libcint outputs?
actual = grad_nuclear_basis(basis)
expect = scfmol.intor("int1e_ipnuc_cart", comp=3)
expect = -np.moveaxis(expect, 1, 2)
assert_allclose(actual, expect, atol=1e-6)
33 changes: 33 additions & 0 deletions test/test_nuclear_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import numpy as np
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.basis import basisset
from pyscf_ipu.experimental.interop import to_pyscf
from pyscf_ipu.experimental.nuclear_gradients import (
grad_kinetic_basis,
grad_nuclear_basis,
grad_overlap_basis,
)
from pyscf_ipu.experimental.structure import molecule


def test_nuclear_gradients():
basis_name = "sto-3g"
h2 = molecule("h2")
scfmol = to_pyscf(h2, basis_name)
basis = basisset(h2, basis_name)

actual = grad_overlap_basis(basis)
expect = scfmol.intor("int1e_ipovlp_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

actual = grad_kinetic_basis(basis)
expect = scfmol.intor("int1e_ipkin_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

# TODO: investigate possible inconsistency in libcint outputs?
actual = grad_nuclear_basis(basis)
expect = scfmol.intor("int1e_ipnuc_cart", comp=3)
expect = -np.moveaxis(expect, 1, 2)
assert_allclose(actual, expect, atol=1e-6)

0 comments on commit 58e6383

Please sign in to comment.