Skip to content

Commit

Permalink
Support while loops in the JAX dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 1, 2022
1 parent 1e8c7fe commit 3de5fb3
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 33 deletions.
242 changes: 230 additions & 12 deletions aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,232 @@ def parse_outer_inputs(outer_inputs):

return outer_in

breakpoint()
if op.info.as_while:
raise NotImplementedError("While loops are not supported in the JAX backend.")
# The inner function returns a boolean as the last value.
return make_jax_while_fn(scan_inner_fn, parse_outer_inputs, input_taps)
else:
return make_jax_scan_fn(
scan_inner_fn,
parse_outer_inputs,
input_taps,
)
return make_jax_scan_fn(scan_inner_fn, parse_outer_inputs, input_taps)


def make_jax_while_fn(
scan_inner_fn: Callable,
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
input_taps: Dict,
):
"""Create a `jax.lax.while_loop` function to perform `Scan` computations when it
is used as while loop.
`jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
must return a value of the same type (Pytree structure, shape and dtype of
the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
current value and returns a boolean that indicates whether to keep iterating
or not.
The JAX `while_loop` needs to perform the following operations:
1. Extract the inner-inputs;
2. Build the initial carry value;
3. Inside the loop:
1. `carry` -> inner-inputs;
2. inner-outputs -> `carry`
4. Post-process the `carry` storage and return outputs
"""

def build_while_carry(outer_in):
"""Build the inputs to `jax.lax.scan` from the outer-inputs."""
init_carry = {
name: outer_in[name]
for name in ["mit_sot", "sit_sot", "nit_sot", "shared", "sequences", "non_sequences"]
}
init_carry["step"] = 0
return init_carry

def build_inner_outputs_map(outer_in):
"""Map the inner-output variables to their position in the tuple returned by the inner function.
TODO: Copied from the scan builder
Inner-outputs are ordered as follow:
- mit-mot-outputs
- mit-sot-outputs
- sit-sot-outputs
- nit-sots (no carry)
- shared-outputs
[+ while-condition]
"""
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]

offset = 0
inner_output_idx = defaultdict(list)
for name in inner_outputs_names:
num_outputs = len(outer_in[name])
for i in range(num_outputs):
inner_output_idx[name].append(offset + i)
offset += num_outputs

return inner_output_idx

def from_carry_storage(carry, step, input_taps):
"""Fetch the inner inputs from the values stored in the carry array.
`Scan` passes storage arrays as inputs, which are then read from and
updated in the loop body. At each step we need to read from this array
the inputs that will be passed to the inner function.
This mechanism is necessary because we handle multiple-input taps within
the `scan` instead of letting users manage the memory in the use cases
where this is necessary.
TODO: Copied from the scan builder
"""

def fetch(carry, step, offset):
return carry[step + offset]

inner_inputs = []
for taps, carry_element in zip(input_taps, carry):
storage_size = -min(taps)
offsets = [storage_size + tap for tap in taps]
inner_inputs.append(
[fetch(carry_element, step, offset) for offset in offsets]
)

return sum(inner_inputs, [])

def to_carry_storage(inner_outputs, carry, step, input_taps):
"""Create the new carry array from the inner output
`Scan` passes storage arrays as inputs, which are then read from and
updated in the loop body. At each step we need to update this array
with the outputs of the inner function
TODO: Copied from the scan builder
"""
new_carry_element = []
for taps, carry_element, output in zip(input_taps, carry, inner_outputs):
new_carry_element.append(
[carry_element.at[step - tap].set(output) for tap in taps]
)

return sum(new_carry_element, [])

def while_loop(*outer_inputs):

outer_in = parse_outer_inputs(outer_inputs)
init_carry = build_while_carry(outer_in)
inner_output_idx = build_inner_outputs_map(outer_in)

def inner_inputs_from_carry(carry):
"""Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
Inner-inputs are ordered as follows:
- sequences
- mit-mot inputs
- mit-sot inputs
- sit-sot inputs
- shared-inputs
- non-sequences
"""
current_step = carry["step"]

inner_in_seqs = carry["sequences"][current_step]
inner_in_mit_sot = from_carry_storage(
carry["mit_sot"], current_step, input_taps["mit_sot"]
)
inner_in_sit_sot = from_carry_storage(
carry["sit_sot"], current_step, input_taps["sit_sot"]
)
inner_in_shared = carry.get("shared", [])
inner_in_non_sequences = carry.get("non_sequences", [])

return sum(
[
inner_in_seqs,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_sequences,
],
[],
)

