From ed8380f643440899a56aefe4e304a68d87971f72 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 25 Sep 2023 09:00:57 +0000 Subject: [PATCH] replace vmap with list comprehension --- pyscf_ipu/experimental/nuclear_gradients.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index e01976a..e72b881 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -36,13 +36,9 @@ def grad_primitive_integral( Float3: Gradient of the integral with respect to cartesian axes. """ - axes = jnp.arange(3) - lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) - t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) - - lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) - t2 = a.lmn * vmap(primitive_op, (0, None))(lhs_m1, b) - grad_out = t1 - t2 + t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)] + t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)] + grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2) return grad_out