diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index e72b881..f008e1c 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -111,8 +111,10 @@ def take_primitives(indices): out = cl * cr * out.T out = out.reshape(3, basis.num_primitives, basis.num_primitives) - out = segment_sum(jnp.rollaxis(out, 1), orbital_index) - out = segment_sum(jnp.rollaxis(out, -1), orbital_index) + out = jnp.rollaxis(out, 1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) + out = jnp.rollaxis(out, -1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) return jnp.rollaxis(out, -1)