Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute] Fix vector step distribute (iree-org#19227)
Browse files Browse the repository at this point in the history
Currently, the 'thread_stride' of NestedLayoutAttr is misinterpreted as
the access stride of multi-dimensional vector.

However, it turns out it correspond to tid -> vtid mapping and the
undistributed vector is packed as :
subgroup x batch x outer x thread x element
where vtid is used to index 'thread' dimension.

Therefore, this commit removes the usage of 'thread_stride's and
'subgroups_stride' when calculating the base constant offset and rather
obtain them from packed undistributed vector shape.

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
  • Loading branch information
manupak authored Nov 20, 2024
1 parent 1aada43 commit b5b8059
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,27 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
return lens;
}

// This is a helper to extract strides from a given shape
// E.g. : a shape of 2x3x4 will return strides [12, 4, 1]
SmallVector<int64_t> getStrides(ArrayRef<int64_t> shape) const {
int64_t elementCount = ShapedType::getNumElements(shape);
SmallVector<int64_t> strides;
int64_t currStride = elementCount;
for (int64_t len : shape) {
currStride = currStride / len;
strides.push_back(currStride);
}
return strides;
}

// Once we are in the realm of remaining dimensions,
// the strides are not packed. This is a helper to
// obtain the packed strides of the remaining dimensions.
// (See above for an example of remaining dimensions under
// getRemainingDims)
SmallVector<int64_t> getPackedStrides(ArrayRef<DimInfo> dims) const {
SmallVector<int64_t> lens = getLens(dims);
int64_t elementCount = ShapedType::getNumElements(lens);
SmallVector<int64_t> packedStrides;
int64_t currStride = elementCount;
for (int64_t len : lens) {
currStride = currStride / len;
packedStrides.push_back(currStride);
}
return packedStrides;
return getStrides(lens);
}

// This function emulates the slicing of otherwise large constant
Expand Down Expand Up @@ -1091,9 +1097,14 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();

SmallVector<int64_t> undistributedShape =
resultLayout.getUndistributedPackedShape();
SmallVector<int64_t> undistributedStrides = getStrides(undistributedShape);
constexpr int64_t subgroupIdx = 0;
constexpr int64_t threadIdx = 3;

ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
ArrayRef<int64_t> threadLengths = resultLayout.getThreadTile();
// Step op by definition should be single dimensional.
SmallVector<int64_t> distributedShape =
Expand All @@ -1102,8 +1113,9 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
int64_t distributedElements = ShapedType::getNumElements(distributedShape);
int64_t originalElements = result.getType().getNumElements();
SmallVector<DimInfo, 2> distributedDims{
{subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]},
{threadIndices[0], threadLengths[0], threadStrides[0]}};
{subgroupIndices[0], subgroupLengths[0],
undistributedStrides[subgroupIdx]},
{threadIndices[0], threadLengths[0], undistributedStrides[threadIdx]}};
llvm::sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) {
return lhs.dimStride > rhs.dimStride;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @step_1
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x]
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex>
// CHECK: %[[TIDB:.+]] = vector.broadcast %[[TID]] : index to vector<4xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDB]], %[[CST]] : vector<4xindex>

// -----

Expand Down Expand Up @@ -94,10 +93,10 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @step_3
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex>
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9]> : vector<4xindex>
// CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x]
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index
// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c16 : index
// CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex>
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
Expand Down Expand Up @@ -132,7 +131,8 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @step_4
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex>
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex>
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c8 : index
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<8xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<8xindex>
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
return shape;
}

/// Before we distribute, we would like to see this as:
/// <SUBGROUP x BATCH x OUTER x THREAD x ELEMENT>
SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
SmallVector<int64_t> shape;
int64_t rank = getRank();
shape.reserve(rank * 5);
shape.append(getSubgroupTile().begin(), getSubgroupTile().end());
shape.append(getBatchTile().begin(), getBatchTile().end());
shape.append(getOuterTile().begin(), getOuterTile().end());
shape.append(getThreadTile().begin(), getThreadTile().end());
shape.append(getElementTile().begin(), getElementTile().end());
return shape;
}

// Gets the rank of the undistributed vector for this layout.
int64_t NestedLayoutAttr::getRank() const {
// The layout requires that all size lists are the same length and match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
// Returns the subgroup/lane ids delinearized from a single linearized
// thread ID.
SmallVector<Value> computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const;

// Get the undistributed shape that is subgroup x batch x outer x thread x element
SmallVector<int64_t> getUndistributedPackedShape() const;
}];

let genVerifyDecl = 1;
Expand Down

0 comments on commit b5b8059

Please sign in to comment.