Skip to content

Commit

Permalink
Add lstm sigmoid dropout test for dependency optimization (#2027)
Browse files Browse the repository at this point in the history
* Add lstm sigmoid_dropout test for dependency optimization

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>
  • Loading branch information
hwangdeyu authored Aug 23, 2022
1 parent 7e3fd01 commit f4902a4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def group_nodes_by_type(graph):


def check_op_count(graph, op_type, expected_count, disabled=True):
# FIXME: after switching to grappler some of the op counts are off. Fix later.
# The grappler optimization may change some of the op counts.
return disabled or len(group_nodes_by_type(graph)[op_type]) == expected_count


Expand Down
21 changes: 20 additions & 1 deletion tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.python.ops import variable_scope
from backend_test_base import Tf2OnnxBackendTestBase
from common import check_tf_min_version, unittest_main, check_opset_after_tf_version, \
skip_tf2, skip_tf_versions, check_op_count
skip_tf2, skip_tf_versions, check_op_count, skip_tfjs

from tf2onnx.tf_loader import is_tf2

Expand Down Expand Up @@ -51,6 +51,7 @@ def new_graph_validator(g):
# Skip checks for tflite graphs (no ":" in outputs)
return good
good = good and check_op_count(g, "LSTM", require_lstm_count, disabled=False)
# If LSTM op rewriter failed to work, Loop op will be shown in general.
good = good and check_op_count(g, "Loop", 0, disabled=False)
return good
try:
Expand Down Expand Up @@ -774,5 +775,23 @@ def func(x):
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tfjs("TFJS converts model incorrectly")
def test_keras_lstm_sigmoid_dropout(self):
in_shape = [16, 16]
batch_size = 2
x_val = np.random.uniform(size=[batch_size] + in_shape).astype(np.float32)

model = tf.keras.models.Sequential()
model_in = tf.keras.layers.Input(shape=tuple(in_shape), name="input")
lstm = tf.keras.layers.LSTM(16, activation='sigmoid', dropout=0.1)
model.add(model_in)
model.add(lstm)

def func(x):
y = model(x)
return tf.identity(y[0], name="output")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)

if __name__ == '__main__':
unittest_main()

0 comments on commit f4902a4

Please sign in to comment.