-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
71 lines (53 loc) · 1.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from visualizers import surface_plot
from data.batch_generator import create_batch_generator
import data.sample
import decision_functions
import graph
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def print_dict(dictionary):
sorted_keys = dictionary.keys()
sorted_keys.sort()
for key in sorted_keys:
print('{} : {}'.format(key, dictionary[key]))
def train_model(
network,
sess,
data,
labels,
batch_size,
num_epochs,
seed=None):
batches = create_batch_generator(
data,
batch_size,
num_epochs,
labels,
seed)
net_vars = {
'train_op' : network.get_train_op(),
'cross_entropy_loss' : network.get_loss('cross_entropy'),
'regularization_loss' : network.get_loss('regularization'),
'total_loss' : network.get_loss('total'),
'prob' : network.inference.get_prob(),
}
vars_to_eval = {
'cross_entropy_loss' : net_vars['cross_entropy_loss'],
'regularization_loss:' : net_vars['regularization_loss'],
'total_loss' : net_vars['total_loss'],
'prob' : net_vars['prob'],
}
network.initialize(sess)
for iter, batch in enumerate(batches):
print('-'*10)
print('Iter: ' + str(iter))
inputs = {
'data': batch['data'],
'labels': batch['labels'],
'is_training': True,
}
feed_dict = network.create_feed_dict(inputs)
sess.run(net_vars['train_op'], feed_dict=feed_dict)
eval_vars = network.eval_vars(vars_to_eval,feed_dict,sess)
print(eval_vars['total_loss'])