Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the JAX Scan dispatcher #1202

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Sep 20, 2022

This PR tries to address the issues observed in #710 and #924 with the transpilation of Scan operators. Most importantly, we increase the test coverage of Scan's functionalities.

@brandonwillard brandonwillard added bug Something isn't working JAX Involves JAX transpilation Scan Involves the `Scan` `Op` labels Sep 20, 2022
@rlouf
Copy link
Member Author

rlouf commented Sep 23, 2022

We can actually work around a lot of the dynamic indexing issues with jax.lax.dynamic_slice so I'm currently fixing this throughout the dispatcher. The basic test with the fixed number of steps is still not passing because of this, but I'm positive that once that change has been applied throughout many of the issues we thought were omnistaging related should disappear.

Then I'll keep making my way through testing more and more scan features.

@rlouf
Copy link
Member Author

rlouf commented Oct 14, 2022

After a lot of messing around I decided to go for a full rewrite and follow the Numba implementation. I have a minimal version that passes the first 3 xit_xot_types test.

That JAX easily complains about dynamic slicing may be a blessing in disguise as it highlights some gaps in Aesara's rewrites, e.g. with #1257 and others. Workarounds that I have currently had to implement could be easily avoided using the adequate rewrites at compile time.

I also switched to run the test without rewrites, and I should probably start gathering a set of rewrites that would help with transpilation. How would we go about having backend-specific rewrites?

  • Remove the unnecessary try...except for the DeepCopyOp since jnp.copy is implemented in JAX

@brandonwillard
Copy link
Member

How would we go about having backend-specific rewrites?

Numba mode already specializes its rewrites, so check out its definition in aesara.compile.mode.

@rlouf rlouf force-pushed the rewrite-jax-scan branch 2 times, most recently from 4bd71bb to a1b7b5c Compare October 17, 2022 20:07
@rlouf
Copy link
Member Author

rlouf commented Oct 18, 2022

This is turning into a much bigger PR than expected as I am also trying to fix any issue that prevents me from running the scan tests. Left to be done on the scan side:

  • while
  • mit-mots
  • The example with a random variable returns the correct output, but there is an error linked to the shared variable output that I need to investigate.

While I'm at it I am going to fix as many known issues with the JAX dispatcher as possible (issues and tests marked as xfail) as long as I do not need to play with the rewrites. This will be the object of a second PR, as this one is already a big improvement over the existing dispatcher.

@codecov
Copy link

codecov bot commented Oct 18, 2022

Codecov Report

Merging #1202 (c7097dd) into main (2434cb4) will increase coverage by 0.14%.
The diff coverage is 90.00%.

❗ Current head c7097dd differs from pull request most recent head b09a40e. Consider uploading reports for the commit b09a40e to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1202      +/-   ##
==========================================
+ Coverage   74.35%   74.49%   +0.14%     
==========================================
  Files         177      173       -4     
  Lines       49046    48658     -388     
  Branches    10379    10390      +11     
==========================================
- Hits        36468    36250     -218     
+ Misses      10285    10112     -173     
- Partials     2293     2296       +3     
Impacted Files Coverage Δ
aesara/link/jax/dispatch/scalar.py 94.80% <66.66%> (-2.60%) ⬇️
aesara/link/jax/linker.py 93.10% <82.60%> (-6.90%) ⬇️
aesara/link/jax/dispatch/elemwise.py 81.42% <83.33%> (+0.83%) ⬆️
aesara/link/jax/dispatch/scan.py 91.47% <91.01%> (+76.22%) ⬆️
aesara/link/jax/dispatch/basic.py 92.30% <100.00%> (+8.43%) ⬆️
aesara/link/jax/dispatch/shape.py 86.48% <100.00%> (-1.98%) ⬇️
aesara/link/jax/dispatch/subtensor.py 100.00% <100.00%> (+32.07%) ⬆️
aesara/link/jax/dispatch/tensor_basic.py 96.92% <100.00%> (+4.85%) ⬆️
aesara/compile/sharedvalue.py 80.24% <0.00%> (-13.51%) ⬇️
aesara/link/jax/dispatch/extra_ops.py 86.56% <0.00%> (-8.96%) ⬇️
... and 65 more

