Skip to content

Commit

Permalink
Hotfix: Circumvent tf-2.12 breaking change on tflite subgraph API to …
Browse files Browse the repository at this point in the history
…unbreak UT (#2204)

TF-2.12.0 introduced API change that breaks tf2onnx UT tests on the
tflite paths, due to the addition of compulsory subgraph arg to several
function's input signature:
tensorflow/tensorflow@55d84d7

This commit is a temporary hotfix to unbreak related UT failure.
Existing tf2onnx's use cases get tflite Interpreter's tensors from model's
first subgraph only. The hotfix hard-codes subgraph index to `0` to
retain the same behavior while resolves API diff.

Signed-off-by: Yu Cong <congyc@amazon.com>
  • Loading branch information
q-ycong-p authored Jul 28, 2023
1 parent 5259b4a commit 0152029
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
11 changes: 3 additions & 8 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info
from tf2onnx.tf_utils import compress_graph_def
from tf2onnx.graph import ExternalTensorStorage
from tf2onnx.tflite.Model import Model


if is_tf2():
Expand Down Expand Up @@ -249,14 +248,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):

def tflite_has_supported_types(self, tflite_path):
try:
with open(tflite_path, 'rb') as f:
buf = f.read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
tensor_cnt = model.Subgraphs(0).TensorsLength()
interpreter = tf.lite.Interpreter(tflite_path)
for i in range(tensor_cnt):
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
tensor_details = interpreter.get_tensor_details()
for tensor_detail in tensor_details:
dtype = tensor_detail.get('dtype')
if np.dtype(dtype).kind == 'O':
return False
return True
Expand Down
16 changes: 8 additions & 8 deletions tf2onnx/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ def read_tflite_model(tflite_path):
try:
interpreter = tf.lite.Interpreter(tflite_path)
interpreter.allocate_tensors()
tensor_cnt = model.Subgraphs(0).TensorsLength()
for i in range(tensor_cnt):
name = model.Subgraphs(0).Tensors(i).Name().decode()
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
if "shape_signature" in details:
tensor_shapes[name] = details["shape_signature"].tolist()
elif "shape" in details:
tensor_shapes[name] = details["shape"].tolist()
tensor_details = interpreter.get_tensor_details()

for tensor_detail in tensor_details:
name = tensor_detail.get('name')
if "shape_signature" in tensor_detail:
tensor_shapes[name] = tensor_detail["shape_signature"].tolist()
elif "shape" in tensor_detail:
tensor_shapes[name] = tensor_detail["shape"].tolist()
except Exception as e: # pylint: disable=broad-except
logger.warning("Error loading model into tflite interpreter: %s", e)
tflite_graphs = get_model_subgraphs(model)
Expand Down

0 comments on commit 0152029

Please sign in to comment.