Skip to content

Commit

Permalink
Lift Subtensor over Join
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 21, 2025
1 parent 844ae15 commit 508d5b9
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 26 deletions.
120 changes: 95 additions & 25 deletions pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable, Sequence
from typing import cast

import numpy as np
from numpy.core.numeric import ( # type: ignore
Expand All @@ -7,16 +8,18 @@
)

from pytensor import Variable
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
Join,
MakeVector,
alloc,
as_tensor,
expand_dims,
get_underlying_scalar_constant_value,
join,
register_infer_shape,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down Expand Up @@ -49,6 +52,7 @@
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType
from pytensor.tensor.variable import TensorVariable


def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
Expand All @@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index(
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)


def _lift_subtensor_non_axis(
local_subtensor_lift_rewrite: NodeRewriter,
fgraph: FunctionGraph,
variable: TensorVariable,
idx_tuple: tuple[int | slice],
axis: int,
old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])

# Lift the non-axis indexes by calling the rewrite itself
indexed_variable = variable[idxs_to_lift]
[indexed_variable] = cast(
list[TensorVariable],
local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner),
)
copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable)

# Then reintroduce the axis index
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
new_axis = axis - ndim_reduced_left
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
new_out = indexed_variable[idxs_to_keep]
copy_stack_trace(old_subtensor_variable, new_out)
return [new_out]

else:
return None


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down Expand Up @@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node):
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
# If there are more dimensions being indexed, we can split them
# And lift the non-axis indexes while keeping the axis index
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and sm.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])

# Lift the non-axis indexes by calling the rewrite itself
opt_sm = sm[idxs_to_lift]
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
copy_stack_trace([old_out, sm], opt_sm)

# Then reintroduce the axis index
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
idx_tuple, axis
)
new_axis = axis - ndim_reduced_left
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
new_out = opt_sm[idxs_to_keep]
copy_stack_trace(old_out, new_out)
return [new_out]

else:
return None
return _lift_subtensor_non_axis(
local_subtensor_lift_rewrite=local_subtensor_of_softmax,
fgraph=fgraph,
variable=sm,
idx_tuple=idx_tuple,
axis=axis,
old_subtensor_variable=old_out,
)

# Index input to softmax
x_sub = x[idx_tuple]
Expand Down Expand Up @@ -695,6 +719,52 @@ def local_subtensor_make_vector(fgraph, node):
pass


@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_join(fgraph, node):
"""Lift a Subtensor through a Join.
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
"""
join_var, *idx = node.inputs

if not (join_var.owner and isinstance(join_var.owner.op, Join)):
return None

if len(fgraph.clients[join_var]) > 1:
# Join involves a full_copy, so we don't want to do it twice
return None

join_axis, *join_components = join_var.owner.inputs

# Rewrite only works when the join axis is a constant along a non-indexed dimension
if not isinstance(join_axis, Constant):
return None

[old_out] = node.outputs
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
return _lift_subtensor_non_axis(
local_subtensor_lift_rewrite=local_subtensor_of_join,
fgraph=fgraph,
variable=join_var,
idx_tuple=idx_tuple,
axis=axis,
old_subtensor_variable=old_out,
)

# Lift index to the Join components
indexed_components = [component[idx_tuple] for component in join_components]
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
out = join(new_axis, *indexed_components)

return [out]


@register_specialize
@register_canonicalize
@node_rewriter([Subtensor])
Expand Down
56 changes: 55 additions & 1 deletion tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
tensor3,
vector,
)
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import (
Expand Down Expand Up @@ -252,6 +252,9 @@ def test_local_subtensor_of_softmax(original_fn, expected_fn):
)


shared_axis = shared(1, "axis")


def test_local_subtensor_of_unbroadcast():
# Test that Subtensor(Unbroadcast(x)) gets optimized into
# Unbroadcast(Subtensor(x)).
Expand Down Expand Up @@ -661,6 +664,57 @@ def test_empty_subtensor(self):
assert local_subtensor_make_vector.transform(fgraph, node) == [v]


@pytest.mark.parametrize(
"original_fn, expected_fn",
[
(
lambda x, y: concatenate([x, y], axis=1)[1],
lambda x, y: concatenate([x[1], y[1]], axis=0),
),
(
lambda x, y: concatenate([x, y], axis=-1)[1:],
lambda x, y: concatenate([x[1:], y[1:]], axis=1),
),
# Indexing on both axis of concatenation and somewhere else:
(
lambda x, y: concatenate([x, y], axis=1)[0, 1:],
lambda x, y: concatenate([x[0], y[0]], axis=0)[1:],
),
# Not supported, indexing on axis of concatenation
(
lambda x, y: concatenate([x, y], axis=0)[0],
lambda x, y: concatenate([x, y], axis=0)[0],
),
(
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
),
# Not supported, axis of concatenation is dynamically determined
(
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
),
],
)
def test_local_subtensor_of_join(original_fn, expected_fn):
rng = np.random.default_rng(257)
x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3))
x_test = rng.normal(size=x.type.shape)
y_test = rng.normal(size=y.type.shape)

out = original_fn(x, y)
expected_opt_out = expected_fn(x, y)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
)


def test_local_subtensor_shape_constant():
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
Expand Down

0 comments on commit 508d5b9

Please sign in to comment.