Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Microsoft/MMdnn
Browse files Browse the repository at this point in the history
  • Loading branch information
wangqianwen0418 committed Nov 29, 2017
2 parents 2fd18eb + b22ba1c commit d4786dc
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 103 deletions.
41 changes: 24 additions & 17 deletions mmdnn/conversion/examples/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
#----------------------------------------------------------------------------------------------

from __future__ import absolute_import

import argparse
import numpy as np
import sys
import os
from six import text_type as _text_type
from tensorflow.contrib.keras.python.keras.preprocessing import image

# work for tf 1.4 in windows & linux
from tensorflow.contrib.keras.api.keras.preprocessing import image

# work for tf 1.3 & 1.4 in linux
# from tensorflow.contrib.keras.python.keras.preprocessing import image


class TestKit(object):
Expand Down Expand Up @@ -92,15 +99,15 @@ def __init__(self):
default = "mmdnn/conversion/examples/data/seagull.jpg",
help = 'Test image path.'
)

parser.add_argument('--dump',
type = _text_type,
default = None,
help = 'Target model path.')

self.args = parser.parse_args()
if self.args.n.endswith('.py'):
self.args.n = self.args.n[:-3]
self.args.n = self.args.n[:-3]
self.MainModel = __import__(self.args.n)


Expand All @@ -125,7 +132,7 @@ def Standard(path, size):
x *= 2.0
return x


@staticmethod
def Identity(path, size, BGRTranspose=False):
img = image.load_img(path, target_size = (size, size))
Expand All @@ -135,51 +142,51 @@ def Identity(path, size, BGRTranspose=False):
return x


def preprocess(self, image_path):
def preprocess(self, image_path):
func = self.preprocess_func[self.args.s][self.args.preprocess]
return func(image_path)


def print_result(self, predict):
predict = np.squeeze(predict)
top_indices = predict.argsort()[-5:][::-1]
self.result = [(i, predict[i]) for i in top_indices]
print (self.result)

def print_intermediate_result(self, intermediate_output, if_transpose = False):

def print_intermediate_result(self, intermediate_output, if_transpose = False):
intermediate_output = np.squeeze(intermediate_output)

if if_transpose == True:
intermediate_output = np.transpose(intermediate_output, [2, 0, 1])

print (intermediate_output)
print (intermediate_output.shape)


def test_truth(self):
this_truth = self.truth[self.args.s][self.args.preprocess]
for index, i in enumerate(self.result):
for index, i in enumerate(self.result):
assert this_truth[index][0] == i[0]
assert np.isclose(this_truth[index][1], i[1], atol = 1e-6)

print ("Test model [{}] from [{}] passed.".format(
self.args.preprocess,
self.args.s
))


def inference(self, image_path):
self.preprocess(image_path)
self.preprocess(image_path)
self.print_result()


def dump(self, path = None):
raise NotImplementError()


'''
if __name__=='__main__':
if __name__=='__main__':
tester = TestKit()
tester.inference('examples/data/elephant.jpg')
'''
4 changes: 2 additions & 2 deletions mmdnn/conversion/tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ We will use the **resnet_v2_152** model as an example.

```bash
$ wget http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz -P examples/tf/
$ tar -xvf examples/tf/inception_v3_2016_08_28.tar.gz
$ rm examples/tf/inception_v3_2016_08_28.tar.gz
$ tar -xvf examples/tf/resnet_v2_152_2017_04_14.tar.gz
$ rm examples/tf/resnet_v2_152_2017_04_14.tar.gz
$ mv *.ckpt *.graph examples/tf/
```

Expand Down
Loading

0 comments on commit d4786dc

Please sign in to comment.