From 280dc49a30c156c0bc675ab9bd6278f252eb1451 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 May 2023 22:41:41 +0200 Subject: [PATCH] Implement Blockwise Op to vectorize existing Ops Inspired by: https://github.com/aesara-devs/aesara/pull/1215 Co-authored-by: Brandon T. Willard Co-authored-by: Purna Chandra Mansingh Co-authored-by: Sayam Kumar Co-authored-by: Kaustubh --- pytensor/tensor/blockwise.py | 437 +++++++++++++++++++++++++++++++++ pytensor/tensor/elemwise.py | 55 +++-- pytensor/tensor/nlinalg.py | 8 + pytensor/tensor/random/op.py | 6 + pytensor/tensor/utils.py | 24 ++ tests/tensor/test_blockwise.py | 302 +++++++++++++++++++++++ 6 files changed, 815 insertions(+), 17 deletions(-) create mode 100644 pytensor/tensor/blockwise.py create mode 100644 tests/tensor/test_blockwise.py diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py new file mode 100644 index 0000000000..ccf431c738 --- /dev/null +++ b/pytensor/tensor/blockwise.py @@ -0,0 +1,437 @@ +import re +from functools import singledispatch +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast + +import numpy as np + +from pytensor import config +from pytensor.gradient import DisconnectedType +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.null_type import NullType +from pytensor.graph.op import Op +from pytensor.tensor.shape import shape_padleft +from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.utils import import_func_from_string +from pytensor.tensor.var import TensorVariable, as_tensor_variable + + +# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad) + +# Copied verbatim from numpy.lib.function_base +# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029 +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _parse_gufunc_signature(signature): + """ + Parse string signatures for a generalized universal function. + + Arguments + --------- + signature : string + Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` + for ``np.matmul``. + + Returns + ------- + Tuple of input and output core dimensions parsed from the signature, each + of the form List[Tuple[str, ...]]. + """ + signature = re.sub(r"\s+", "", signature) + + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + return tuple( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + + +def safe_signature( + core_inputs: Sequence[Variable], + core_outputs: Sequence[Variable], +) -> str: + def operand_sig(operand: Variable, prefix: str) -> str: + operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) + return f"({operands})" + + inputs_sig = ",".join( + operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) + ) + outputs_sig = ",".join( + operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) + ) + return f"{inputs_sig}->{outputs_sig}" + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + if hasattr(op, "gufunc_signature"): + signature = op.gufunc_signature + else: + # TODO: This is pretty bad for shape inference and merge optimization! + # Should get better as we add signatures to our Ops + signature = safe_signature(node.inputs, node.outputs) + return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + op = node.op + return _vectorize_node(op, node, *batched_inputs) + + +class Blockwise(Op): + """Generalizes a core `Op` to work with batched dimensions. + + TODO: Dispatch JAX (should be easy with the vectorize macro) + TODO: Dispatch Numba + TODO: C implementation? + TODO: Fuse Blockwise? + """ + + __props__ = ("core_op", "signature") + + def __init__( + self, + core_op: Op, + signature: Optional[str] = None, + name: Optional[str] = None, + **kwargs, + ): + """ + + Parameters + ---------- + core_op + An instance of a subclass of `Op` which works on the core case. + signature + Generalized universal function signature, + e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication + + """ + if isinstance(core_op, Blockwise): + raise TypeError("Core Op is already a Blockwise") + + if signature is None: + signature = getattr(core_op, "gufunc_signature", None) + if signature is None: + raise ValueError( + f"Signature not provided nor found in core_op {core_op}" + ) + + self.core_op = core_op + self.signature = signature + self.name = name + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self._gufunc = None + super().__init__(**kwargs) + + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: + core_input_types = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if inp.type.ndim < len(sig): + raise ValueError( + f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" + ) + # ndim_supp = 0 case + if not sig: + core_shape = () + else: + core_shape = inp.type.shape[-len(sig) :] + core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) + + core_node = self.core_op.make_node(*core_input_types) + + if len(core_node.outputs) != len(self.outputs_sig): + raise ValueError( + f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" + ) + for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + if core_out.type.ndim != len(sig): + raise ValueError( + f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" + ) + + return core_node + + def make_node(self, *inputs): + inputs = [as_tensor_variable(i) for i in inputs] + + core_node = self._create_dummy_core_node(inputs) + + batch_ndims = max( + inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + ) + + # Don't pollute the graph with useless BlockWise + # TODO: Do we want to do this? Or leave it as a Blockwise and later have a rewrite that removes useless casse + # A reason to not eagerly avoid Blockwise is that we could make all rewrites track the Blockwise version, + # instead of having to track both or only the more restricted core case. + if not batch_ndims: + return self.core_op.make_node(*inputs) + + batched_inputs = [] + batch_shapes = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + # Append missing dims to the left + missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) + if missing_batch_ndims: + inp = shape_padleft(inp, missing_batch_ndims) + batched_inputs.append(inp) + + if not sig: + batch_shapes.append(inp.type.shape) + else: + batch_shapes.append(inp.type.shape[: -len(sig)]) + + def get_most_specialized_batch_shape( + dims: Sequence[Union[None, int]] + ) -> Union[None, int]: + dims_set = set(dims) + # All dims are the same + if len(dims_set) == 1: + return tuple(dims_set)[0] + + # Only valid indeterminate case + if dims_set == {None, 1}: + return None + + dims_set.discard(1) + dims_set.discard(None) + if len(dims_set) > 1: + raise ValueError + return tuple(dims_set)[0] + + try: + batch_shape = tuple( + [ + get_most_specialized_batch_shape(batch_dims) + for batch_dims in zip(*batch_shapes) + ] + ) + except ValueError: + raise ValueError( + f"Incompatible Blockwise batch input shapes {[inp.type.shape for inp in inputs]}" + ) + + batched_outputs = [ + tensor(dtype=core_out.type.dtype, shape=batch_shape + core_out.type.shape) + for core_out in core_node.outputs + ] + + return Apply(self, batched_inputs, batched_outputs) + + def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: + return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + + def infer_shape( + self, fgraph, node, input_shapes + ) -> List[Tuple[TensorVariable, ...]]: + from pytensor.tensor import broadcast_shape + from pytensor.tensor.shape import Shape_i + + batch_ndims = self._batch_ndim_from_outputs(node.outputs) + core_dims: Dict[str, Any] = {} + batch_shapes = [] + for input_shape, sig in zip(input_shapes, self.inputs_sig): + batch_shapes.append(input_shape[:batch_ndims]) + core_shape = input_shape[batch_ndims:] + + for core_dim, dim_name in zip(core_shape, sig): + prev_core_dim = core_dims.get(core_dim) + if prev_core_dim is None: + core_dims[dim_name] = core_dim + # Prefer constants + elif not isinstance(prev_core_dim, Constant): + core_dims[dim_name] = core_dim + + batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + + out_shapes = [] + for output, sig in zip(node.outputs, self.outputs_sig): + core_out_shape = [] + for i, dim_name in enumerate(sig): + # The output dim is the same as another input dim + if dim_name in core_dims: + core_out_shape.append(core_dims[dim_name]) + else: + # TODO: We could try to make use of infer_shape of core_op + core_out_shape.append(Shape_i(batch_ndims + i)(output)) + out_shapes.append((*batch_shape, *core_out_shape)) + + return out_shapes + + def connection_pattern(self, node): + if hasattr(self.core_op, "connection_pattern"): + return self.core_op.connection_pattern(node) + + return [[True for _ in node.outputs] for _ in node.inputs] + + def _bgrad(self, inputs, outputs, ograds): + # Grad, with respect to broadcasted versions of inputs + + def as_core(t, core_t): + # Inputs could be NullType or DisconnectedType + if isinstance(t.type, (NullType, DisconnectedType)): + return t + return core_t.type() + + with config.change_flags(compute_test_value="off"): + safe_inputs = [ + tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) + for inp, sig in zip(inputs, self.inputs_sig) + ] + core_node = self._create_dummy_core_node(safe_inputs) + + core_inputs = [ + as_core(inp, core_inp) + for inp, core_inp in zip(inputs, core_node.inputs) + ] + core_ograds = [ + as_core(ograd, core_ograd) + for ograd, core_ograd in zip(ograds, core_node.outputs) + ] + core_outputs = core_node.outputs + + core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) + + batch_ndims = self._batch_ndim_from_outputs(outputs) + + def transform(var): + # From a graph of ScalarOps, make a graph of Broadcast ops. + if isinstance(var.type, (NullType, DisconnectedType)): + return var + if var in core_inputs: + return inputs[core_inputs.index(var)] + if var in core_outputs: + return outputs[core_outputs.index(var)] + if var in core_ograds: + return ograds[core_ograds.index(var)] + + node = var.owner + + # The gradient contains a constant, which may be responsible for broadcasting + if node is None: + if batch_ndims: + var = shape_padleft(var, batch_ndims) + return var + + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + ret = [] + for core_igrad, ipt in zip(core_igrads, inputs): + # Undefined gradient + if core_igrad is None: + ret.append(None) + else: + ret.append(transform(core_igrad)) + + return ret + + def L_op(self, inputs, outs, ograds): + from pytensor.tensor.math import sum as pt_sum + + # Compute grad with respect to broadcasted input + rval = self._bgrad(inputs, outs, ograds) + + # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable + # to the gradient.grad method when the outputs have + # some integer and some floating point outputs + if any(out.type.dtype not in continuous_dtypes for out in outs): + # For integer output, return value may only be zero or undefined + # We don't bother with trying to check that the scalar ops + # correctly returned something that evaluates to 0, we just make + # the return value obviously zero so that gradient.grad can tell + # this op did the right thing. + new_rval = [] + for elem, inp in zip(rval, inputs): + if isinstance(elem.type, (NullType, DisconnectedType)): + new_rval.append(elem) + else: + elem = inp.zeros_like() + if str(elem.type.dtype) not in continuous_dtypes: + elem = elem.astype(config.floatX) + assert str(elem.type.dtype) not in discrete_dtypes + new_rval.append(elem) + return new_rval + + # Sum out the broadcasted dimensions + batch_ndims = self._batch_ndim_from_outputs(outs) + batch_shape = outs[0].type.shape[:batch_ndims] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if isinstance(rval[i].type, (NullType, DisconnectedType)): + continue + + assert inp.type.ndim == batch_ndims + len(sig) + + to_sum = [ + j + for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + if inp_s == 1 and out_s != 1 + ] + if to_sum: + rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True) + + return rval + + def _create_gufunc(self, node): + if hasattr(self.core_op, "gufunc_spec"): + self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0]) + if self._gufunc: + return self._gufunc + + n_outs = len(self.outputs_sig) + core_node = self._create_dummy_core_node(node.inputs) + + def core_func(*inner_inputs): + inner_outputs = [[None] for _ in range(n_outs)] + + inner_inputs = [np.asarray(inp) for inp in inner_inputs] + self.core_op.perform(core_node, inner_inputs, inner_outputs) + + if len(inner_outputs) == 1: + return inner_outputs[0][0] + else: + return tuple(r[0] for r in inner_outputs) + + self._gufunc = np.vectorize(core_func, signature=self.signature) + return self._gufunc + + def perform(self, node, inputs, output_storage): + gufunc = self._gufunc + + if gufunc is None: + gufunc = self._create_gufunc(node) + + res = gufunc(*inputs) + if not isinstance(res, tuple): + res = (res,) + + for node_out, out_storage, r in zip(node.outputs, output_storage, res): + out_dtype = getattr(node_out, "dtype", None) + if out_dtype and out_dtype != r.dtype: + r = np.asarray(r, dtype=out_dtype) + out_storage[0] = r + + def __str__(self): + if self.name is None: + return f"{type(self).__name__}{{{self.core_op}, {self.signature}}}" + else: + return self.name + + +@_vectorize_node.register(Blockwise) +def vectorize_not_needed(op, node, *batch_inputs): + return op.make_node(*batch_inputs) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 6d19579030..fde67dfc2b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -22,6 +22,7 @@ from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -29,6 +30,7 @@ float_dtypes, lvector, ) +from pytensor.tensor.utils import import_func_from_string from pytensor.tensor.var import TensorVariable from pytensor.utils import uniq @@ -228,7 +230,7 @@ def __str__(self): return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" - def perform(self, node, inp, out, params): + def perform(self, node, inp, out, params=None): (res,) = inp (storage,) = out @@ -662,22 +664,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): impl = "c" if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = getattr(np, self.nfunc_spec[0], None) - if self.nfunc is None: - # Not inside NumPy. So probably another package like scipy. - symb = self.nfunc_spec[0].split(".") - for idx in range(1, len(self.nfunc_spec[0])): - try: - module = __import__(".".join(symb[:idx])) - except ImportError: - break - for sub in symb[1:]: - try: - module = getattr(module, sub) - except AttributeError: - module = None - break - self.nfunc = module + self.nfunc = import_func_from_string(self.nfunc_spec[0]) if ( (len(node.inputs) + len(node.outputs)) <= 32 @@ -1759,3 +1746,37 @@ def _get_vector_length_Elemwise(op, var): return get_vector_length(var.owner.inputs[0]) raise ValueError(f"Length of {var} cannot be determined") + + +_vectorize_node.register(Elemwise, vectorize_not_needed) + + +@_vectorize_node.register(DimShuffle) +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) + new_order = list(range(batched_ndims)) + [ + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order + ] + return DimShuffle(input_broadcastable, new_order).make_node(x) + + +@_vectorize_node.register(CAReduce) +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + axes = op.axis + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) + if axes is None: + axes = list(range(node.inputs[0].type.ndim)) + else: + axes = list(axes) + new_axes = [axis + batched_ndims for axis in axes] + new_op = op.clone(axis=new_axes) + return new_op.make_node(x) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 32fa47d28d..e2fef2f2c4 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -15,6 +15,7 @@ class MatrixPinv(Op): __props__ = ("hermitian",) + gufunc_signature = "(m,n)->(n,m)" def __init__(self, hermitian): self.hermitian = hermitian @@ -93,6 +94,8 @@ class MatrixInverse(Op): """ __props__ = () + gufunc_signature = "(m,m)->(m,m)" + gufunc_spec = ("numpy.linalg.inv", 1, 1) def __init__(self): pass @@ -181,6 +184,8 @@ class Det(Op): """ __props__ = () + gufunc_signature = "(m,m)->()" + gufunc_spec = ("numpy.linalg.det", 1, 1) def make_node(self, x): x = as_tensor_variable(x) @@ -218,6 +223,7 @@ class SLogDet(Op): """ __props__ = () + gufunc_signature = "(m, m)->(),()" def make_node(self, x): x = as_tensor_variable(x) @@ -252,6 +258,8 @@ class Eig(Op): """ __props__: Tuple[str, ...] = () + gufunc_signature = "(m,m)->(m),(m,m)" + gufunc_spec = ("numpy.linalg.eig", 1, 2) def make_node(self, x): x = as_tensor_variable(x) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 1e4e44274f..d2f3c659a3 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -16,6 +16,7 @@ get_vector_length, infer_static_shape, ) +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes from pytensor.tensor.shape import shape_tuple @@ -428,3 +429,8 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor): default_rng = DefaultGeneratorMakerOp() + + +# RandomVariables are vectorized on the parameters by default. +# RNG, size and dtype can't be vectorized, but the Op will raise if the wrong input type is passed +_vectorize_node.register(RandomVariable, vectorize_not_needed) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 7535f47c5c..f9b2cc27a5 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -107,3 +107,27 @@ def as_list(x): return list(x) except TypeError: return [x] + + +def import_func_from_string(func_string: str): # -> Optional[Callable]: + func = getattr(np, func_string, None) + if func is not None: + return func + + # Not inside NumPy or Scipy. So probably another package like scipy. + module = None + items = func_string.split(".") + for idx in range(1, len(items)): + try: + module = __import__(".".join(items[:idx])) + except ImportError: + break + + if module: + for sub in items[1:]: + try: + module = getattr(module, sub) + except AttributeError: + module = None + break + return module diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py new file mode 100644 index 0000000000..76f3eb5642 --- /dev/null +++ b/tests/tensor/test_blockwise.py @@ -0,0 +1,302 @@ +from itertools import product +from typing import Tuple, Union + +import numpy as np +import pytest + +import pytensor +from pytensor import config +from pytensor.gradient import grad +from pytensor.graph import Apply, Op +from pytensor.tensor import exp, tensor +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Any, Sum +from pytensor.tensor.math import any as pt_any +from pytensor.tensor.math import sum as pt_sum +from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.random import normal +from pytensor.tensor.slinalg import Cholesky + + +def test_vectorize_node(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + col_mat = tensor(shape=(None, 1)) + tcol_mat = tensor(shape=(None, None, 1)) + node = col_mat.dimshuffle(0).owner # drop column + vect_node = vectorize_node(node, tcol_mat) + assert isinstance(vect_node.op, DimShuffle) + assert vect_node.op.new_order == (0, 1) + assert vect_node.inputs[0] is tcol_mat + assert vect_node.outputs[0].type.shape == (None, None) + + node = pt_sum(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Sum) + assert vect_node.op.axis == (1, 2) + assert vect_node.inputs[0] is tns + + bool_mat = tensor(dtype="bool", shape=(None, None)) + bool_tns = tensor(dtype="bool", shape=(None, None, None)) + node = pt_any(bool_mat, axis=0).owner + vect_node = vectorize_node(node, bool_tns) + assert isinstance(vect_node.op, Any) + assert vect_node.op.axis == (1,) + assert vect_node.inputs[0] is bool_tns + + node = normal(vec).owner + new_inputs = node.inputs[:3] + [mat] + node.inputs[4:] + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + + # Something that falls back to Blockwise + node = MatrixInverse()(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, MatrixInverse + ) + assert vect_node.op.signature == ("(m,m)->(m,m)") + assert vect_node.inputs[0] is tns + + # Useless blockwise + tns4 = tensor(shape=(5, None, None, None)) + new_vect_node = vectorize_node(vect_node, tns4) + assert new_vect_node.op is vect_node.op + assert isinstance(new_vect_node.op, Blockwise) and isinstance( + new_vect_node.op.core_op, MatrixInverse + ) + assert new_vect_node.inputs[0] is tns4 + + +def test_useless_blockwise(): + cop = MatrixInverse() + bop = Blockwise(cop, signature=("(m, m) -> (m, m)")) + + inp = tensor(shape=(None, None, None)) + out = bop(inp) + assert out.owner.op is bop + assert out.owner.inputs[0] is inp + + inp = tensor(shape=(None, None)) + out = bop(inp) + assert out.owner.op is cop + assert out.owner.inputs[0] is inp + + +class TestOp(Op): + def make_node(self, *inputs): + return Apply(self, inputs, [i.type() for i in inputs]) + + def perform(self, *args, **kwargs): + raise NotImplementedError("Test Op should not be present in final graph") + + +test_op = TestOp() + + +def test_vectorize_node_default_signature(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(5, None)) + node = test_op.make_node(vec, mat) + + vect_node = vectorize_node(node, mat, mat) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, TestOp + ) + assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") + + with pytest.raises( + ValueError, match="Signature not provided nor found in core_op TestOp" + ): + Blockwise(test_op) + + vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + + +def test_blockwise_shape(): + # Single output + inp = tensor(shape=(5, None, None)) + inp_test = np.zeros((5, 4, 3), dtype=config.floatX) + + # Shape can be inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (n, m)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 3, 4) + + # Shape can only be partially inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (m, k)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp], out.shape[:-1]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 4) + + # Mutiple outputs + inp1 = tensor(shape=(7, 1, None, None)) + inp2 = tensor(shape=(1, 5, None, None)) + inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) + inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) + + op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") + outs = op(inp1, inp2) + assert outs[0].type.shape == (7, 5, None, None) + assert outs[1].type.shape == (7, 5, None, None) + + shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp1, inp2], outs[0].shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) + + shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) + assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) + + +class BlockwiseOpTester: + """Base class to test Blockwise works for specific Ops""" + + core_op = None + signature = None + batcheable_axes = None + + @classmethod + def setup_class(cls): + seed = sum(map(ord, cls.__class__.__name__)) + cls.rng = np.random.default_rng(seed) + cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature) + if cls.batcheable_axes is None: + cls.batcheable_axes = list(range(len(cls.outputs_sig))) + batch_shapes = [(), (1,), (5,), (1, 1), (1, 3), (3, 1), (3, 5)] + cls.test_batch_shapes = list( + product(batch_shapes, repeat=len(cls.batcheable_axes)) + ) + cls.block_op = Blockwise(core_op=cls.core_op, signature=cls.signature) + + @staticmethod + def parse_shape(shape: Tuple[Union[str, int], ...]) -> Tuple[int, ...]: + """ + Convert (5, "m", "n") -> (5, 7, 11) + """ + mapping = {"m": 7, "n": 11, "k": 19} + return tuple(mapping.get(p, p) for p in shape) + + def create_testvals(self, shape): + return self.rng.normal(size=self.parse_shape(shape)).astype(config.floatX) + + def create_batched_inputs(self): + for batch_shapes in self.test_batch_shapes: + vec_inputs = [] + vec_inputs_testvals = [] + for batch_shape, param_sig in zip(batch_shapes, self.params_sig): + vec_inputs.append(tensor(shape=batch_shape + (None,) * len(param_sig))) + vec_inputs_testvals.append( + self.create_testvals(shape=batch_shape + param_sig) + ) + yield vec_inputs, vec_inputs_testvals + + def test_perform(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + core_func = pytensor.function(base_inputs, self.core_op(*base_inputs)) + np_func = np.vectorize(core_func, signature=self.signature) + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + pt_func = pytensor.function(vec_inputs, self.block_op(*vec_inputs)) + if len(self.outputs_sig) != 1: + raise NotImplementedError("Did not implement test for multi-output Ops") + np.testing.assert_allclose( + pt_func(*vec_inputs_testvals), + np_func(*vec_inputs_testvals), + ) + + def test_grad(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + out = self.core_op(*base_inputs).sum() + if len(base_inputs) == 1: + core_grad_func = pytensor.function( + base_inputs, grad(out, wrt=base_inputs[0]) + ) + else: + core_grad_func = pytensor.function(base_inputs, grad(out, wrt=base_inputs)) + + [param_sig, _] = self.signature.split("->") + grad_sig = f"{param_sig}->{param_sig}" + np_func_raw = np.vectorize(core_grad_func, signature=grad_sig) + if len(base_inputs): + np_func = lambda *args: [np_func_raw(*args)] # noqa: E731 + else: + np_func = np_func_raw + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + out = self.block_op(*vec_inputs).sum() + pt_func = pytensor.function(vec_inputs, grad(out, wrt=vec_inputs)) + pt_outs = pt_func(*vec_inputs_testvals) + np_outs = np_func(*vec_inputs_testvals) + for pt_out, np_out in zip(pt_outs, np_outs): + np.testing.assert_allclose(pt_out, np_out) + + +class MatrixOpBlockwiseTester(BlockwiseOpTester): + def create_testvals(self, shape): + # Return a posdef matrix + X = super().create_testvals(shape) + return np.einsum("...ij,...kj->...ik", X, X) + + +class TestCholesky(MatrixOpBlockwiseTester): + core_op = Cholesky(lower=True) + signature = "(m, m) -> (m, m)" + + +class TestMatrixInverse(MatrixOpBlockwiseTester): + core_op = MatrixInverse() + signature = "(m, m) -> (m, m)"