Skip to content

Commit

Permalink
[SPIRV] Handle extraction from create_mask during load/store vectoriz…
Browse files Browse the repository at this point in the history
…ation (iree-org#15524)

Proper handling of masks requires direct lowerings of the mask, however
we can skip the materialization of the mask entirely (and subsequent
reliance on unrolling/canonicalization to clean up the masks) by folding
`vector.extract(vector.create_mask)` into the appropriate boolean. This
works for subgroup reduce (the first dynamic codegen problem we're
tackling) because masks and transfers are never unrolled in this
pipeline, rather we just distribute them in place to a pre-configured
vector size.

We handle this during SPIRVVectorizeLoadStore because immediately after
scalarizing vector.transfer_read/write, we will introduce the extracts
on the mask.
  • Loading branch information
qedawkins authored Nov 10, 2023
1 parent 8633629 commit f9d7599
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,55 @@ struct ScalarizeVectorTransferWrite final
}
};

/// Converts IR like the following into a single mask bit
///
/// %mask = vector.create_mask %msize : vector<4xi1>
/// %mbit = vector.extract %mask[1] : i1
///
/// into
///
/// %c1 = arith.constant 1 : index
/// %mbit = arith.cmpi slt %c1, %msize
///
/// We run this at the same time as scalarizing masked transfers to try to fold
/// away any remaining mask creation ops as SPIR-V lacks support for masked
/// operations.
struct ReifyExtractOfCreateMask final
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Restrict to the degenerate case where we are extracting a single element.
if (extractOp.getResult().getType().isa<VectorType>()) {
return failure();
}
auto maskOp = extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!maskOp) {
return failure();
}

Location loc = maskOp.getLoc();
Value maskBit =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
for (auto [idx, size] :
llvm::zip_equal(extractOp.getMixedPosition(), maskOp.getOperands())) {
Value idxVal;
if (idx.is<Attribute>()) {
idxVal = rewriter.create<arith::ConstantIndexOp>(
loc, idx.get<Attribute>().cast<IntegerAttr>().getInt());
} else {
idxVal = idx.get<Value>();
}
Value cmpIdx = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, idxVal, size);
maskBit = rewriter.create<arith::AndIOp>(loc, cmpIdx, maskBit);
}
rewriter.replaceOp(extractOp, maskBit);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1050,6 +1099,7 @@ void SPIRVVectorizeLoadStorePass::runOnOperation() {
RewritePatternSet rewritingPatterns(context);
rewritingPatterns.add<ScalarizeVectorTransferRead, ScalarizeVectorLoad,
ScalarizeVectorTransferWrite>(context);
rewritingPatterns.add<ReifyExtractOfCreateMask>(context);

if (failed(
applyPatternsAndFoldGreedily(func, std::move(rewritingPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,40 @@ func.func @scalarize_masked_vector_transfer_op(%arg: vector<3xf32>, %mask: vecto
// CHECK: memref.store %[[E2]], %{{.*}}[%[[C5]]] : memref<20xf32>
// CHECK: }
// CHECK: return %[[MASK_TR]] : vector<3xf32>

// -----

func.func @extract_vector_transfer_read_mask_bits(%arg: vector<3xf32>, %index: index) -> (vector<3xf32>) {
%c3 = arith.constant 3: index
%f0 = arith.constant 0.0 : f32
%mask = vector.create_mask %index : vector<3xi1>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<20xf32>
%1 = vector.transfer_read %0[%c3], %f0, %mask : memref<20xf32>, vector<3xf32>
return %1: vector<3xf32>
}

// CHECK-LABEL: func.func @extract_vector_transfer_read_mask_bits
// CHECK-SAME: %{{.*}}: vector<3xf32>, %[[MASK_SIZE:.+]]: index
// CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32

// CHECK: %[[MB0:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C0]] : index
// CHECK: %[[MASK_LD0:.+]] = scf.if %[[MB0]] -> (f32) {
// CHECK: %[[LD0:.+]] = memref.load {{.*}}[%[[C3]]] : memref<20xf32>
// CHECK: scf.yield %[[LD0]] : f32
// CHECK: } else {
// CHECK: scf.yield %[[PAD]] : f32
// CHECK: }
// CHECK: vector.insert %[[MASK_LD0]], %[[INIT]] [0] : f32 into vector<3xf32>
// CHECK: %[[MB1:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C1]] : index
// CHECK: scf.if %[[MB1]] -> (f32) {
// CHECK: memref.load %{{.*}}[%[[C4]]] : memref<20xf32>
// CHECK: %[[MB2:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C2]] : index
// CHECK: scf.if %[[MB2]] -> (f32) {
// CHECK: memref.load %{{.*}}[%[[C5]]] : memref<20xf32>

0 comments on commit f9d7599

Please sign in to comment.