Skip to content

Commit

Permalink
Merge pull request #524 from onnx/gs/opt-relu6
Browse files Browse the repository at this point in the history
optimize relu6
  • Loading branch information
nbcsm authored May 16, 2019
2 parents 0f880fc + 965e988 commit 61ff34f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ tf2onnx - convert TensorFlow models to ONNX models.
# Supported ONNX version
tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found.

By default we use opset 7 for the resulting ONNX graph since most runtimes will support opset 7. Opset 7 was introduced in onnx-1.2.
We support opset 6 to 10. By default we use opset 7 for the resulting ONNX graph since most runtimes will support opset 7.

Newer releases of ONNX support higher opsets. For example, to create an ONNX graph for opset 8 use in the command line ```--opset 8```.
If you want the graph to be generated with a newer opset, use ```--opset``` in the command line, for example ```--opset 10```.

# Status
We support many TensorFlow models. Support for Fully Connected and Convolutional networks is mature. Dynamic LSTM/GRU/Attention networks should work but the code for this is evolving.
Expand Down Expand Up @@ -41,7 +41,7 @@ For pytorch/caffe2, follow the instructions here:
We tested with pytorch/caffe2 and onnxruntime and unit tests are passing for those.

## Supported Tensorflow and Python Versions
We tested with tensorflow 1.5-1.13 and anaconda **3.5,3.6**.
We are testing with tensorflow 1.5-1.13 and anaconda **3.5,3.6,3.7**.

# Installation
## From pypi
Expand All @@ -55,7 +55,7 @@ python setup.py install
or
python setup.py develop
```
tensorflow-onnx requires onnx-1.2.2 or better and will install/upgrade onnx if needed.
tensorflow-onnx requires onnx-1.5 or better and will install/upgrade onnx if needed.

To create a distribution:
```
Expand All @@ -69,10 +69,10 @@ names with ```--inputs INPUTS``` and ```--outputs OUTPUTS```.

```
python -m tf2onnx.convert
--input SOURCE_GRAPHDEF_PB
--graphdef SOURCE_GRAPHDEF_PB
--checkpoint SOURCE_CHECKPOINT
--saved-model SOURCE_SAVED_MODEL
[--input SOURCE_GRAPHDEF_PB]
[--graphdef SOURCE_GRAPHDEF_PB]
[--checkpoint SOURCE_CHECKPOINT]
[--saved-model SOURCE_SAVED_MODEL]
[--output TARGET_ONNX_MODEL]
[--inputs GRAPH_INPUTS]
[--outputs GRAPH_OUTPUS]
Expand Down
5 changes: 2 additions & 3 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,10 +900,9 @@ def test_tanh(self):
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

@check_onnxruntime_incompatibility("Max")
def test_relu6(self):
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
x_val = np.array([0.5, 1.0, -0.5, -1.0, 6, 7], dtype=np.float32).reshape((2, 3))
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
x_ = tf.nn.relu6(x)
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val})
Expand Down
4 changes: 2 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def test_relu6(self):
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph, opset=self.config.opset)
self.assertEqual(
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Clip] output [op_type=Identity] '
'input1:0 -> Relu6 Relu6:0 -> output }',
onnx_to_graphviz(g))

def test_conv2d(self):
Expand Down
8 changes: 4 additions & 4 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ class Relu6:
@classmethod
def version_4(cls, ctx, node, **kwargs):
# relu6 = min(max(features, 0), 6)
node.type = "Relu"
clip_name = utils.make_name(node.name)
clip_node = ctx.insert_new_node_on_output("Clip", node.output[0], name=clip_name, min=0.0, max=6.0)
ctx.copy_shape(node.output[0], clip_node.output[0])
# relu6 = min(max(features, 0), 6)
node.type = "Clip"
node.set_attr("min", 0.0)
node.set_attr("max", 6.0)


@tf_op("Rsqrt")
Expand Down

0 comments on commit 61ff34f

Please sign in to comment.