diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index 52343c43d..f2faa010c 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -28,7 +28,6 @@ from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info from tf2onnx.tf_utils import compress_graph_def from tf2onnx.graph import ExternalTensorStorage -from tf2onnx.tflite.Model import Model if is_tf2(): @@ -249,14 +248,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs): def tflite_has_supported_types(self, tflite_path): try: - with open(tflite_path, 'rb') as f: - buf = f.read() - buf = bytearray(buf) - model = Model.GetRootAsModel(buf, 0) - tensor_cnt = model.Subgraphs(0).TensorsLength() interpreter = tf.lite.Interpreter(tflite_path) - for i in range(tensor_cnt): - dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access + tensor_details = interpreter.get_tensor_details() + for tensor_detail in tensor_details: + dtype = tensor_detail.get('dtype') if np.dtype(dtype).kind == 'O': return False return True diff --git a/tf2onnx/tflite_utils.py b/tf2onnx/tflite_utils.py index 6e3f2d024..223643fb1 100644 --- a/tf2onnx/tflite_utils.py +++ b/tf2onnx/tflite_utils.py @@ -196,14 +196,14 @@ def read_tflite_model(tflite_path): try: interpreter = tf.lite.Interpreter(tflite_path) interpreter.allocate_tensors() - tensor_cnt = model.Subgraphs(0).TensorsLength() - for i in range(tensor_cnt): - name = model.Subgraphs(0).Tensors(i).Name().decode() - details = interpreter._get_tensor_details(i) # pylint: disable=protected-access - if "shape_signature" in details: - tensor_shapes[name] = details["shape_signature"].tolist() - elif "shape" in details: - tensor_shapes[name] = details["shape"].tolist() + tensor_details = interpreter.get_tensor_details() + + for tensor_detail in tensor_details: + name = tensor_detail.get('name') + if "shape_signature" in tensor_detail: + tensor_shapes[name] = tensor_detail["shape_signature"].tolist() + elif "shape" in tensor_detail: + tensor_shapes[name] = tensor_detail["shape"].tolist() except Exception as e: # pylint: disable=broad-except logger.warning("Error loading model into tflite interpreter: %s", e) tflite_graphs = get_model_subgraphs(model)