-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathvgg_net.py
127 lines (101 loc) · 4.36 KB
/
vgg_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import cPickle
import numpy as np
from co_occurence import *
import tensorflow as tf
import tensorflow.contrib.layers as lays
from tensorflow.contrib.slim.nets import vgg
slim = tf.contrib.slim
def infer(inputs, is_training=True):
inputs = tf.cast(inputs, tf.float32)
inputs = ((inputs / 255.0)-0.5)*2
#Use Pretrained Base Model
with tf.variable_scope("vgg_16"):
with slim.arg_scope(vgg.vgg_arg_scope()):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
# net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
# net = slim.max_pool2d(net, [2, 2], scope='pool3')
# net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
# net = slim.max_pool2d(net, [2, 2], scope='pool4')
# net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
# net = slim.max_pool2d(net, [2, 2], scope='pool5')
#Append fully connected layer
net1 = slim.flatten(net)
net = slim.fully_connected(net1, 512,
weights_initializer=tf.contrib.layers.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005),
scope='finetune/fc1')
net = slim.fully_connected(net, 2,
activation_fn=None,
weights_initializer=tf.contrib.layers.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005),
scope='finetune/fc2')
return net,net1
def losses(logits, labels):
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
return loss
def optimize(losses):
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(lr, global_step,
num_iter*decay_per, decay_rate, staircase=True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(losses, global_step=global_step)#,
#var_list=slim.get_model_variables("finetune"))
return train_op
if __name__ == '__main__':
tf.reset_default_graph()
batch_size=32
num_epochs=10
lr = 0.001
decay_rate=0.1
decay_per=40 #epoch
image = tf.placeholder(tf.float32, [None, 272, 100, 3])
#Create the training graph
# filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=num_epochs)
# image, label = read_and_decode(filename_queue)
prediction,net_x = infer(image)
# loss = losses(prediction, label)
# train_op = optimize(loss)
data_X = get_vistex()
data_X = np.expand_dims(data_X,-1)
data_X = np.tile(data_X,(1,1,1,3))
print data_X.shape
# exit(0)
indices = np.random.permutation(np.arange(data_X.shape[0]))
data_X = data_X[indices,:,:]
print "Training started"
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
restore = slim.assign_from_checkpoint_fn(
'vgg_16.ckpt',
slim.get_model_variables("vgg_16"))
sess.run(init_op)
restore(sess)
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for epoch in range(num_epochs):
for step in range(data_X.shape[0]/batch_size):
batch_x = data_X[step*batch_size:(step+1)*batch_size]
# batch_x = np.expand_dims(batch_x,-1)
# print batch_x.shape
# Run optimization op (backprop)
# _,summary = sess.run([train_op,merged], feed_dict={X: batch_x})
# train_writer.add_summary(summary, i)
result = sess.run(net_x, feed_dict={image: batch_x})
print result
break
break
# exit(0)
# if step % display_step == 0:
# # Calculate batch loss and accuracy
# log("LR : "+str(learning_rate)+" Epoch : " + str(epoch) + " Step " + str(step))
# coord.request_stop()
# coord.join(threads)
print 'Training Done'
saver = tf.train.Saver(slim.get_model_variables())
saver.save(sess, 'vgg_logs/model.ckpt')
sess.close()