Skip to content

Commit

Permalink
add indexing maps for iree_linalg_ext.scatter's out operand (iree-o…
Browse files Browse the repository at this point in the history
…rg#17704)

Addresses iree-org#17691, indexing maps
weren't provided for output operands

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jun 20, 2024
1 parent 643a7cd commit d01fb23
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,32 @@ util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>)
// CHECK: flow.dispatch.workgroups
// CHECK: %[[GEN:.+]] = linalg.generic

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
%6 = tensor.empty() : tensor<4x1xi32>
%2 = tensor.empty() : tensor<4x1x16x8x128xf32>
%4 = tensor.empty() : tensor<10x8192x16x8x128xf32>

%outs = tensor.extract_slice %4[0, 0, 0, 0, 0][1, 8192, 16, 8, 128][1, 1, 1, 1, 1] :
tensor<10x8192x16x8x128xf32> to tensor<8192x16x8x128xf32>

%8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false)
ins(%2, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
outs(%outs : tensor<8192x16x8x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8192x16x8x128xf32>

util.return %8 : tensor<8192x16x8x128xf32>
}

// CHECK: util.func public @linalgext_scatter_fusion
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK: %[[OUTS:.+]] = tensor.extract_slice
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: outs(%[[OUTS]] : tensor<8192x16x8x128xf32>)
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ ScatterOp::reifyResultShapes(OpBuilder &b,
SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
builder.getMultiDimIdentityMap(getIndicesType().getRank())};
builder.getMultiDimIdentityMap(getIndicesType().getRank()),
/*output=*/AffineMap(nullptr)};
}

SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
Expand Down Expand Up @@ -499,7 +500,8 @@ ReverseOp::reifyResultShapes(OpBuilder &b,

SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getOperandRank())};
return {builder.getMultiDimIdentityMap(getOperandRank()),
/*output=*/AffineMap(nullptr)};
}

SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
Expand Down

0 comments on commit d01fb23

Please sign in to comment.