diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 31f416784453..4991a35d4668 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -990,6 +990,19 @@ struct DistributeStep final : OpDistributionPattern { return lens; } + // This is a helper to extract strides from a given shape + // E.g. : a shape of 2x3x4 will return strides [12, 4, 1] + SmallVector getStrides(ArrayRef shape) const { + int64_t elementCount = ShapedType::getNumElements(shape); + SmallVector strides; + int64_t currStride = elementCount; + for (int64_t len : shape) { + currStride = currStride / len; + strides.push_back(currStride); + } + return strides; + } + // Once we are in the realm of remaining dimensions, // the strides are not packed. This is a helper to // obtain the packed strides of the remaining dimensions. @@ -997,14 +1010,7 @@ struct DistributeStep final : OpDistributionPattern { // getRemainingDims) SmallVector getPackedStrides(ArrayRef dims) const { SmallVector lens = getLens(dims); - int64_t elementCount = ShapedType::getNumElements(lens); - SmallVector packedStrides; - int64_t currStride = elementCount; - for (int64_t len : lens) { - currStride = currStride / len; - packedStrides.push_back(currStride); - } - return packedStrides; + return getStrides(lens); } // This function emulates the slicing of otherwise large constant @@ -1091,9 +1097,14 @@ struct DistributeStep final : OpDistributionPattern { SmallVector subgroupIndices, threadIndices; populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout, subgroupIndices, threadIndices); - ArrayRef subgroupStrides = resultLayout.getSubgroupStrides(); + + SmallVector undistributedShape = + resultLayout.getUndistributedPackedShape(); + SmallVector undistributedStrides = getStrides(undistributedShape); + constexpr int64_t subgroupIdx = 0; + constexpr int64_t threadIdx = 3; + ArrayRef subgroupLengths = resultLayout.getSubgroupTile(); - ArrayRef threadStrides = resultLayout.getThreadStrides(); ArrayRef threadLengths = resultLayout.getThreadTile(); // Step op by definition should be single dimensional. SmallVector distributedShape = @@ -1102,8 +1113,9 @@ struct DistributeStep final : OpDistributionPattern { int64_t distributedElements = ShapedType::getNumElements(distributedShape); int64_t originalElements = result.getType().getNumElements(); SmallVector distributedDims{ - {subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]}, - {threadIndices[0], threadLengths[0], threadStrides[0]}}; + {subgroupIndices[0], subgroupLengths[0], + undistributedStrides[subgroupIdx]}, + {threadIndices[0], threadLengths[0], undistributedStrides[threadIdx]}}; llvm::sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) { return lhs.dimStride > rhs.dimStride; }); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir index fbc532549f23..76c33e4d31e2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir @@ -26,11 +26,10 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func @step_1 -// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[CST:.+]] = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex> // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x] -// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index -// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex> -// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex> +// CHECK: %[[TIDB:.+]] = vector.broadcast %[[TID]] : index to vector<4xindex> +// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDB]], %[[CST]] : vector<4xindex> // ----- @@ -94,10 +93,10 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func @step_3 -// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex> +// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9]> : vector<4xindex> // CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x] // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x] -// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index +// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c16 : index // CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex> // CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex> // CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index @@ -132,7 +131,8 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func @step_4 -// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex> +// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x] -// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex> -// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex> +// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c8 : index +// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<8xindex> +// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<8xindex> diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp index f086ecb8c1bd..92134cb9e1f1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -311,6 +311,20 @@ SmallVector NestedLayoutAttr::getDistributedShape() const { return shape; } +/// Before we distribute, we would like to see this as: +/// +SmallVector NestedLayoutAttr::getUndistributedPackedShape() const { + SmallVector shape; + int64_t rank = getRank(); + shape.reserve(rank * 5); + shape.append(getSubgroupTile().begin(), getSubgroupTile().end()); + shape.append(getBatchTile().begin(), getBatchTile().end()); + shape.append(getOuterTile().begin(), getOuterTile().end()); + shape.append(getThreadTile().begin(), getThreadTile().end()); + shape.append(getElementTile().begin(), getElementTile().end()); + return shape; +} + // Gets the rank of the undistributed vector for this layout. int64_t NestedLayoutAttr::getRank() const { // The layout requires that all size lists are the same length and match diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td index 446ff7733938..913fb9f92dd3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td @@ -292,6 +292,9 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout", // Returns the subgroup/lane ids delinearized from a single linearized // thread ID. SmallVector computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const; + + // Get the undistributed shape that is subgroup x batch x outer x thread x element + SmallVector getUndistributedPackedShape() const; }]; let genVerifyDecl = 1;