-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain_vgg19.py
98 lines (79 loc) · 3.99 KB
/
train_vgg19.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Simple tester for the vgg19_trainable
"""
import numpy as np
import tensorflow as tf
from dataSetGenerator import append,picShow
from vgg19 import vgg19_trainable as vgg19
import argparse
parser = argparse.ArgumentParser(prog="Train vgg19",description="Simple tester for the vgg19_trainable")
parser.add_argument('--dataset', metavar='dataset', type=str,required=True,
help='DataSet Name')
parser.add_argument('--batch', metavar='batch', type=int, default=10, help='batch size ')
parser.add_argument('--epochs', metavar='epochs', type=int, default=30,
help='number of epoch to train the network')
args = parser.parse_args()
classes_name = args.dataset
batch_size = args.batch
epochs = args.epochs
# batch_size = 10
# epochs = 30
# classes_name = "SIRI-WHU"
# classes_name = "UCMerced_LandUse"
# classes_name = "RSSCN7"
classes = np.load("DataSets/{0}/{0}_classes.npy".format(classes_name))
batch = np.load("DataSets/{0}/{0}_dataTrain.npy".format(classes_name))
labels = np.load("DataSets/{0}/{0}_labelsTrain.npy".format(classes_name))
classes_num = len(classes)
rib = batch.shape[1]
with tf.device('/device:GPU:0'):
# with tf.device('/cpu:0'):
# with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=int(environ['NUMBER_OF_PROCESSORS']))) as sess:
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
images = tf.placeholder(tf.float32, [None, rib, rib, 3])
true_out = tf.placeholder(tf.float32, [None, classes_num])
train_mode = tf.placeholder(tf.bool)
try:
vgg = vgg19.Vgg19('Weights/VGG19_{}.npy'.format(classes_name),classes_num)
except:
print('Weights/VGG19_{}.npy Not Exist'.format(classes_name))
vgg = vgg19.Vgg19(None,classes_num)
vgg.build(images,train_mode)
# print number of variables used: 143667240 variables, i.e. ideal size = 548MB
# print('number of variables used:',vgg.get_var_count())
print('Data SHape used:',batch.shape)
sess.run(tf.global_variables_initializer())
# test classification
prob = sess.run(vgg.prob, feed_dict={images: batch[:8], train_mode: False})
# picShow(batch[:8], labels[:8], classes, None, prob, True)
# simple 1-step training
cost = tf.reduce_sum((vgg.prob - true_out) ** 2)
train = tf.train.GradientDescentOptimizer(0.0001).minimize(cost)
correct_prediction = tf.equal(tf.argmax(prob), tf.argmax(true_out))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
batche_num = len(batch)
accs = []
costs = []
for _ in range(epochs):
indice = np.random.permutation(batche_num)
counter = 0
for i in range(int(batche_num/batch_size)):
min_batch = indice[i*batch_size:(i+1)*batch_size]
cur_cost, cur_train,cur_acc= sess.run([cost, train,acc], feed_dict={images: batch[min_batch], true_out: labels[min_batch], train_mode: True})
print("Iteration :{} Batch :{} loss :{}".format(_, i, cur_cost))
accs.append(cur_acc)
costs.append(cur_cost)
counter += 1
if counter % 100 == 0:
# save Weights
vgg.save_npy(sess, 'Weights/VGG19_{}.npy'.format(classes_name))
# save graph data
append(costs,'Data/cost19_{}.txt'.format(classes_name))
append(accs,'Data/acc19_{}.txt'.format(classes_name))
# save Weights
vgg.save_npy(sess, 'Weights/VGG19_{}.npy'.format(classes_name))
# test classification again, should have a higher probability about tiger
prob = sess.run(vgg.prob, feed_dict={images: batch[:8], train_mode: False})
picShow(batch[:8], labels[:8], classes, None, prob)
# import subprocess
# subprocess.call(["shutdown", "/s"])