Skip to content

Commit

Permalink
Fix dynamic slicing of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 23, 2022
1 parent db2a759 commit e63a54a
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,13 @@ def jax_inner_func(carry, x):
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
start_indices = [0] * scan_in_part.ndim
slice_sizes = list(scan_in_part.shape)
slice_sizes[0] = slice_sizes[0] - n_steps
scan_in_part_sliced = jax.lax.dynamic_slice(
scan_in_part, start_indices, slice_sizes
)
return jnp.concatenate([scan_in_part_sliced, scan_out_part], axis=0)

if scan_args.outer_in_mit_sot:
scan_out_final = [
Expand Down

0 comments on commit e63a54a

Please sign in to comment.