Skip to content

Commit

Permalink
[SPIRV] Enable dynamic reduction through subgroup reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Nov 10, 2023
1 parent 8524274 commit 7df3702
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,17 +1186,22 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
return failure();

// Make sure reduction dimensions are static and innermost ones.
int64_t numDynamicReduc = 0;
for (unsigned dim : reductionDims) {
if (ShapedType::isDynamic(bounds[dim])) {
LLVM_DEBUG(llvm::dbgs() << "failed: dynamic shapes in reduction dims\n");
return failure();
numDynamicReduc++;
}
if (dim < numParallelDims) {
LLVM_DEBUG(llvm::dbgs() << "failed: non-innermost reduction dims\n");
return failure();
}
}

// lazy: Don't support dynamic multi-reduce yet.
if (numDynamicReduc > 1) {
return failure();
}

if (op.getRegionOutputArgs().size() != 1)
return failure();

Expand Down Expand Up @@ -1227,6 +1232,40 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
return failure();

const int subgroupSize = targetEnv.getResourceLimits().getSubgroupSize();

// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
cast<PartitionableLoopsInterface>(op.getOperation())
.getPartitionableLoops(kNumMaxParallelDims);
llvm::SmallDenseSet<unsigned, 4> partitionedLoopsSet;
partitionedLoopsSet.insert(partitionedLoops.begin(), partitionedLoops.end());
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);

// Without any bounds on dynamic reduction dims, we need specialization to
// get peak performance. For now, just use the subgroup size.
if (numDynamicReduc) {
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
reductionTileSizes[reductionDims[0]] = subgroupSize;
TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
tileSizes.emplace_back(std::move(reductionTileSizes)); // reduction level
std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
if (failed(setOpConfigAndEntryPointFnTranslation(
op->getParentOfType<func::FuncOp>(), op, tileSizes,
CodeGenPipeline::SPIRVSubgroupReduce, workgroupSize))) {
return failure();
}

// Set lowering configuration to drive tiling for other Linalg ops too---the
// pipeline expects it.
op->getParentOfType<FunctionOpInterface>().walk([&](linalg::LinalgOp op) {
setLoweringConfig(
op, IREE::Codegen::LoweringConfigAttr::get(op.getContext(), tileSizes));
});
return success();
}

int64_t reductionSize = 1;
for (int64_t dim : reductionDims)
reductionSize *= bounds[dim];
Expand Down Expand Up @@ -1295,14 +1334,6 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
return failure();

std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
cast<PartitionableLoopsInterface>(op.getOperation())
.getPartitionableLoops(kNumMaxParallelDims);
llvm::SmallDenseSet<unsigned, 4> partitionedLoopsSet;
partitionedLoopsSet.insert(partitionedLoops.begin(), partitionedLoops.end());
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);

SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remaingGroupSize = groupSize;
Expand Down

0 comments on commit 7df3702

Please sign in to comment.