Skip to content

Commit

Permalink
replace vmap with list comprehension
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 26, 2023
1 parent e446ddb commit ed8380f
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions pyscf_ipu/experimental/nuclear_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ def grad_primitive_integral(
Float3: Gradient of the integral with respect to cartesian axes.
"""

axes = jnp.arange(3)
lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1)
t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b)

lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1)
t2 = a.lmn * vmap(primitive_op, (0, None))(lhs_m1, b)
grad_out = t1 - t2
t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)]
t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)]
grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2)
return grad_out


Expand Down

0 comments on commit ed8380f

Please sign in to comment.