From c3db7106df931e767a560b0c100bcf5c2b77c888 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Wed, 4 Dec 2024 17:33:51 +0000 Subject: [PATCH] [VectorDistribution] Add distribution for trivial vector.extract (#19318) This patch adds a distribution pattern for vector.extract when the list of indices is zero. This arises in the case of a scalar extract for 0-d vectors. --- .../Common/GPU/GPUDistributionPatterns.cpp | 29 ++++++++++++++++++- .../GPU/test/gpu_vector_distribution.mlir | 17 +++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 276b7fe11d4b..9cc704b196b2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -303,10 +303,37 @@ struct DistributeGather final : OpDistributionPattern { } }; +/// Distribute a 0-rank vector to scalar vector.extract conversion. +struct DistributeTrivialExtract final + : OpDistributionPattern { + using OpDistributionPattern::OpDistributionPattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + DistributionSignature &signature, + PatternRewriter &rewriter) const override { + if (extractOp.getSourceVectorType().getRank() != 0) { + return rewriter.notifyMatchFailure( + extractOp, "Only 0-rank vector extractions supported"); + } + + VectorValue source = extractOp.getVector(); + VectorLayoutInterface sourceLayout = signature[source]; + + Value distributed = rewriter.create( + extractOp.getLoc(), getDistributed(rewriter, source, sourceLayout), + ArrayRef{}); + + replaceOpWithDistributedValues(rewriter, extractOp, distributed); + + return success(); + } +}; + } // namespace void populateGPUDistributionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); // Elementwise patterns. patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index 1f0833fa8768..c893ba30d0b5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -96,6 +96,23 @@ func.func @distribute_scf_for_0d(%a: vector, %b: vector) -> vector } +// CHECK-LABEL: @distribute_scalar_extract +func.func @distribute_scalar_extract(%a: f16, %b: vector) -> f16 { + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.0 : f16 + // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector + %root = arith.constant dense<0.0> : vector + %rootl = iree_vector_ext.to_layout %root to layout(#layout_0d) : vector + // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector -> vector + // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector + // CHECK-DAG: %[[SCALAR:.*]] = vector.extract %[[C]][] : f16 from vector + %c = arith.mulf %rootl, %b : vector + %scalar = vector.extract %c[] : f16 from vector + // CHECK-DAG: %[[D:.*]] = arith.addf %[[SCALAR]], %{{.*}} : f16 + %d = arith.addf %scalar, %a : f16 + return %d : f16 +} + builtin.module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op