Skip to content

Commit

Permalink
Merge branch 'jax-fix-subtensor' into rewrite-jax-scan
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 8, 2022
2 parents d43bf96 + d2b41c4 commit d718bdd
Show file tree
Hide file tree
Showing 148 changed files with 3,416 additions and 2,791 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ jobs:
path: coverage

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
directory: ./coverage/
fail_ci_if_error: true
7 changes: 4 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -25,9 +25,10 @@ repos:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
language_version: python39
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
Expand All @@ -47,7 +48,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.982
rev: v0.991
hooks:
- id: mypy
additional_dependencies:
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ Contributing
We welcome bug reports and fixes and improvements to the documentation.

For more information on contributing, please see the
`contributing guide <https://github.com/aesara-devs/aesara/CONTRIBUTING.md>`.
`contributing guide <https://github.com/aesara-devs/aesara/CONTRIBUTING.md>`__.

A good place to start contributing is by looking through the issues
`here <https://github.com/aesara-devs/aesara/issues`.
`here <https://github.com/aesara-devs/aesara/issues>`__.

Support
=======
Expand Down
138 changes: 84 additions & 54 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import List, Optional, Sequence, cast
from typing import Dict, List, Optional, Sequence, Tuple, cast

import aesara.tensor as at
from aesara import function
Expand All @@ -19,7 +19,6 @@
clone_replace,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
Expand Down Expand Up @@ -82,6 +81,81 @@ def local_traverse(out):
return ret


def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())

dummy_shared_inputs = []
shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs

fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

fgraph.replace_all(zip(local_inputs, nominal_local_inputs))

