diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp index bbe873df8094..0d998ab187d9 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp @@ -966,6 +966,55 @@ struct ScalarizeVectorTransferWrite final } }; +/// Converts IR like the following into a single mask bit +/// +/// %mask = vector.create_mask %msize : vector<4xi1> +/// %mbit = vector.extract %mask[1] : i1 +/// +/// into +/// +/// %c1 = arith.constant 1 : index +/// %mbit = arith.cmpi slt %c1, %msize +/// +/// We run this at the same time as scalarizing masked transfers to try to fold +/// away any remaining mask creation ops as SPIR-V lacks support for masked +/// operations. +struct ReifyExtractOfCreateMask final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Restrict to the degenerate case where we are extracting a single element. + if (extractOp.getResult().getType().isa()) { + return failure(); + } + auto maskOp = extractOp.getVector().getDefiningOp(); + if (!maskOp) { + return failure(); + } + + Location loc = maskOp.getLoc(); + Value maskBit = + rewriter.create(loc, rewriter.getBoolAttr(true)); + for (auto [idx, size] : + llvm::zip_equal(extractOp.getMixedPosition(), maskOp.getOperands())) { + Value idxVal; + if (idx.is()) { + idxVal = rewriter.create( + loc, idx.get().cast().getInt()); + } else { + idxVal = idx.get(); + } + Value cmpIdx = rewriter.create( + loc, arith::CmpIPredicate::slt, idxVal, size); + maskBit = rewriter.create(loc, cmpIdx, maskBit); + } + rewriter.replaceOp(extractOp, maskBit); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -1050,6 +1099,7 @@ void SPIRVVectorizeLoadStorePass::runOnOperation() { RewritePatternSet rewritingPatterns(context); rewritingPatterns.add(context); + rewritingPatterns.add(context); if (failed( applyPatternsAndFoldGreedily(func, std::move(rewritingPatterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir index 3a8c4b465dae..dc14da7ace9a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir @@ -592,3 +592,40 @@ func.func @scalarize_masked_vector_transfer_op(%arg: vector<3xf32>, %mask: vecto // CHECK: memref.store %[[E2]], %{{.*}}[%[[C5]]] : memref<20xf32> // CHECK: } // CHECK: return %[[MASK_TR]] : vector<3xf32> + +// ----- + +func.func @extract_vector_transfer_read_mask_bits(%arg: vector<3xf32>, %index: index) -> (vector<3xf32>) { + %c3 = arith.constant 3: index + %f0 = arith.constant 0.0 : f32 + %mask = vector.create_mask %index : vector<3xi1> + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<20xf32> + %1 = vector.transfer_read %0[%c3], %f0, %mask : memref<20xf32>, vector<3xf32> + return %1: vector<3xf32> +} + +// CHECK-LABEL: func.func @extract_vector_transfer_read_mask_bits +// CHECK-SAME: %{{.*}}: vector<3xf32>, %[[MASK_SIZE:.+]]: index +// CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32 + +// CHECK: %[[MB0:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C0]] : index +// CHECK: %[[MASK_LD0:.+]] = scf.if %[[MB0]] -> (f32) { +// CHECK: %[[LD0:.+]] = memref.load {{.*}}[%[[C3]]] : memref<20xf32> +// CHECK: scf.yield %[[LD0]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[PAD]] : f32 +// CHECK: } +// CHECK: vector.insert %[[MASK_LD0]], %[[INIT]] [0] : f32 into vector<3xf32> +// CHECK: %[[MB1:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C1]] : index +// CHECK: scf.if %[[MB1]] -> (f32) { +// CHECK: memref.load %{{.*}}[%[[C4]]] : memref<20xf32> +// CHECK: %[[MB2:.+]] = arith.cmpi sgt, %[[MASK_SIZE]], %[[C2]] : index +// CHECK: scf.if %[[MB2]] -> (f32) { +// CHECK: memref.load %{{.*}}[%[[C5]]] : memref<20xf32>