diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 1ebd462997f6..4b5957a50186 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -102,7 +102,9 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", "//compiler/src/iree/compiler/Dialect/Encoding/IR", + "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 8dfe2413816d..b48869ffcabd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -133,7 +133,9 @@ iree_cc_library( iree::compiler::Codegen::Utils iree::compiler::Codegen::Utils::VectorOpUtils iree::compiler::Dialect::Encoding::IR + iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Stream::Analysis iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 73491a869ff4..86e5ee260db9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -16,7 +16,9 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -35,6 +37,7 @@ namespace mlir::iree_compiler { #define GEN_PASS_DEF_GPUMATERIALIZEDEVICEENCODINGPASS +#define GEN_PASS_DEF_GPUMATERIALIZEHOSTENCODINGPASS #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" static bool hasIntrinsic(IREE::GPU::TargetAttr target, @@ -130,6 +133,22 @@ materializeEncodingForTarget(RankedTensorType tensorType, } namespace { + +// TODO(hanchung): Delete this pass and rely on tensor-based analysis to +// materialize encodings based on where tensors are used. This pass is not able +// to handle that. +struct GPUMaterializeHostEncodingPass + : public impl::GPUMaterializeHostEncodingPassBase< + GPUMaterializeHostEncodingPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + struct GPUMaterializeDeviceEncodingPass final : impl::GPUMaterializeDeviceEncodingPassBase< GPUMaterializeDeviceEncodingPass> { @@ -248,8 +267,15 @@ struct GPUUnsetEncodingOpLoweringConversion FailureOr maybeEncodingInfo = converter->getEncodingInfo(unsetEncodingOp.getSource().getType()); if (failed(maybeEncodingInfo)) { - return rewriter.notifyMatchFailure(unsetEncodingOp, - "unhandled result encoding"); + Value result = adaptor.getSource(); + Type targetType = + getTypeConverter()->convertType(unsetEncodingOp.getSourceType()); + if (targetType != result.getType()) { + result = rewriter.create(unsetEncodingOp.getLoc(), + targetType, result); + } + rewriter.replaceOp(unsetEncodingOp, result); + return success(); } Location loc = unsetEncodingOp.getLoc(); @@ -400,12 +426,10 @@ class GPUConvertToMultiMma final const MaterializeEncodingValueFn materializeEncodingValueFn; }; -} // namespace - -void GPUMaterializeDeviceEncodingPass::runOnOperation() { - MLIRContext *ctx = &getContext(); - FunctionOpInterface funcOp = getOperation(); - auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); +static LogicalResult +materializeFuncOpEncodings(FunctionOpInterface funcOp, + IREE::HAL::ExecutableTargetAttr targetAttr) { + MLIRContext *ctx = funcOp.getContext(); { RewritePatternSet patterns(ctx); MaterializeEncodingTypeConverter typeConverter(materializeEncodingForTarget, @@ -424,7 +448,7 @@ void GPUMaterializeDeviceEncodingPass::runOnOperation() { memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { funcOp.emitOpError("materialization failed"); - return signalPassFailure(); + return failure(); } } @@ -436,9 +460,92 @@ void GPUMaterializeDeviceEncodingPass::runOnOperation() { memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { funcOp.emitOpError("folding patterns failed"); + return failure(); + } + } + + return success(); +} + +static std::optional> +getFuncExecutableTargetAttrs(FunctionOpInterface funcOp, + IREE::Stream::AffinityAnalysis &affinityAnalysis, + IREE::HAL::DeviceAnalysis &deviceAnalysis) { + // Get a set of all unique affinities used by resources within the function. + SetVector uniqueAffinityAttrs; + SmallVector lookupAffinityAttrs; + funcOp.walk([&](Operation *op) { + if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) { + uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(), + lookupAffinityAttrs.end()); + } + lookupAffinityAttrs.clear(); + }); + + // Resolve affinities to executable targets. + SetVector executableTargetAttrs; + for (auto affinityAttr : uniqueAffinityAttrs) { + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp, + executableTargetAttrs); + } + return executableTargetAttrs; +} + +} // namespace + +void GPUMaterializeHostEncodingPass::runOnOperation() { + auto moduleOp = getOperation(); + + // Run required analysis passes. + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } + + for (auto funcOp : moduleOp.getOps()) { + // Gather the required executable targets for the function. Note that it's + // possible there are more required for ops nested within the function but + // this pass is a hack and can't handle that :shrug:. + auto executableTargets = + getFuncExecutableTargetAttrs(funcOp, affinityAnalysis, deviceAnalysis); + if (!executableTargets) { + funcOp.emitOpError() + << "could not determine executable targets for the function"; + return signalPassFailure(); + } else if (executableTargets->empty()) { + // Probably no tensors. + continue; + } + + // HACK: this pass is run on the host _but shouldn't be_. Because it's + // run on the host and IREE is a compiler capable of multi-targeting there + // may be multiple executable targets at any point in the host program. + // This pass can't handle that and assumes it's been checked earlier by + // spooky action at a distance. This needs to be fixed. + if (executableTargets->size() != 1) { + funcOp.emitOpError() << "has multiple executable targets and CPU data " + "tiling isn't built to support that"; + return signalPassFailure(); + } + + // Materialize encodings within the function. + if (failed( + materializeFuncOpEncodings(funcOp, executableTargets->front()))) { return signalPassFailure(); } } } +void GPUMaterializeDeviceEncodingPass::runOnOperation() { + FunctionOpInterface funcOp = getOperation(); + auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + if (failed(materializeFuncOpEncodings(funcOp, targetAttr))) { + return signalPassFailure(); + } +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 1b5136df1e3b..1f22bad8ea77 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -197,6 +197,11 @@ def GPUApplyTilingLevelPass : ]; } +def GPUMaterializeHostEncodingPass : + Pass<"iree-codegen-gpu-materialize-host-encoding", "mlir::ModuleOp"> { + let summary = "Materialize the encoding for tensor as specified by the backend."; +} + def GPUMaterializeDeviceEncodingPass : InterfacePass<"iree-codegen-gpu-materialize-device-encoding", "mlir::FunctionOpInterface"> { let summary = "Materialize the encoding for tensor as specified by the backend."; diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index d487e95f907d..b1725418d31d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/SmallVectorExtras.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -876,6 +877,27 @@ struct MaterializeOperation : public OpMaterializeEncodingPattern { } }; +struct MaterializeOptimizationBarrierOp + : public OpMaterializeEncodingPattern { + using OpMaterializeEncodingPattern< + IREE::Util::OptimizationBarrierOp>::OpMaterializeEncodingPattern; + + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, + IREE::Util::OptimizationBarrierOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (llvm::none_of(op.getOperandTypes(), [](Type type) -> bool { + auto tensorType = dyn_cast(type); + return tensorType && tensorType.getEncoding(); + })) { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return success(); + } +}; + /// Pattern to convert contraction operations. class MaterializeContractionOp : public OpInterfaceConversionPattern< mlir::linalg::ContractionOpInterface> { @@ -953,12 +975,12 @@ void populateShapeIndependentMaterializeEncodingPatterns( return resultType == typeConverter.convertType(resultType); }); - patterns.insert, - MaterializeOperation, - MaterializeFlowDispatchTensorLoadOp, - MaterializeFlowDispatchTensorStoreOp, - MaterializeInterfaceBindingEncoding>( - context, typeConverter, materializeEncodingValueFn); + patterns.insert< + MaterializeDPSOperation, + MaterializeOperation, MaterializeOptimizationBarrierOp, + MaterializeFlowDispatchTensorLoadOp, MaterializeFlowDispatchTensorStoreOp, + MaterializeInterfaceBindingEncoding>(context, typeConverter, + materializeEncodingValueFn); }; } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 5575a05f9d6a..3797824fa122 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -77,6 +77,7 @@ iree_compiler_cc_library( ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses", + "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow", "//compiler/src/iree/compiler/Dialect/Flow/IR", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 19e2d0eb0ef1..70bd927bfc7e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -92,6 +92,7 @@ iree_cc_library( MLIRTransforms iree::compiler::Codegen::Common iree::compiler::Codegen::Common::CPU::CommonCPUPasses + iree::compiler::Codegen::Common::GPU::CommonGPUPasses iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::Flow::Conversion::TensorToFlow iree::compiler::Dialect::Flow::IR diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 750ae788f63b..46a985f4bb9b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/CPU/Passes.h" +#include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" @@ -23,6 +24,14 @@ namespace mlir::iree_compiler::GlobalOptimization { +// TODO: Remove the flag once the codegen can handle the late materialization +// path. This is mainly for testing. +static llvm::cl::opt clEnableExperimentalRocmDataTiling( + "iree-global-opt-experimental-rocm-data-tiling", + llvm::cl::desc("Enables data-tiling materializatino for rocm backends " + "(experimental)."), + llvm::cl::init(false)); + #define GEN_PASS_DEF_MATERIALIZEHOMOGENEOUSENCODINGSPASS #include "iree/compiler/GlobalOptimization/Passes.h.inc" @@ -38,13 +47,9 @@ class MaterializeHomogeneousEncodingsPass registry.insert(); } - void runNopPipeline(ModuleOp &moduleOp) { - OpPassManager passManager(moduleOp.getOperationName()); + void addNopPipeline(OpPassManager &passManager) { FunctionLikeNest(passManager).addPass(createMaterializeEncodingIntoNopPass); FunctionLikeNest(passManager).addPass(createCanonicalizerPass); - if (failed(runPipeline(passManager, moduleOp))) { - return signalPassFailure(); - } } void runOnOperation() override { @@ -55,8 +60,13 @@ class MaterializeHomogeneousEncodingsPass SetVector executableTargets; deviceAnalysis.gatherAllExecutableTargets(executableTargets); + OpPassManager passManager(moduleOp.getOperationName()); if (executableTargets.size() != 1) { - return runNopPipeline(moduleOp); + addNopPipeline(passManager); + if (failed(runPipeline(passManager, moduleOp))) { + return signalPassFailure(); + } + return; } // TODO: vmvx has its own logic about supporting dynamic tile @@ -67,13 +77,21 @@ class MaterializeHomogeneousEncodingsPass return; } - // Only llvm-cpu backends handle encodings for now, others just go with nop. - if (executableTarget.getBackend() != "llvm-cpu") { - return runNopPipeline(moduleOp); + // Only llvm-cpu and rocm backends handle encodings for now, others just go + // with nop. + if (executableTarget.getBackend() == "llvm-cpu") { + passManager.addPass(createCPUMaterializeHostEncodingPass()); + } else if (clEnableExperimentalRocmDataTiling && + executableTarget.getBackend() == "rocm") { + passManager.addPass(createGPUMaterializeHostEncodingPass()); + FunctionLikeNest(passManager).addPass([&]() { + return createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false, + /*useOnlyReshapes=*/true, + /*controlFn=*/std::nullopt); + }); + } else { + addNopPipeline(passManager); } - - OpPassManager passManager(moduleOp.getOperationName()); - passManager.addPass(createCPUMaterializeHostEncodingPass()); if (failed(runPipeline(passManager, moduleOp))) { return signalPassFailure(); } diff --git a/tests/e2e/rocm_specific/BUILD.bazel b/tests/e2e/rocm_specific/BUILD.bazel new file mode 100644 index 000000000000..438fb0728d67 --- /dev/null +++ b/tests/e2e/rocm_specific/BUILD.bazel @@ -0,0 +1,24 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Tests for end-to-end IREE support specific to the vulkan-spirv lowering. + +load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_check_single_backend_test_suite( + name = "check_rocm_hip", + srcs = ["encoding.mlir"], + compiler_flags = [ + "--iree-global-opt-experimental-rocm-data-tiling", + ], + driver = "hip", + target_backend = "rocm", +) diff --git a/tests/e2e/rocm_specific/CMakeLists.txt b/tests/e2e/rocm_specific/CMakeLists.txt new file mode 100644 index 000000000000..c428b12fc7f7 --- /dev/null +++ b/tests/e2e/rocm_specific/CMakeLists.txt @@ -0,0 +1,26 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# tests/e2e/rocm_specific/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_check_single_backend_test_suite( + NAME + check_rocm_hip + SRCS + "encoding.mlir" + TARGET_BACKEND + "rocm" + DRIVER + "hip" + COMPILER_FLAGS + "--iree-global-opt-experimental-rocm-data-tiling" +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/tests/e2e/rocm_specific/encoding.mlir b/tests/e2e/rocm_specific/encoding.mlir new file mode 100644 index 000000000000..2718f9d00805 --- /dev/null +++ b/tests/e2e/rocm_specific/encoding.mlir @@ -0,0 +1,232 @@ +//===----------------------------------------------------------------------===// +// Utility Methods +//===----------------------------------------------------------------------===// + +func.func private @generate_2D_source_f16(%height : index, %width : index) -> tensor { + %init_source = tensor.empty(%height, %width) : tensor + %source = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%init_source : tensor) { + ^bb0(%b0 : f16): + %outer = linalg.index 0 : index + %inner = linalg.index 1 : index + %strided = arith.muli %outer, %width : index + %linearized = arith.addi %inner, %strided : index + %linearized_i16 = arith.index_cast %linearized : index to i16 + %linearized_f16 = arith.sitofp %linearized_i16 : i16 to f16 + linalg.yield %linearized_f16 : f16 + } -> tensor + // This blocks the fusion for inputs and testing ops. + %0 = util.optimization_barrier %source : tensor + %1 = flow.tensor.tie_shape %0 : tensor{%height, %width} + return %1 : tensor +} + +func.func private @generate_2D_source_f32(%height : index, %width : index) -> tensor { + %init_source = tensor.empty(%height, %width) : tensor + %source = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%init_source : tensor) { + ^bb0(%b0 : f32): + %outer = linalg.index 0 : index + %inner = linalg.index 1 : index + %strided = arith.muli %outer, %width : index + %linearized = arith.addi %inner, %strided : index + %linearized_i32 = arith.index_cast %linearized : index to i32 + %linearized_f32 = arith.sitofp %linearized_i32 : i32 to f32 + linalg.yield %linearized_f32 : f32 + } -> tensor + // This blocks the fusion for inputs and testing ops. + %0 = util.optimization_barrier %source : tensor + %1 = flow.tensor.tie_shape %0 : tensor{%height, %width} + return %1 : tensor +} + +func.func private @generate_2D_source_i8(%height : index, %width : index) -> tensor { + %c255 = arith.constant 255 : index + %init_source = tensor.empty(%height, %width) : tensor + %source = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%init_source : tensor) { + ^bb0(%b0 : i8): + %outer = linalg.index 0 : index + %inner = linalg.index 1 : index + %strided = arith.muli %outer, %width : index + %linearized = arith.addi %inner, %strided : index + %linearized_rem = arith.remsi %linearized, %c255 : index + %linearized_i8 = arith.index_cast %linearized_rem : index to i8 + linalg.yield %linearized_i8 : i8 + } -> tensor + // This blocks the fusion for inputs and testing ops. + %0 = util.optimization_barrier %source : tensor + %1 = flow.tensor.tie_shape %0 : tensor{%height, %width} + return %1 : tensor +} + +func.func private @generate_2D_source_i32(%height : index, %width : index) -> tensor { + %init_source = tensor.empty(%height, %width) : tensor + %source = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%init_source : tensor) { + ^bb0(%b0 : i32): + %outer = linalg.index 0 : index + %inner = linalg.index 1 : index + %strided = arith.muli %outer, %width : index + %linearized = arith.addi %inner, %strided : index + %linearized_i32 = arith.index_cast %linearized : index to i32 + linalg.yield %linearized_i32 : i32 + } -> tensor + // This blocks the fusion for inputs and testing ops. + %0 = util.optimization_barrier %source : tensor + %1 = flow.tensor.tie_shape %0 : tensor{%height, %width} + return %1 : tensor +} + +//===----------------------------------------------------------------------===// +// f32.f32.f32 variants +//===----------------------------------------------------------------------===// + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_f32f32f32_lhs = #iree_encoding.encoding> +#encoding_f32f32f32_rhs = #iree_encoding.encoding> +#encoding_f32f32f32_acc = #iree_encoding.encoding> + +func.func @set_encoding_f32f32f32_lhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f32(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf32> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf32> -> tensor<129x255xf32, #encoding_f32f32f32_lhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xf32, #encoding_f32f32f32_lhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf32, #encoding_f32f32f32_lhs> -> tensor<129x255xf32> + check.expect_almost_eq(%2, %source) : tensor<129x255xf32> + return +} + +func.func @set_encoding_f32f32f32_rhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f32(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf32> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf32> -> tensor<129x255xf32, #encoding_f32f32f32_rhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xf32, #encoding_f32f32f32_rhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf32, #encoding_f32f32f32_rhs> -> tensor<129x255xf32> + check.expect_almost_eq(%2, %source) : tensor<129x255xf32> + return +} + +func.func @set_encoding_f32f32f32_acc() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f32(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf32> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf32> -> tensor<129x255xf32, #encoding_f32f32f32_acc> + %barrire = util.optimization_barrier %1 : tensor<129x255xf32, #encoding_f32f32f32_acc> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf32, #encoding_f32f32f32_acc> -> tensor<129x255xf32> + check.expect_almost_eq(%2, %source) : tensor<129x255xf32> + return +} + +//===----------------------------------------------------------------------===// +// i8.i8.i32 variants +//===----------------------------------------------------------------------===// + +#encoding_i8i8i32_lhs = #iree_encoding.encoding> +#encoding_i8i8i32_rhs = #iree_encoding.encoding> +#encoding_i8i8i32_acc = #iree_encoding.encoding> + +func.func @set_encoding_i8i8i32_lhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_i8(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xi8> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xi8> -> tensor<129x255xi8, #encoding_i8i8i32_lhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xi8, #encoding_i8i8i32_lhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xi8, #encoding_i8i8i32_lhs> -> tensor<129x255xi8> + check.expect_eq(%2, %source) : tensor<129x255xi8> + return +} + +func.func @set_encoding_i8i8i32_rhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_i8(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xi8> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xi8> -> tensor<129x255xi8, #encoding_i8i8i32_rhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xi8, #encoding_i8i8i32_rhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xi8, #encoding_i8i8i32_rhs> -> tensor<129x255xi8> + check.expect_eq(%2, %source) : tensor<129x255xi8> + return +} + +func.func @set_encoding_i8i8i32_acc() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_i32(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xi32> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xi32> -> tensor<129x255xi32, #encoding_i8i8i32_acc> + %barrire = util.optimization_barrier %1 : tensor<129x255xi32, #encoding_i8i8i32_acc> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xi32, #encoding_i8i8i32_acc> -> tensor<129x255xi32> + check.expect_eq(%2, %source) : tensor<129x255xi32> + return +} + + +//===----------------------------------------------------------------------===// +// f16.f16.f32 variants +//===----------------------------------------------------------------------===// + +#encoding_f16f16f32_lhs = #iree_encoding.encoding> +#encoding_f16f16f32_rhs = #iree_encoding.encoding> +#encoding_f16f16f32_acc = #iree_encoding.encoding> + +func.func @set_encoding_f16f16f32_lhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f16(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf16> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf16> -> tensor<129x255xf16, #encoding_f16f16f32_lhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xf16, #encoding_f16f16f32_lhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf16, #encoding_f16f16f32_lhs> -> tensor<129x255xf16> + check.expect_eq(%2, %source) : tensor<129x255xf16> + return +} + +func.func @set_encoding_f16f16f32_rhs() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f16(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf16> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf16> -> tensor<129x255xf16, #encoding_f16f16f32_rhs> + %barrire = util.optimization_barrier %1 : tensor<129x255xf16, #encoding_f16f16f32_rhs> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf16, #encoding_f16f16f32_rhs> -> tensor<129x255xf16> + check.expect_eq(%2, %source) : tensor<129x255xf16> + return +} + +func.func @set_encoding_f16f16f32_acc() { + %height = arith.constant 129 : index + %width = arith.constant 255 : index + %0 = call @generate_2D_source_f32(%height, %width) : (index, index) -> tensor + %source = tensor.cast %0 : tensor to tensor<129x255xf32> + + %1 = iree_encoding.set_encoding %source : tensor<129x255xf32> -> tensor<129x255xf32, #encoding_f16f16f32_acc> + %barrire = util.optimization_barrier %1 : tensor<129x255xf32, #encoding_f16f16f32_acc> + %2 = iree_encoding.unset_encoding %1 : tensor<129x255xf32, #encoding_f16f16f32_acc> -> tensor<129x255xf32> + check.expect_eq(%2, %source) : tensor<129x255xf32> + return +}