Skip to content

Commit

Permalink
Remove duplicated Inv Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 22, 2023
1 parent e58bd91 commit f356c7c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 56 deletions.
13 changes: 0 additions & 13 deletions pytensor/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
Expand Down Expand Up @@ -125,18 +124,6 @@ def eigh(x):
return eigh


@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_basic.numba_njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)

return inv


@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
Expand Down
21 changes: 1 addition & 20 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,6 @@ def pinv(x, hermitian=False):
return MatrixPinv(hermitian=hermitian)(x)


class Inv(Op):
"""Computes the inverse of one or more matrices."""

def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype)

def infer_shape(self, fgraph, node, shapes):
return shapes


inv = Inv()


class MatrixInverse(Op):
r"""Computes the inverse of a matrix :math:`A`.
Expand Down Expand Up @@ -169,7 +150,7 @@ def infer_shape(self, fgraph, node, shapes):
return shapes


matrix_inverse = MatrixInverse()
inv = matrix_inverse = MatrixInverse()


def matrix_dot(*args):
Expand Down
26 changes: 3 additions & 23 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower)(x)
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
],
)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower)(A, x)
g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
Expand Down Expand Up @@ -352,26 +352,6 @@ def test_Eigh(x, uplo, exc):
None,
(),
),
(
nlinalg.Inv,
set_test_value(
at.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
at.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
(
nlinalg.MatrixPinv,
set_test_value(
Expand Down

0 comments on commit f356c7c

Please sign in to comment.