Skip to content

Commit

Permalink
add num_segments to support jit compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 26, 2023
1 parent ed8380f commit 417ad4c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pyscf_ipu/experimental/nuclear_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ def take_primitives(indices):

out = cl * cr * out.T
out = out.reshape(3, basis.num_primitives, basis.num_primitives)
out = segment_sum(jnp.rollaxis(out, 1), orbital_index)
out = segment_sum(jnp.rollaxis(out, -1), orbital_index)
out = jnp.rollaxis(out, 1)
out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals)
out = jnp.rollaxis(out, -1)
out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals)
return jnp.rollaxis(out, -1)


Expand Down

0 comments on commit 417ad4c

Please sign in to comment.