diff --git a/tests/test_backend.py b/tests/test_backend.py index 99f89bb88..550d21dde 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -72,6 +72,7 @@ matrix_diag_part = tf.compat.v1.matrix_diag_part fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars + extract_image_patches = tf.image.extract_patches elif Version(tf.__version__) >= Version("1.13"): conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input conv3d_transpose = tf.compat.v1.nn.conv3d_transpose @@ -94,6 +95,7 @@ matrix_diag_part = tf.compat.v1.matrix_diag_part fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars + extract_image_patches = tf.compat.v1.extract_image_patches else: conv2d_backprop_input = tf.nn.conv2d_backprop_input conv3d_transpose = tf.nn.conv3d_transpose @@ -111,6 +113,7 @@ is_inf = tf.is_inf floormod = tf.floormod matrix_diag_part = tf.matrix_diag_part + extract_image_patches = tf.extract_image_patches def make_xval(shape): @@ -6248,5 +6251,21 @@ def func(tensor, indices, updates): self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val}) self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val}) + def test_extract_image_patches(self): + for rates in [[1, 1], [1, 4], [4, 1], [3, 3]]: + for _, padding, x_shape, sizes, strides in get_conv_getdata(): + def func(x): + return extract_image_patches( + x, + sizes=sizes, + strides=strides, + rates=[1] + rates + [1], + padding=padding, + name=_TFOUTPUT + ) + + x_val = make_xval(x_shape) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/rewriter/__init__.py b/tf2onnx/rewriter/__init__.py index fc551a4ad..16f7670e2 100644 --- a/tf2onnx/rewriter/__init__.py +++ b/tf2onnx/rewriter/__init__.py @@ -24,6 +24,7 @@ from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2 from tf2onnx.rewriter.gru_tf2_rewriter import rewrite_gru_tf2 from tf2onnx.rewriter.fused_op_rewriter import rewrite_fused_ops +from tf2onnx.rewriter.extract_image_patches_rewriter import rewrite_extract_image_patches __all__ = [ @@ -53,4 +54,5 @@ "rewriter_lstm_tf2", "rewrite_gru_tf2", "rewrite_fused_ops", + "rewrite_extract_image_patches", ] diff --git a/tf2onnx/rewriter/extract_image_patches_rewriter.py b/tf2onnx/rewriter/extract_image_patches_rewriter.py new file mode 100644 index 000000000..d941a11d7 --- /dev/null +++ b/tf2onnx/rewriter/extract_image_patches_rewriter.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 + + +""" +tf2onnx.rewriter.extract_image_patches_rewriter - Rewrites ExtractImagePatches into supported operations. +""" + +import numpy as np +from tf2onnx import utils +from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher + + +def rewrite_extract_image_patches(g, ops): + pattern = OpTypePattern("ExtractImagePatches", name="extract_image_patches") + matcher = GraphMatcher(pattern) + match_results = list(matcher.match_ops(ops)) + for match_result in match_results: + operation = match_result.get_op("extract_image_patches") + input_shape = g.get_shape(operation.input[0]) + output_shape = operation.output_shapes[0] + + sizes = operation.get_attr_value("ksizes") + strides = operation.get_attr_value("strides") + rates = operation.get_attr_value("rates") + padding = operation.get_attr_str("padding") + + # Our constraints. + utils.make_sure(0 not in output_shape, "Empty ExtractImagePatches output is unsupported.") + [_, size_rows, size_cols, _] = sizes + + # Transform input into [N * C, H, W, 1]. + transformed_input = g.make_node("Reshape", inputs=[ + g.make_node("Transpose", inputs=operation.input, attr=dict(perm=[0, 3, 1, 2])).output[0], + g.make_const(utils.make_name("new_shape"), np.int64([ + input_shape[0] * input_shape[3], + input_shape[1], + input_shape[2], + 1, + ])).output[0], + ]) + + # Create identity kernel. + k = size_rows * size_cols + identity_kernel = g.make_node("Reshape", inputs=[ + g.make_node("EyeLike", inputs=[ + g.make_node("ConstantOfShape", inputs=[ + g.make_const(utils.make_name("eye_size"), np.array([k, k], dtype=np.int64)).output[0], + ]).output[0], + ]).output[0], + g.make_const(utils.make_name("new_shape"), np.array([ + size_rows, + size_cols, + 1, + k, + ], dtype=np.int64)).output[0], + ]) + + # Convolve into [N * C, ?H, ?W, K]. + convolution = g.make_node("Conv2D", inputs=[transformed_input.output[0], identity_kernel.output[0]], + attr=dict(strides=strides, dilations=rates, padding=padding, data_format="NHWC"), + shapes=[[input_shape[0] * input_shape[3], output_shape[1], output_shape[2], k]], + dtypes=operation.output_dtypes, skip_conversion=False) + + # Transform into [N, ?H, ?W, C * K]. + output_node = g.make_node("Reshape", inputs=[ + g.make_node("Transpose", inputs=[ + g.make_node("Reshape", inputs=[ + convolution.output[0], + g.make_const(utils.make_name("new_shape"), np.array([ + input_shape[0], + input_shape[3], + output_shape[1], + output_shape[2], + k, + ], dtype=np.int64)).output[0], + ]).output[0], + ], attr=dict(perm=[0, 2, 3, 4, 1])).output[0], + g.make_const(utils.make_name("new_shape"), np.array(output_shape, dtype=np.int64)).output[0], + ]) + + # Replace node. + g.replace_all_inputs(operation.output[0], output_node.output[0]) + g.remove_node(operation.name) + return g.get_nodes() diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index c2c881e77..339cc6f18 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -598,6 +598,7 @@ def compat_handler(ctx, node, **kwargs): rewriter_lstm_tf2, rewrite_gru_tf2, rewrite_single_direction_lstm, + rewrite_extract_image_patches, # bi-directional rewrite_bi_direction_lstm, rewrite_single_direction_gru,