diff --git a/aesara/link/jax/dispatch/scan.py b/aesara/link/jax/dispatch/scan.py index 9f2c602d89..0ae2e681cb 100644 --- a/aesara/link/jax/dispatch/scan.py +++ b/aesara/link/jax/dispatch/scan.py @@ -26,8 +26,6 @@ def scan(*outer_inputs): raise NotImplementedError("sit-sot not supported") if len(outer_in_shared): raise NotImplementedError("shared variables not supported") - if len(outer_in_non_seqs): - raise NotImplementedError("non sequence are not supported") # If `output_infos` is empty we need to create an empty initial carry # value with the output's shape and dtype @@ -39,6 +37,7 @@ def scan(*outer_inputs): init_carry = outer_in_sit_sot sequences = outer_in_seqs + non_sequences = outer_in_non_seqs def scan_inner_in_args(carry, x): """Create an inner-input expression. @@ -59,7 +58,7 @@ def scan_inner_in_args(carry, x): else: inner_in_sit_sot = carry - return sum([inner_in_seqs, inner_in_sit_sot], []) + return sum([inner_in_seqs, inner_in_sit_sot, non_sequences], []) def scan_new_carry(inner_outputs): """Create a new carry expression diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 0c2249336a..5685a249ab 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -48,6 +48,17 @@ None, lambda op: op.info.n_nit_sot > 0, ), + # nit-sot, non_seq + ( + lambda c: at.as_tensor(2.0) * c, + [], + [{}], + [at.dscalar("c")], + 3, + [1.0], + None, + lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0, + ), ], ) def test_xit_xot_types(