@rlouf rlouf force-pushed the rewrite-jax-scan branch 3 times, most recently from a761818 to 0183921 Compare October 19, 2022 09:01
@rlouf
Copy link
Member Author

rlouf commented Oct 19, 2022

The following test with a RandomStream fails:

def test_nit_sot_shared():
    res, updates = scan(
        fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
            0, 1, name="a"
        ),
        n_steps=3,
    )

    jax_fn = function((), res, updates=updates, mode="JAX")
    jax_res = jax_fn()
    assert jax_res.shape == (3,)

The values are correct, but the __set__ method of aesara.link.basic.Container fails for the shared state when it calls aesara.tensor.random.type.RandomState.filter:

 65  	            gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
 66  	            state_keys = ["key", "pos"]
 67  	
 68  	            for key in gen_keys:
 69  	                if key not in data:
 70  	                    raise TypeError()
 71  	
 72  	            for key in state_keys:
 73  	                if key not in data["state"]:
 74  	                    raise TypeError()
 75  	
 76  	            state_key = data["state"]["key"]
 77  	            if state_key.shape == (624,) and state_key.dtype == np.uint32:
 78  	                # TODO: Add an option to convert to a `RandomState` instance?
 79  	                return data

Indeed, the shared state for random variables in the JAX backend also contains a jax_state key and state_key.shape = (624,3). Modifying this function to just return data makes the test pass, but we should not have to change code in Aesara everytime we want to accomodate a new backend.

The filter method of types seems to introduce a coupling between linkers and the Aesara IR that we should probably avoid having. In the meantime I am not sure how to go about solving the issue from the JAX linker.

Plus I don't think we need to carry this state around in the JAX backend, isn't only jax_state required here?

@rlouf rlouf force-pushed the rewrite-jax-scan branch 5 times, most recently from 973ec08 to 3de5fb3 Compare November 1, 2022 23:22
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the commit entitled "Return a scalar when the tensor values is a scalar", is there an associated MWE/test case?

Also, the commit description mentions that ScalarFromTensor is being called on scalars, and I want to make sure that those input scalars are TensorTypes scalars, and not ScalarType scalars. The latter would imply that we're missing a rewrite for useless ScalarFromTensors.

@rlouf rlouf force-pushed the rewrite-jax-scan branch 4 times, most recently from 2e7b3a8 to fd37b21 Compare November 15, 2022 20:05
@rlouf
Copy link
Member Author

rlouf commented Nov 16, 2022

  • Slice with a dynamic index:
import aesara
import aesara.tensor as at

a = at.iscalar("a")
x = at.arange(3)
out = x[:a]
aesara.dprint(out)
# Subtensor{:int32:} [id A]
#  |ARange{dtype='int64'} [id B]
#  | |TensorConstant{0} [id C]
#  | |TensorConstant{3} [id D]
#  | |TensorConstant{1} [id E]
#  |ScalarFromTensor [id F]
#    |a [id G]

try:
    fn = aesara.function((a,), out, mode="JAX")
    fn(1)
except Exception as e:
    print(f"\n{e}")
# Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: DeepCopyOp(Subtensor{:int32:}.0)
# Toposort index: 2
# Inputs types: [TensorType(int64, (None,))]
# Inputs shapes: [()]
# Inputs strides: [()]
# Inputs values: [array(1, dtype=int32)]
# Outputs clients: [['output']]

In this case there are two solutions:

  • Use jax.lax.dynamic_slice to translate Subtensor.
  • Setting a as static_argnum when jit-compiling; this is not the best solution since the function will be re-compiled every time the value of a changes, but this mechanism will be useful in situations where there is no other choice.

@brandonwillard
Copy link
Member

  • Setting a as static_argnum when jit-compiling; this is not the best solution since the function will be re-compiled every time the value of a changes, but this mechanism will be useful in situations where there is no other choice.

As I recall, the trouble with using that is that it's limited to only the (outermost) graph inputs, and we can't compose jax.jited functions that use that setting. If that's not the case, then it's worth trying; otherwise, we have to find workable paths from input-restricted JAX operations to the graph inputs, and that's probably not possible in some cases.

@rlouf
Copy link
Member Author

rlouf commented Dec 7, 2022

