diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 3ee627aebabc..846e72c2267f 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -624,6 +625,21 @@ areNotFullTiles(ArrayRef inputShape, return false; } +static SmallVector getMixedValues(MLIRContext *context, + ArrayRef staticValues, + OperandRange dynamicValues) { + OpBuilder b(context); + return mlir::getMixedValues(staticValues, dynamicValues, b); +} + +static SmallVector +getStaticValues(SmallVector mixedValues) { + SmallVector dynamicTiles; + SmallVector staticTiles; + dispatchIndexOpFoldResults(mixedValues, dynamicTiles, staticTiles); + return staticTiles; +} + /// Utility function shared between Pack and UnPack to get the tile sizes as /// OpFoldResults. // TODO: interface or base class in .td @@ -631,17 +647,8 @@ template static SmallVector getMixedTiles(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); - SmallVector mixedInnerTiles; - unsigned dynamicValIndex = 0; - OpBuilder b(op.getContext()); - for (int64_t tileSize : op.getStaticInnerTiles()) { - if (!ShapedType::isDynamic(tileSize)) { - mixedInnerTiles.push_back(b.getIndexAttr(tileSize)); - } else { - mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); - } - } - return mixedInnerTiles; + return LinalgExt::getMixedValues(op.getContext(), op.getStaticInnerTiles(), + op.getInnerTiles()); } /// Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel @@ -650,10 +657,7 @@ template static SmallVector getStaticTiles(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); - SmallVector dynamicTiles; - SmallVector staticTiles; - dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles); - return staticTiles; + return getStaticValues(op.getMixedTiles()); } /// Utility function shared between Pack and UnPack to get a map between @@ -1502,6 +1506,148 @@ SmallVector OnlineAttentionOp::getIndexingMapsArray() { getIndexingMaps().getAsValueRange()); } +//===----------------------------------------------------------------------===// +// Im2colOp +//===----------------------------------------------------------------------===// + +/// Return all static and dynamic kernel_size as OpFoldResults. +SmallVector Im2colOp::getMixedKernelSize() { + return LinalgExt::getMixedValues(getContext(), getStaticKernelSize(), + getKernelSize()); +} + +/// Return all static and dynamic k_offset as OpFoldResults. +SmallVector Im2colOp::getMixedKOffset() { + return LinalgExt::getMixedValues(getContext(), getStaticKOffset(), + getKOffset()); +} + +/// Return all static and dynamic k_offset as OpFoldResults. +SmallVector Im2colOp::getMixedMOffset() { + return LinalgExt::getMixedValues(getContext(), getStaticMOffset(), + getMOffset()); +} + +void Im2colOp::setMixedKOffset(SmallVector kOffset) { + SmallVector staticKOffset; + SmallVector dynamicKOffset; + dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset); + setStaticKOffset(staticKOffset); + getKOffsetMutable().assign(dynamicKOffset); +} + +void Im2colOp::setMixedMOffset(SmallVector mOffset) { + SmallVector staticMOffset; + SmallVector dynamicMOffset; + dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset); + setStaticMOffset(staticMOffset); + getMOffsetMutable().assign(dynamicMOffset); +} + +/// Custom builder methods for im2col op. +void Im2colOp::build(OpBuilder &builder, OperationState &state, Value input, + Value output, ArrayRef strides, + ArrayRef dilations, + ArrayRef kernelSize, + ArrayRef kOffset, + ArrayRef mOffset, ArrayRef batchPos, + ArrayRef mPos, ArrayRef kPos) { + assert(strides.size() == kernelSize.size() && + dilations.size() == kernelSize.size() && + mPos.size() == kernelSize.size() && + "strides, dilations, m_pos, and kernel expected to be the same rank"); + SmallVector staticKernelSize, staticMOffset, staticKOffset; + SmallVector dynamicKernelSize, dynamicMOffset, dynamicKOffset; + dispatchIndexOpFoldResults(kernelSize, dynamicKernelSize, staticKernelSize); + dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset); + dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset); + SmallVector resultType; + auto outputType = output.getType(); + if (isa(outputType)) { + resultType.push_back(outputType); + } + build(builder, state, resultType, input, output, + builder.getDenseI64ArrayAttr(strides), + builder.getDenseI64ArrayAttr(dilations), dynamicKernelSize, + builder.getDenseI64ArrayAttr(staticKernelSize), dynamicKOffset, + builder.getDenseI64ArrayAttr(staticKOffset), dynamicMOffset, + builder.getDenseI64ArrayAttr(staticMOffset), + builder.getDenseI64ArrayAttr(batchPos), + builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos)); +} + +LogicalResult Im2colOp::verify() { + Operation *op = getOperation(); + if (llvm::count_if(getDpsInputs(), [](Value v) { + return isa(v.getType()); + }) != 1) { + return op->emitOpError("expected only one ShapedType operand"); + } + if (getNumDpsInits() != 1) { + return op->emitOpError("expected one output operand"); + } + + // TODO(Max191): Support cases with more than 1 m or k dimension, and remove + // the check for a single m_offset and k_offset. + if (getMixedMOffset().size() != 1) { + return op->emitOpError("expected one m_offset"); + } + if (getMixedKOffset().size() != 1) { + return op->emitOpError("expected one k_offset"); + } + auto inputType = getInputType(); + unsigned inputRank = inputType.getRank(); + ArrayRef batchPos = getBatchPos(); + ArrayRef mPos = getMPos(); + ArrayRef kPos = getKPos(); + if (inputRank != batchPos.size() + mPos.size() + kPos.size()) { + return op->emitOpError( + "expected input rank to be the sum of batch, m, and k ranks"); + } + ArrayRef strides = getStrides(); + ArrayRef dilations = getDilations(); + SmallVector kernelSize = getMixedKernelSize(); + if (kernelSize.size() != mPos.size()) { + return op->emitOpError( + "expected kernel rank to be equal to the m_pos rank"); + } + if (strides.size() != kernelSize.size()) { + return op->emitOpError( + "expected strides rank to be equal to the kernel rank"); + } + if (dilations.size() != kernelSize.size()) { + return op->emitOpError( + "expected dilations rank to be equal to the kernel rank"); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector expectedOutputShape; + for (auto pos : batchPos) { + expectedOutputShape.push_back(inputShape[pos]); + } + ArrayRef outputShape = getOutputType().getShape(); + // When the op is tiled, the m and k dimensions of the output are tiled, but + // they are not tiled in the input, so we cannot verify the output size of + // these dimensions. + expectedOutputShape.push_back(outputShape[outputShape.size() - 2]); + expectedOutputShape.push_back(outputShape.back()); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return op->emitOpError("incompatible output shape"); + } + return success(); +} + +LogicalResult Im2colOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +LogicalResult +Im2colOp::reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return cast(getOperation()) + .reifyResultShapes(b, reifiedReturnShapes); +} + #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ @@ -1522,6 +1668,7 @@ DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp) DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp) DEFINE_OP_GET_EFFECTS(AttentionOp) DEFINE_OP_GET_EFFECTS(OnlineAttentionOp) +DEFINE_OP_GET_EFFECTS(Im2colOp) } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 0eebd2e16976..0fd5317d6bc1 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -768,6 +768,134 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", } }]; } +//===----------------------------------------------------------------------===// +// Im2col +//===----------------------------------------------------------------------===// + +def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col", + [DeclareOpInterfaceMethods]> { + let summary = "Im2col operation for convolutions"; + let description = [{ + Im2col op for convolutions. The operation performs a transformation on the + input to convert it from a convolution input to an equivalent gemm input. + The op is defined by its input, output, some conv metadata, and some + indexing metadata. The `strides`, `dilations`, and `kernel_size` are taken + from the convolution from which this op is generated, and they define how + the input operand is indexed when the operation is decomposed. The shape of + the output should be `tensor`, and the `m_pos`, `k_pos`, and + `batch_pos` indicate which input dimensions map to which output dimensions. + + The `k_offset` is an offset within the output K dimension from which the + iteration space of the operation begins. This is used for tiling, since the + tiled implementation must leave the output K dimension untiled. Similarly, + `m_offset` is the offset within the output M dimension from which the + iteration space of the operation begins. + The iteration space is the full output shape of the im2col op, so if the + im2col op were tiled to loops with a scalar inner tile, it would look like + the following: + ``` + %im2col = iree_linalg_ext.im2col + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%in : tensor<2x34x34x640xf32>) + outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + ``` + becomes: + ``` + scf.for %arg0 = %c0 to %c2 step %c1 + scf.for %arg1 = %c0 to %c1024 step %c1 + scf.for %arg2 = %c0 to %c5760 step %c1 + %im2col = iree_linalg_ext.im2col + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [%arg1] k_offset = [%arg2] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%in_tile : tensor<1x34x34x640xf32>) + outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + ``` + Then, when the tiled op is decomposed, it becomes a loop over the iteration + space of the im2col op, whith an extract_slice from the `%in_tile` followed + by an insert_slice to the `%out_tile`. The indices for the extract slice are + computed using the `m_offset` and `k_offset` as: + (b, m, k) -> (b, M / 32 + K / (640*3), M % 32 + K % (640*3) / 640, K % 640) + Where `(b, m, k)` are the indices of the tiled op's iteration space, and + `M = m + m_offset` and `K = k + K_offset`. + }]; + + let arguments = (ins AnyShaped:$input, AnyShaped:$output, + DenseI64ArrayAttr:$strides, + DenseI64ArrayAttr:$dilations, + Variadic:$kernel_size, + DenseI64ArrayAttr:$static_kernel_size, + Variadic:$m_offset, + DenseI64ArrayAttr:$static_m_offset, + Variadic:$k_offset, + DenseI64ArrayAttr:$static_k_offset, + DenseI64ArrayAttr:$batch_pos, + DenseI64ArrayAttr:$m_pos, + DenseI64ArrayAttr:$k_pos); + + let results = (outs Variadic:$results); + let hasFolder = 1; + let assemblyFormat = [{ + attr-dict + `strides` `=` $strides + `dilations` `=` $dilations + `kernel_size` `=` + custom($kernel_size, $static_kernel_size) + `m_offset` `=` + custom($m_offset, $static_m_offset) + `k_offset` `=` + custom($k_offset, $static_k_offset) + `batch_pos` `=` $batch_pos + `m_pos` `=` $m_pos + `k_pos` `=` $k_pos + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + (`->` type($results)^)? + }]; + + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$output, + "ArrayRef":$strides, + "ArrayRef":$dilations, + "ArrayRef":$kernel_size, + "ArrayRef":$m_offset, + "ArrayRef":$k_offset, + "ArrayRef":$batch_dimensions, + "ArrayRef":$m_dimensions, + "ArrayRef":$k_dimensions)> + ]; + + let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ + ShapedType getInputType() { + return cast(getInput().getType()); + } + ShapedType getOutputType() { + return cast(getOutput().getType()); + } + int64_t getInputRank() { + return getInputType().getRank(); + } + int64_t getOutputRank() { + return getOutputType().getRank(); + } + // Return op metadata. + SmallVector getMixedKernelSize(); + SmallVector getMixedMOffset(); + SmallVector getMixedKOffset(); + + // Set op metadata. + void setMixedKOffset(SmallVector kOffset); + void setMixedMOffset(SmallVector mOffset); + + // Method to implement for specifying output range for + // DestinationStyleOpInterface + MutableOperandRange getDpsInitsMutable() { + return getOutputMutable(); + } + }]; +} } // OpGroupNonStructuredOps diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index fd0dea7e8d69..5fdfdf5dd4bc 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -582,6 +582,54 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape( // ----- +func.func @illegal_im2col_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + // expected-error @+1 {{expected strides rank to be equal to the kernel rank}} + %1 = iree_linalg_ext.im2col strides = [1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x34x34x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} + +// ----- + +func.func @illegal_im2col_dilations(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + // expected-error @+1 {{expected dilations rank to be equal to the kernel rank}} + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x34x34x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} + +// ----- + +func.func @illegal_im2col_kernel_size(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + // expected-error @+1 {{expected kernel rank to be equal to the m_pos rank}} + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x34x34x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} + +// ----- + +func.func @illegal_im2col_input_rank(%arg0: tensor<1x2x34x34x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + // expected-error @+1 {{expected input rank to be the sum of batch, m, and k ranks}} + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<1x2x34x34x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} + +// ----- + func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> { %0 = tensor.empty() : tensor<8x8x1x6x6x32xf32> // expected-error @+1 {{incompatible output shape}} diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index 15540ee7c657..d01074ad86d2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -895,6 +895,116 @@ func.func @unpack(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) { // ----- +func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x34x34x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} +// CHECK: func.func @im2col(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32> +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] +// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x34x34x640xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> +// CHECK: return %[[D1]] : tensor<2x1024x5760xf32> + +// ----- + +func.func @im2col_dynamic(%arg0: tensor, %s0: index, %s1: index, %s2: index, + %mOffset: index, %kOffset: index) -> tensor { + %0 = tensor.empty(%s0, %s1, %s2) : tensor + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [%mOffset] k_offset = [%kOffset] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor) + outs(%0 : tensor) -> tensor + return %1 : tensor +} +// CHECK: func.func @im2col_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor, +// CHECK-SAME: %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[MOFFSET:.+]]: index, %[[KOFFSET:.+]]: index +// CHECK: %[[D0:.+]] = tensor.empty({{.+}}) : tensor +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] +// CHECK-SAME: m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3] +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs(%[[D0]] : tensor) -> tensor +// CHECK: return %[[D1]] : tensor + +// ----- + +func.func @im2col_strided(%arg0: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + %1 = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x65x96x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} +// CHECK: func.func @im2col_strided(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32> +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32> +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3] +// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x65x96x640xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> +// CHECK: return %[[D1]] : tensor<2x1024x5760xf32> + +// ----- + +func.func @im2col_dilated(%arg0: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x44x46x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} +// CHECK: func.func @im2col_dilated(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32> +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32> +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3] +// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x44x46x640xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> +// CHECK: return %[[D1]] : tensor<2x1024x5760xf32> + +// ----- + +func.func @im2col_strided_dilated_mixed_kernel(%arg0: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + %1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] + m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%arg0 : tensor<2x172x101x640xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} +// CHECK: func.func @im2col_strided_dilated_mixed_kernel(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32> +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32> +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] +// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x172x101x640xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> +// CHECK: return %[[D1]] : tensor<2x1024x5760xf32> + +// ----- + +func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> { + %0 = tensor.empty() : tensor<2x1024x5760xf32> + %1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] + m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0] + ins(%arg0 : tensor<640x2x101x172xf32>) + outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> + return %1 : tensor<2x1024x5760xf32> +} +// CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32> +// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] +// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0] +// CHECK-SAME: ins(%[[ARG0]] : tensor<640x2x101x172xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32> +// CHECK: return %[[D1]] : tensor<2x1024x5760xf32> + +// ----- + func.func @winograd_filter_transform(%arg0: tensor<3x3x64x128xf32>) -> tensor<8x8x64x128xf32> { %0 = tensor.empty() : tensor<8x8x64x128xf32> %1 = iree_linalg_ext.winograd.filter_transform diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 73f31f21b253..0a9a1b542a62 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -326,6 +326,8 @@ void registerUtilExternalModels(DialectRegistry ®istry) { IREE::LinalgExt::WinogradOutputTransformOp::attachInterface< LinalgOpTiedOpInterface>( *context); + IREE::LinalgExt::Im2colOp::attachInterface< + LinalgOpTiedOpInterface>(*context); IREE::LinalgExt::AttentionOp::attachInterface< LinalgOpTiedOpInterface>(*context); });