Skip to content

Commit

Permalink
dialects: (stablehlo) Add support for andOp (#3081)
Browse files Browse the repository at this point in the history
Add stablehlo.and operation.

Co-authored-by: Erick Ochoa <erick@ceci-nest-pas.me>
  • Loading branch information
efferifick and Erick Ochoa authored Aug 21, 2024
1 parent bd14181 commit 04fcab7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]

// CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
36 changes: 34 additions & 2 deletions xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""

import abc
from typing import Annotated, cast
from typing import Annotated, TypeAlias, cast

from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, TensorType
from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType
from xdsl.ir import Attribute, Dialect, SSAValue
from xdsl.irdl import (
ConstraintVar,
Expand Down Expand Up @@ -91,6 +91,37 @@ class AddOp(ElementwiseBinaryOperation):
name = "stablehlo.add"


IntegerTensorType: TypeAlias = TensorType[IntegerType]


@irdl_op_definition
class AndOp(IRDLOperation):
"""
Performs element-wise AND of two tensors lhs and rhs and produces a result tensor. Depending on the element type, does the following:
For booleans: logical AND.
For integers: bitwise AND.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and
"""

name = "stablehlo.and"

T = Annotated[IntegerTensorType, ConstraintVar("T")]

lhs = operand_def(T)
rhs = operand_def(T)

result = result_def(T)

def __init__(
self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None
):
if result_type is None:
result_type = lhs.type
super().__init__(operands=(lhs, rhs), result_types=(result_type,))


@irdl_op_definition
class MultiplyOp(ElementwiseBinaryOperation):
"""
Expand Down Expand Up @@ -188,6 +219,7 @@ def verify_(self) -> None:
[
AbsOp,
AddOp,
AndOp,
MultiplyOp,
SubtractOp,
TransposeOp,
Expand Down

0 comments on commit 04fcab7

Please sign in to comment.