From 417ad4cca8d9c8ba709f685038cf610b92f526fa Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 26 Sep 2023 08:08:25 +0000 Subject: [PATCH] add num_segments to support jit compilation --- pyscf_ipu/experimental/nuclear_gradients.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index e72b881..f008e1c 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -111,8 +111,10 @@ def take_primitives(indices): out = cl * cr * out.T out = out.reshape(3, basis.num_primitives, basis.num_primitives) - out = segment_sum(jnp.rollaxis(out, 1), orbital_index) - out = segment_sum(jnp.rollaxis(out, -1), orbital_index) + out = jnp.rollaxis(out, 1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) + out = jnp.rollaxis(out, -1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) return jnp.rollaxis(out, -1)