Skip to content

Commit

Permalink
Merge pull request #418 from lucienwang1009/optimization_backend
Browse files Browse the repository at this point in the history
fix transpose bug
  • Loading branch information
nbcsm authored Mar 29, 2019
2 parents df08059 + 8edc4c9 commit 9ffbbfd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
30 changes: 28 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tf2onnx import utils
from tf2onnx.graph import GraphUtil
from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main
from common import unittest_main, group_nodes_by_type


# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
Expand Down Expand Up @@ -162,6 +162,33 @@ def test_transpose_with_identity(self):
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=1)

def test_trans_output_as_graph_outputs(self):
"""
If transpose's output is graph's output, don't optimize it.
"""
trans = helper.make_node("Transpose", ["X"], ["Y"], name="trans", perm=[0, 2, 3, 1])
graph_proto = helper.make_graph(
[trans],
"trans-to-graph-output",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 5, 3))],
)

graph = GraphUtil.create_graph_from_onnx_graph(graph_proto)
# remove identity to graph output
identity_op = graph.get_node_by_output(graph.outputs[0])
graph.outputs = [identity_op.input[0]]
graph.remove_node(identity_op.name)

optimized_graph = GraphUtil.optimize_graph(graph, "onnx-tests")

self.assertTrue(optimized_graph, msg="graph after optimizer should not be None")

trans_cnt = len(group_nodes_by_type(optimized_graph)["Transpose"])

self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " +
str(trans_cnt) + " left")

# Tranpose Optimizer Tests End

# Identity Optimizer Tests Start
Expand Down Expand Up @@ -396,6 +423,5 @@ def test_duplicated_need_multiple_run(self):
op_type="Log", remaining_op_num=3)
# Merge Duplicated Nodes Optimizer Tests End


if __name__ == "__main__":
unittest_main()
3 changes: 3 additions & 0 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def _switch_transpose_and_node(self, node, trans):
# if return value is True, then it means Transpose is handled as designed
# otherwise, it means that we skip handling since it is not in our support set
def _handle_nhwc_tranpose(self, trans):
if trans.output[0] in self._g.outputs:
log.debug("%s connects to graph outputs, skip", trans.output[0])
return False
out_nodes = self._g.find_output_consumers(trans.output[0])
if len(out_nodes) == 1:
p = out_nodes[0]
Expand Down

0 comments on commit 9ffbbfd

Please sign in to comment.