def carry_from_inner_outputs(inner_outputs):
step = carry["step"]
new_carry = {
"mit_sot": [],
"sit_sot": [],
"nit-sot": [],
"shared": [],
"step": step + 1,
"sequences": carry["sequences"],
"non_sequences": carry["non_sequences"],
}

if "shared" in inner_output_idx:
shared_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["shared"]
]
new_carry["shared"] = shared_inner_outputs

if "mit_sot" in inner_output_idx:
mit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
]
new_carry["mit_sot"] = to_carry_storage(
mit_sot_inner_outputs, carry["mit_sot"], step, input_taps["mit_sot"]
)

if "sit_sot" in inner_output_idx:
sit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
]
new_carry["sit_sot"] = to_carry_storage(
sit_sot_inner_outputs, carry["sit_sot"], step, input_taps["sit_sot"]
)
if "nit_sot" in inner_output_idx:
nit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["nit_sot"]
]
new_carry["nit_sot"] = to_carry_storage(
nit_sot_inner_outputs, carry["nit_sot"], step, input_taps["nit_sot"]
)

return new_carry

def cond_fn(carry):
# The inner-function of `Scan` returns a boolean as the last
# value. This needs to be included in `carry`.
# TODO: Will it return `False` if the number of steps is exceeded?
return carry["do_continue"]

def body_fn(carry):
inner_inputs = inner_inputs_from_carry(carry)
inner_outputs = scan_inner_fn(*inner_inputs)
new_carry = carry_from_inner_outputs(inner_outputs)
return new_carry

# TODO
# The `Scan` implementation in the C backend will execute the
# function once before checking the termination condition, while
# `jax.lax.while_loop` checks the condition first. We thus need to call
# `body_fn` once before calling `jax.lax.while_loop`. This allows us,
# along with `n_steps`, to build the storage array for the `nit-sot`s
# since there is no way to know their shape and dtype before executing
# the function.
carry = body_fn(init_carry)
carry = jax.lax.while_loop(body_fn, cond_fn, carry)

# TODO: Post-process the storage arrays
outer_outputs = carry

return outer_outputs

return while_loop


def make_jax_scan_fn(
Expand All @@ -58,7 +276,8 @@ def make_jax_scan_fn(
stacked to the previous outputs. We use this to our advantage to build
`Scan` outputs without having to post-process the storage arrays.
The JAX scan function needs to perform the following operations:
The JAX `scan` function needs to perform the following operations:
1. Extract the inner-inputs;
2. Build the initial `carry` and `sequence` values;
3. Inside the loop:
Expand Down Expand Up @@ -151,7 +370,6 @@ def scan(*outer_inputs):
outer_in = parse_outer_inputs(outer_inputs)
n_steps, sequences, init_carry = build_jax_scan_inputs(outer_in)
inner_output_idx = build_inner_outputs_map(outer_in)

def scan_inner_in_args(carry, x):
"""Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
Expand Down Expand Up @@ -265,11 +483,11 @@ def body_fn(carry, x):
)

shared_output = tuple(last_carry["shared"])
results = results + shared_output
outer_outputs = results + shared_output

if len(results) == 1:
return results[0]
if len(outer_outputs) == 1:
return outer_outputs[0]

return results
return outer_outputs

return scan
52 changes: 31 additions & 21 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from aesara.link.jax.linker import JAXLinker
from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until
from aesara.tensor.math import gammaln, log
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.type import ivector, lscalar, scalar
Expand All @@ -24,6 +25,15 @@
jax_mode = Mode(JAXLinker(), opts)


def test_while():
res, updates = scan(
fn=lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
outputs_info=[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
n_steps=3,
)
jax_fn = function((), res, updates=updates, mode="JAX")


def test_sit_sot():
a_at = at.scalar("a", dtype="floatX")

Expand Down Expand Up @@ -87,27 +97,27 @@ def test_mit_sot_2():
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
[],
None,
lambda op: op.info.n_nit_sot > 0,
),
# (
# lambda a_t: 2 * a_t,
# [at.dvector("a")],
# [{}],
# [],
# None,
# [np.arange(10)],
# None,
# lambda op: op.info.n_seqs > 0,
# ),
# # nit-sot
# (
# lambda: at.as_tensor(2.0),
# [],
# [{}],
# [],
# 3,
# [],
# None,
# lambda op: op.info.n_nit_sot > 0,
# ),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
Expand Down

0 comments on commit 3de5fb3

Please sign in to comment.