As I recall, the trouble with using that is that it's limited to only the (outermost) graph inputs, and we can't compose jax.jited functions that use that setting. If that's not the case, then it's worth trying; otherwise, we have to find workable paths from input-restricted JAX operations to the graph inputs, and that's probably not possible in some cases.

We can always ask users to JIT-compile functions themselves if that's the case, and raise a warning at compilation ("JAX will only be able to JIT-compile your function if you specifiy the {input_position}-th argument ({variable_name}) as static").

Given the number of issues with the JAX backend this work is uncovering, I decided to break the changes down in several smaller PRs and fix the issues unrelated to Scan there, like I did for RandomVariables in #1284. I will remove related commits in this PR. In the meantime we can either decide to wrap this up marking the tests failing for other reasons as xfail, and leave mit-mots for later as I am currently unable to make the test run for unrelated reasons, or decide to keep this open until everything else is fixed.

@rlouf rlouf force-pushed the rewrite-jax-scan branch 3 times, most recently from d718bdd to e528e44 Compare December 8, 2022 08:33
@rlouf
Copy link
Member Author

rlouf commented Dec 8, 2022

The following code fails:

import aesara
import aesara.tensor as at


a_at = at.dvector("a")
res, updates = aesara.scan(
   fn=lambda a_t: 2 * a_t,
   sequences=a_at
)
jax_fn = aesara.function((a_at,), res, updates=updates, mode="JAX")
jax_fn([0, 1, 2, 3, 4])
# IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: Elemwise{mul,no_inplace}(TensorConstant{(1,) of 2.0}, Subtensor{int64:int64:int8}.0)

I print the associated function graph:

aesara.dprint(jax_fn)
# Elemwise{mul,no_inplace} [id A] 5
#  |TensorConstant{(1,) of 2.0} [id B]
#  |Subtensor{int64:int64:int8} [id C] 4
#    |a [id D]
#    |ScalarFromTensor [id E] 3
#    | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}}[(0, 0)] [id F] 2
#    |   |Shape_i{0} [id G] 0
#    |   | |a [id D]
#    |   |TensorConstant{0} [id H]
#    |   |TensorConstant{0} [id I]
#    |ScalarFromTensor [id J] 1
#    | |Shape_i{0} [id G] 0
#    |ScalarConstant{1} [id K]

jax_fn.maker.fgraph.toposort()[4].tag
# scratchpad{'imported_by': ['local_subtensor_merge']}

jax_fn.maker.fgraph.toposort()[2].tag
# scratchpad{'imported_by': ['inplace_elemwise_optimizer']}

# jax_fn.maker.fgraph.toposort()[1].tag
scratchpad{'imported_by': ['local_subtensor_merge']}

Several remarks:

  • The graphs created by inplace_elemwise_optimizer are often problematic in the JAX backend; the switch statement can produce traced arrays in situation where we could have kept concrete values;
  • ScalarFromTensor is called on Shape_i of a vector; fixing or adding a rewrite Aesara-side could prevent this.
  • Elemwise{mul} transpiles to jax.lax.mul which produces a TracedArray while using the python operator * does not. This is not a problem in this case, but is in general.

@rlouf
Copy link
Member Author

rlouf commented Dec 8, 2022

The following code also fails, because of an Elemwise{add} that transforms concrete values in a TracedArray:

import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker


opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)


res, updates = aesara.scan(
    fn=lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
    outputs_info=[
        {"initial": at.as_tensor(1.0, dtype="floatX"), "taps": [-1]},
        {"initial": at.as_tensor(0.5, dtype="floatX"), "taps": [-1]},
    ],
    n_steps=10,
)
jax_fn = function((), res, updates=updates, mode=jax_mode)

aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 17
#  |for{cpu,scan_fn}.0 [id B] 16
#  | |TensorConstant{10} [id C]
#  | |IncSubtensor{Set;:int64:} [id D] 15
#  | | |AllocEmpty{dtype='float64'} [id E] 14
#  | | | |Elemwise{add,no_inplace} [id F] 13
#  | | |   |TensorConstant{10} [id C]
#  | | |   |Subtensor{int64} [id G] 11
#  | | |     |Shape [id H] 10
#  | | |     | |Unbroadcast{0} [id I] 9
#  | | |     |   |InplaceDimShuffle{x} [id J] 8
#  | | |     |     |TensorConstant{1.0} [id K]
#  | | |     |ScalarConstant{0} [id L]
#  | | |Unbroadcast{0} [id I] 9
#  | | |ScalarFromTensor [id M] 12
#  | |   |Subtensor{int64} [id G] 11
#  | |IncSubtensor{Set;:int64:} [id N] 7
#  |   |AllocEmpty{dtype='float64'} [id O] 6
#  |   | |Elemwise{add,no_inplace} [id P] 5
#  |   |   |TensorConstant{10} [id C]
#  |   |   |Subtensor{int64} [id Q] 3
#  |   |     |Shape [id R] 2
#  |   |     | |Unbroadcast{0} [id S] 1
#  |   |     |   |InplaceDimShuffle{x} [id T] 0
#  |   |     |     |TensorConstant{0.5} [id U]
#  |   |     |ScalarConstant{0} [id V]
#  |   |Unbroadcast{0} [id S] 1
#  |   |ScalarFromTensor [id W] 4
#  |     |Subtensor{int64} [id Q] 3
#  |ScalarConstant{1} [id X]
# Subtensor{int64::} [id Y] 18
#  |for{cpu,scan_fn}.1 [id B] 16
#  |ScalarConstant{1} [id Z]

# Inner graphs:

# for{cpu,scan_fn}.0 [id B]
#  >Elemwise{mul,no_inplace} [id BA]
#  > |TensorConstant{2} [id BB]
#  > |*0-<TensorType(float64, ())> [id BC] -> [id D]
#  >Elemwise{mul,no_inplace} [id BD]
#  > |TensorConstant{2} [id BE]
#  > |*1-<TensorType(float64, ())> [id BF] -> [id N]

# for{cpu,scan_fn}.1 [id B]
#  >Elemwise{mul,no_inplace} [id BA]
#  >Elemwise{mul,no_inplace} [id BD]

JAX indeed complains that the input to AllocEmpty (shape) is a TracedArray. The following fails for the same reasons:

import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker

opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)

res, updates = aesara.scan(
    fn=lambda a_tm1: 2 * a_tm1,
    outputs_info=[
        {"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}
    ],
    n_steps=6,
)

jax_fn = function((), res, updates=updates, mode=jax_mode)
aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 8
#  |for{cpu,scan_fn} [id B] 7
#  | |TensorConstant{6} [id C]
#  | |IncSubtensor{Set;:int64:} [id D] 6
#  |   |AllocEmpty{dtype='float64'} [id E] 5
#  |   | |Elemwise{add,no_inplace} [id F] 4
#  |   |   |TensorConstant{6} [id C]
#  |   |   |Subtensor{int64} [id G] 2
#  |   |     |Shape [id H] 1
#  |   |     | |Subtensor{:int64:} [id I] 0
#  |   |     |   |TensorConstant{[0. 1.]} [id J]
#  |   |     |   |ScalarConstant{2} [id K]
#  |   |     |ScalarConstant{0} [id L]
#  |   |Subtensor{:int64:} [id I] 0
#  |   |ScalarFromTensor [id M] 3
#  |     |Subtensor{int64} [id G] 2
#  |ScalarConstant{2} [id N]

# Inner graphs:

# for{cpu,scan_fn} [id B]
#  >Elemwise{mul,no_inplace} [id O]
#  > |TensorConstant{2} [id P]
#  > |*0-<TensorType(float64, ())> [id Q] -> [id D]
#     fn = function((), res, updates=updates)
#     assert np.allclose(fn(), jax_fn())

@rlouf
Copy link
Member Author

rlouf commented Dec 9, 2022

I am currently waiting for #1338 to be merged to see what else needs to be fixed in the backend to allow the tests to pass.

@rlouf rlouf force-pushed the rewrite-jax-scan branch 4 times, most recently from 5eafcd5 to 0932c8e Compare December 15, 2022 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important JAX Involves JAX transpilation Scan Involves the `Scan` `Op`
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Scan's as_while in JAX Implement mit-mots conversion in the JAX translation of the Scan Op
2 participants