for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
fgraph.inputs[i] = nom_inp
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
r"""
This creates an `Op` from inputs and outputs lists of variables.
Expand Down Expand Up @@ -333,66 +407,21 @@ def __init__(
if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")

for i in inputs + outputs:
if not isinstance(i, Variable):
for out in outputs:
if not isinstance(out, Variable):
raise TypeError(
f"Inputs and outputs must be Variable instances; got {i}"
f"Inputs and outputs must be Variable instances; got {out}"
)
if i in inputs:
if isinstance(i, Constant):
raise TypeError(f"Constants not allowed as inputs; {i}")
if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}")

for var in graph_inputs(outputs, inputs):
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not allowed here")
raise NotImplementedError("Updates and givens are not supported")

self.is_inline = inline

# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = []
for var in graph_inputs(outputs):
if isinstance(var, SharedVariable):
self.shared_inputs.append(var)

inputs, outputs = replace_nominals_with_dummies(inputs, outputs)

# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)

shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
]

replacements.update(dict(zip(self.shared_inputs, shared_vars)))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars,
replace=replacements,
copy_inputs_over=False,
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs

self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand All @@ -415,6 +444,7 @@ def __init__(
else:
self.set_lop_overrides("default")
self._lop_type = "lop"

self.set_rop_overrides(rop_overrides)

self._connection_pattern = connection_pattern
Expand Down
28 changes: 14 additions & 14 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,17 +848,17 @@ def _get_preallocated_maps(
or "ALL" in prealloc_modes
):
max_ndim = 0
rev_out_broadcastable = []
rev_out_shape = []
for r in considered_outputs:
if isinstance(r.type, TensorType):
if max_ndim < r.ndim:
rev_out_broadcastable += [True] * (r.ndim - max_ndim)
rev_out_shape += [1] * (r.ndim - max_ndim)
max_ndim = r.ndim
assert len(rev_out_broadcastable) == max_ndim
assert len(rev_out_shape) == max_ndim

for i, b in enumerate(r.broadcastable[::-1]):
rev_out_broadcastable[i] = rev_out_broadcastable[i] and b
out_broadcastable = rev_out_broadcastable[::-1]
for i, s in enumerate(r.type.shape[::-1]):
rev_out_shape[i] = 1 if rev_out_shape[i] == 1 and s == 1 else None
out_shape = rev_out_shape[::-1]

if "strided" in prealloc_modes or "ALL" in prealloc_modes:
check_ndim = config.DebugMode__check_preallocated_output_ndim
Expand Down Expand Up @@ -887,14 +887,14 @@ def _get_preallocated_maps(
# Moreover, to avoid memory problems, we do not test with strides
# 2 and -2 on those dimensions.
step_signs_list = []
for b in out_broadcastable[-check_ndim:]:
if b:
for s in out_shape[-check_ndim:]:
if s == 1:
step_signs_list.append((1,))
else:
step_signs_list.append((-1, 1))

# Use the same step on all dimensions before the last check_ndim.
if all(out_broadcastable[:-check_ndim]):
if all(s == 1 for s in out_shape[:-check_ndim]):
step_signs_list = [(1,)] + step_signs_list
else:
step_signs_list = [(-1, 1)] + step_signs_list
Expand All @@ -905,7 +905,7 @@ def _get_preallocated_maps(

# First, the dimensions above check_ndim, then the other ones
# Do not test with 2 or -2 for dimensions above check_ndim
steps = [step_signs[0]] * len(out_broadcastable[:-check_ndim])
steps = [step_signs[0]] * len(out_shape[:-check_ndim])
steps += [s * step_size for s in step_signs[1:]]

name = f"strided{tuple(steps)}"
Expand All @@ -932,8 +932,8 @@ def _get_preallocated_maps(

if "wrong_size" in prealloc_modes or "ALL" in prealloc_modes:
# For each dimension, try size-1, size, size+1
for dim, b in enumerate(out_broadcastable):
if b:
for dim, s in enumerate(out_shape):
if s == 1:
# The shape has to be 1
continue

Expand All @@ -947,11 +947,11 @@ def _get_preallocated_maps(
for r in considered_outputs:
if isinstance(r.type, TensorType):
r_shape_diff = shape_diff[: r.ndim]
out_shape = [
new_buf_shape = [
max((s + sd), 0)
for s, sd in zip(r_vals[r].shape, r_shape_diff)
]
new_buf = np.empty(out_shape, dtype=r.type.dtype)
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
wrong_size[r] = new_buf

Expand Down
36 changes: 17 additions & 19 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import logging
from copy import copy
from typing import Optional

Expand All @@ -16,11 +15,6 @@
from aesara.graph.fg import FunctionGraph


_logger = logging.getLogger("aesara.compile.function.pfunc")

__docformat__ = "restructuredtext en"


def rebuild_collect_shared(
outputs,
inputs=None,
Expand Down Expand Up @@ -78,10 +72,12 @@ def rebuild_collect_shared(
shared_inputs = []

def clone_v_get_shared_updates(v, copy_inputs_over):
"""
Clones a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared inputs,
and their default_update (if applicable) to update_d and update_expr.
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
Also, it appends all `SharedVariable`\s met along the way to
`shared_inputs` and their corresponding
`SharedVariable.default_update`\s (when applicable) to `update_d` and
`update_expr`.
"""
# this co-recurses with clone_a
Expand All @@ -103,7 +99,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
shared_inputs.append(v)
if hasattr(v, "default_update"):
if v.default_update is not None:
# Check that v should not be excluded from the default
# updates list
if no_default_updates is False or (
Expand Down Expand Up @@ -419,22 +415,24 @@ def construct_pfunc_ins_and_outs(
givens = []

if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or " "a tuple")
raise TypeError("The `params` argument must be a list or a tuple")

if not isinstance(no_default_updates, bool) and not isinstance(
no_default_updates, list
):
raise TypeError("no_default_update should be either a boolean or " "a list")
raise TypeError("The `no_default_update` argument must be a boolean or list")

if len(updates) > 0 and any(
isinstance(v, Variable) for v in iter_over_pairs(updates)
if len(updates) > 0 and not all(
isinstance(pair, (tuple, list))
and len(pair) == 2
and isinstance(pair[0], Variable)
for pair in iter_over_pairs(updates)
):
raise ValueError(
"The updates parameter must be an OrderedDict/dict or a list of "
"lists/tuples with 2 elements"
raise TypeError(
"The `updates` parameter must be an ordered mapping or a list of pairs"
)

# transform params into aesara.compile.In objects.
# Transform params into aesara.compile.In objects.
inputs = [
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
]
Expand Down
Loading

0 comments on commit d718bdd

Please sign in to comment.