From 95d203ad89cdaba647572139b3575fa64ef9684c Mon Sep 17 00:00:00 2001 From: zhijxu-MS Date: Thu, 16 May 2019 12:26:34 +0800 Subject: [PATCH] fix bug new version numpy will choose dtype accroding to python int value. --- tf2onnx/graph_builder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tf2onnx/graph_builder.py b/tf2onnx/graph_builder.py index dcb7f211e..8f6eb0f2c 100644 --- a/tf2onnx/graph_builder.py +++ b/tf2onnx/graph_builder.py @@ -47,10 +47,10 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None): # input sequence should be "data", "starts", "ends", "axes", "steps" attr = {} data = self.convert_to_input(kwargs.pop("data")) - starts = self.convert_to_input(kwargs.pop("starts")) - ends = self.convert_to_input(kwargs.pop("ends")) - axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True) - steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True) + starts = self.convert_to_input(kwargs.pop("starts"), dtype=np.int64) + ends = self.convert_to_input(kwargs.pop("ends"), dtype=np.int64) + axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True, dtype=np.int64) + steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True, dtype=np.int64) inputs = [data, starts, ends, axes, steps] # pro-process inputs and attr @@ -78,7 +78,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None): return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name, outputs=outputs, shapes=shapes, dtypes=dtypes).output[0] - def convert_to_input(self, tensor, is_optional=False): + def convert_to_input(self, tensor, is_optional=False, dtype=None): """in ONNX, input shold come from node, so it must be a string""" if is_optional and tensor is None: return None @@ -87,7 +87,7 @@ def convert_to_input(self, tensor, is_optional=False): res = tensor if isinstance(tensor, list): - res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor)).output[0] + res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor, dtype)).output[0] utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed")