Skip to content

Commit

Permalink
[VectorDistribution] Add distribution for trivial vector.extract (ire…
Browse files Browse the repository at this point in the history
…e-org#19318)

This patch adds a distribution pattern for vector.extract when the list
of indices is zero. This arises in the case of a scalar extract for 0-d
vectors.
  • Loading branch information
Groverkss authored Dec 4, 2024
1 parent 8894f5a commit c3db710
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,37 @@ struct DistributeGather final : OpDistributionPattern<vector::GatherOp> {
}
};

/// Distribute a 0-rank vector to scalar vector.extract conversion.
struct DistributeTrivialExtract final
: OpDistributionPattern<vector::ExtractOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
if (extractOp.getSourceVectorType().getRank() != 0) {
return rewriter.notifyMatchFailure(
extractOp, "Only 0-rank vector extractions supported");
}

VectorValue source = extractOp.getVector();
VectorLayoutInterface sourceLayout = signature[source];

Value distributed = rewriter.create<vector::ExtractOp>(
extractOp.getLoc(), getDistributed(rewriter, source, sourceLayout),
ArrayRef<int64_t>{});

replaceOpWithDistributedValues(rewriter, extractOp, distributed);

return success();
}
};

} // namespace

void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
patterns.add<DistributeConstants, DistributeScfFor>(patterns.getContext());
patterns.add<DistributeConstants, DistributeScfFor, DistributeTrivialExtract>(
patterns.getContext());
// Elementwise patterns.
patterns.add<DistributeElementwise>(patterns.getContext());
patterns.add<DistributeTrivialLayoutConversions>(patterns.getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ func.func @distribute_scf_for_0d(%a: vector<i32>, %b: vector<i32>) -> vector<i32
return %out : vector<i32>
}

// CHECK-LABEL: @distribute_scalar_extract
func.func @distribute_scalar_extract(%a: f16, %b: vector<f16>) -> f16 {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<f16>
%root = arith.constant dense<0.0> : vector<f16>
%rootl = iree_vector_ext.to_layout %root to layout(#layout_0d) : vector<f16>
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<f16> -> vector<f16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<f16>
// CHECK-DAG: %[[SCALAR:.*]] = vector.extract %[[C]][] : f16 from vector<f16>
%c = arith.mulf %rootl, %b : vector<f16>
%scalar = vector.extract %c[] : f16 from vector<f16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[SCALAR]], %{{.*}} : f16
%d = arith.addf %scalar, %a : f16
return %d : f16
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
Expand Down

0 comments on commit c3db710

Please sign in to comment.