Skip to content

Commit

Permalink
Merge pull request #79 from Microsoft/caffe_emitter
Browse files Browse the repository at this point in the history
Caffe emitter
  • Loading branch information
kitstar authored Feb 12, 2018
2 parents 0c46278 + 9c30d70 commit 13d909e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 4 deletions.
9 changes: 5 additions & 4 deletions mmdnn/conversion/caffe/caffe_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def header_code(self):
from caffe import to_proto
from six import text_type as _text_type
n = caffe.NetSpec()
__weights_dict = dict()
Expand All @@ -58,14 +57,15 @@ def load_weights(weight_file):
def KitModel(weight_file = None):
n = caffe.NetSpec()
"""

@property
def end_code(self):
return """
return """ return n
def make_net(prototxt):
KitModel()
n = KitModel()
with open(prototxt, 'w') as fpb:
print(n.to_proto(), file=fpb)
Expand All @@ -88,6 +88,7 @@ def gen_weight(weight_file, model, prototxt):
if 'bias' in __weights_dict[key]:
net.params[key][1].data.flat = __weights_dict[key]['bias']
net.save(model)
return net
Expand Down
77 changes: 77 additions & 0 deletions mmdnn/conversion/examples/caffe/imagenet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#----------------------------------------------------------------------------------------------

import argparse
import numpy as np
import sys
import os
from six import text_type as _text_type
from mmdnn.conversion.examples.imagenet_test import TestKit
import caffe


class TestCaffe(TestKit):

def __init__(self):
super(TestCaffe, self).__init__()

self.truth['caffe']['alexnet'] = [(657, 0.41121054), (744, 0.20789708), (847, 0.086725503), (821, 0.05908291), (595, 0.058017164)]

if self.args.dump:
self.dump_net = self.args.dump + '.prototxt'
self.dump_weight = self.args.dump + '.caffemodel'
else:
self.dump_net = 'tmp.prototxt'
self.dump_weight = 'tmp.caffemodel'

self.MainModel.make_net(self.dump_net)
self.MainModel.gen_weight(self.args.w, self.dump_weight, self.dump_net)
self.model = caffe.Net(self.dump_net, self.dump_weight, caffe.TEST)

def preprocess(self, image_path):
x = super(TestCaffe, self).preprocess(image_path)
# caffe uses NCHW
x = np.transpose(x, [2, 0, 1])
self.data = np.expand_dims(x, 0)


def print_result(self):
self.model.blobs['data'].data[...] = self.data
predict = self.model.forward()['prob'][0]
super(TestCaffe, self).print_result(predict)


def print_intermediate_result(self, layer_name, if_transpose = False):
intermediate_output = self.model.blobs[layer_name].data[0]
super(TestCaffe, self).print_intermediate_result(intermediate_output, if_transpose)


def inference(self, image_path):
self.preprocess(image_path)

# self.print_intermediate_result('', False)

self.print_result()

self.test_truth()

# delete tmp model files
if os.path.isfile(self.dump_net):
os.remove(self.dump_net)
if os.path.isfile(self.dump_weight):
os.remove(self.dump_weight)


def dump(self):
print ('Caffe model files are saved as [{}] and [{}], generated by [{}.py] and [{}].'.format(
self.dump_net, self.dump_weight, self.args.n, self.args.w))


if __name__=='__main__':
tester = TestCaffe()
if tester.args.dump:
tester.dump()
else:
tester.inference(tester.args.image)
1 change: 1 addition & 0 deletions mmdnn/conversion/examples/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TestKit(object):
truth = {
'caffe' : {
'alexnet' : [(821, 0.25088307), (657, 0.20857951), (744, 0.096812263), (595, 0.066312768), (847, 0.053720973)],
#'alexnet' : [(657, 0.41121086), (744, 0.20789686), (847, 0.086725488), (821, 0.059082959), (595, 0.058017101)],
'vgg19' : [(21, 0.37522122), (144, 0.28500062), (23, 0.099720284), (134, 0.036305398), (22, 0.033559237)],
'inception_v1' : [(21, 0.93591732), (23, 0.037170019), (22, 0.014315935), (128, 0.005050648), (749, 0.001965977)],
'resnet152' : [(144, 0.93159181), (23, 0.033074539), (21, 0.028599562), (99, 0.001878676), (146, 0.001557963)],
Expand Down

0 comments on commit 13d909e

Please sign in to comment.