forked from wuyang0329/unet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
64 lines (54 loc) · 2.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#encoding:utf-8
from model_v2 import *
from data import *
import os
import keras
from keras.callbacks import TensorBoard
import tensorflow as tf
import keras.backend.tensorflow_backend as K
import matplotlib.pyplot as plt
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
K.set_session(sess)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
#path to images which are prepared to train a model
train_path = "CamVid"
image_folder = "train"
label_folder = "trainannot"
valid_path = "CamVid"
valid_image_folder ="val"
valid_label_folder = "valannot"
log_filepath = './log'
flag_multi_class = True
num_classes = 12
dp = data_preprocess(train_path=train_path,image_folder=image_folder,label_folder=label_folder,
valid_path=valid_path,valid_image_folder=valid_image_folder,valid_label_folder=valid_label_folder,
flag_multi_class=flag_multi_class,
num_classes=num_classes)
# train your own model
train_data = dp.trainGenerator(batch_size=2)
valid_data = dp.validLoad(batch_size=2)
test_data = dp.testGenerator()
model = unet(num_class=num_classes)
tb_cb = TensorBoard(log_dir=log_filepath)
model_checkpoint = keras.callbacks.ModelCheckpoint('./model/CamVid_model_v1.hdf5', monitor='val_loss',verbose=1,save_best_only=True)
history = model.fit_generator(train_data,
steps_per_epoch=200,epochs=30,
validation_steps=10,
validation_data=valid_data,
callbacks=[model_checkpoint,tb_cb])
# draw the loss and accuracy curve
plt.figure(12, figsize=(6, 6), dpi=60)
plt.subplot(211)
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='val')
plt.title('loss')
plt.legend()
plt.subplot(212)
plt.plot(history.history['acc'], label='train')
plt.plot(history.history['val_acc'], label='val')
plt.title('acc')
plt.legend()
plt.show()