Skip to content

Commit

Permalink
Merge pull request #526 from zhijxu-MS/bert_bug
Browse files Browse the repository at this point in the history
fix bug
  • Loading branch information
nbcsm authored May 16, 2019
2 parents 61ff34f + 95d203a commit 64423cd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tf2onnx/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
# input sequence should be "data", "starts", "ends", "axes", "steps"
attr = {}
data = self.convert_to_input(kwargs.pop("data"))
starts = self.convert_to_input(kwargs.pop("starts"))
ends = self.convert_to_input(kwargs.pop("ends"))
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True)
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True)
starts = self.convert_to_input(kwargs.pop("starts"), dtype=np.int64)
ends = self.convert_to_input(kwargs.pop("ends"), dtype=np.int64)
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True, dtype=np.int64)
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True, dtype=np.int64)
inputs = [data, starts, ends, axes, steps]

# pro-process inputs and attr
Expand Down Expand Up @@ -78,7 +78,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]

def convert_to_input(self, tensor, is_optional=False):
def convert_to_input(self, tensor, is_optional=False, dtype=None):
"""in ONNX, input shold come from node, so it must be a string"""
if is_optional and tensor is None:
return None
Expand All @@ -87,7 +87,7 @@ def convert_to_input(self, tensor, is_optional=False):

res = tensor
if isinstance(tensor, list):
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor)).output[0]
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor, dtype)).output[0]

utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed")

Expand Down

0 comments on commit 64423cd

Please sign in to comment.