forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen][GPU] Update greedy tile + fuse pipeline to generate mfma (i…
…ree-org#17617) This adds intrinsic packing and reshape propagation patterns to LLVMGPUTileAndFuse to allow for generating mfma operations. This adds a few passes to invoke a few necessary patterns for the pipeline to generate (good) code. 1. PropagateReshapesByExpansion to propagate reshapes introduced after decomposing tensor.pack/unpack towards the edges of the kernel in the hopes that the destination can line up properly. 2. IREE::GPU::PackToIntrinsics to pack based on the lowering config specified mma kind. 3. IREE::GPU::DistributeMmaToLanes to distribute iree_gpu.multi_mma ops to lanes, similar to another tiling level. There are a few known outstanding issues. 1. We run `ConvertToDestinationPassingStyle` twice to re-link the kernel destination with the body after decomposing `tensor.unpack`. This is to work around an issue with EliminateEmptyTensors being unable to analyze `flow.dispatch.tensor.store` ops with slicing behavior properly. After workgroup distribution is refactored to generate an scf.forall, this needs to be revisited. 4. iree_gpu.shuffle_tensor lowering to `tensor.insert_slice` is still broken. This will need to be reworked to support dynamic shapes. 5. Currently, because of the way the layout works, only MFMA_16x16x16 works. To support other layouts we will need another level of expanding to the intrinsic implicit layout and then propagating those expand_shapes. This will likely need to happen after reduction tiling unless we want to teach tile + fuse to swap tensor.expand_shape ops with tensor.extract_slice.
- Loading branch information
Showing
48 changed files
with
1,210 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/PassDetail.h" | ||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
#include "iree/compiler/Codegen/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
namespace { | ||
|
||
struct PropagateReshapesByExpansionPass | ||
: public PropagateReshapesByExpansionPassBase< | ||
PropagateReshapesByExpansionPass> { | ||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
void PropagateReshapesByExpansionPass::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
|
||
{ | ||
RewritePatternSet patterns(context); | ||
// Preemptively attempt to fold any reshapes into interface bindings if | ||
// possible to simplify subsequent reshape propagation. | ||
populateReshapeToInterfaceTensorPatterns(patterns); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
RewritePatternSet bubbleExpandShapePatterns(context); | ||
linalg::ControlFusionFn bubbleUpExpansionControlFn = | ||
[](OpOperand *fusedOperand) { | ||
Operation *producer = fusedOperand->get().getDefiningOp(); | ||
Operation *consumer = fusedOperand->getOwner(); | ||
|
||
// Block only if one of the operations has a lowering configuration | ||
// which means it likely expects tiling specific to its original shape. | ||
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { | ||
return false; | ||
} | ||
return true; | ||
}; | ||
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, | ||
bubbleUpExpansionControlFn); | ||
// Add patterns to do some additional cleanup (on top of canonicalizations | ||
// that can be done later) of reshape ops. | ||
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); | ||
linalg::FillOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, | ||
context); | ||
tensor::CollapseShapeOp::getCanonicalizationPatterns( | ||
bubbleExpandShapePatterns, context); | ||
tensor::EmptyOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, | ||
context); | ||
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, | ||
context); | ||
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); | ||
|
||
if (failed(applyPatternsAndFoldGreedily( | ||
getOperation(), std::move(bubbleExpandShapePatterns)))) { | ||
getOperation()->emitOpError("Failed to propagate reshapes"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
std::unique_ptr<OperationPass<>> createPropagateReshapesByExpansionPass() { | ||
return std::make_unique<PropagateReshapesByExpansionPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.