-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_flags.py
299 lines (236 loc) · 9.75 KB
/
train_flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import gin
from absl import flags
FLAGS = flags.FLAGS
FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet'
flags.DEFINE_bool(
'use_tpu', default=True,
help=('Use TPU to execute the model for training and evaluation. If'
' --use_tpu=false, will use whatever devices are available to'
' TensorFlow by default (e.g. CPU and GPU)'))
# Cloud TPU Cluster Resolvers
flags.DEFINE_string(
'tpu', default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')
flags.DEFINE_string(
'master',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')
flags.DEFINE_string('tpu_job_name', default=None, help='The tpu worker name.')
flags.DEFINE_string(
'gcp_project', default=None,
help='Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
flags.DEFINE_string(
'tpu_zone', default=None,
help='GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
# Model specific flags
flags.DEFINE_string(
'data_dir', default=FAKE_DATA_DIR,
help=('The directory where the ImageNet input data is stored. Please see'
' the README.md for the expected data format.'))
flags.DEFINE_string(
'model_dir', default=None,
help=('The directory where the model and training/evaluation summaries are'
' stored.'))
flags.DEFINE_string(
'restore_dir', default=None,
help=('The directory where the model should be restored from'))
flags.DEFINE_bool(
'restore_trainable_variables', default=True,
help=('Only restore trainable variables'))
flags.DEFINE_string(
'params', default=None,
help=('The file to read model parameters from'))
flags.DEFINE_string(
'dataset', default=None,
help=('The file to read examples from'))
flags.DEFINE_string(
'export_dataset', default=None,
help=('Export the dataset as a .tfrecord file'))
flags.DEFINE_integer(
'resnet_depth', default=50,
help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'
' 200}. ResNet-18 and 34 use the pre-activation residual blocks'
' without bottleneck layers. The other models use pre-activation'
' bottleneck layers. Deeper models require more training time and'
' more memory and may require reducing --train_batch_size to prevent'
' running out of memory.'))
flags.DEFINE_string(
'mode', default='in_memory_eval',
help='One of {"train_and_eval", "train", "eval"}.')
flags.DEFINE_integer(
'train_steps', default=112590,
help=('The number of steps to use for training. Default is 112590 steps'
' which is approximately 90 epochs at batch size 1024. This flag'
' should be adjusted according to the --train_batch_size flag.'))
flags.DEFINE_integer(
'train_batch_size', default=1024, help='Batch size for training.')
flags.DEFINE_integer(
'eval_batch_size', default=1024, help='Batch size for evaluation.')
flags.DEFINE_integer(
'num_train_images', default=1281167, help='Size of training data set.')
flags.DEFINE_integer(
'num_eval_images', default=50000, help='Size of evaluation data set.')
flags.DEFINE_integer(
'num_label_classes', default=1000, help='Number of classes, at least 2')
flags.DEFINE_integer(
'steps_per_eval', default=1251,
help=('Controls how often evaluation is performed. Since evaluation is'
' fairly expensive, it is advised to evaluate as infrequently as'
' possible (i.e. up to --train_steps, which evaluates the model only'
' after finishing the entire training regime).'))
flags.DEFINE_integer(
'eval_timeout',
default=None,
help='Maximum seconds between checkpoints before evaluation terminates.')
flags.DEFINE_bool(
'skip_host_call',
default=True,
help=('Skip the host_call which is executed every training step. This is'
' generally used for generating training summaries (train loss,'
' learning rate, etc...). When --skip_host_call=false, there could'
' be a performance drop if host_call function is slow and cannot'
' keep up with the TPU-side computation.'))
flags.DEFINE_integer(
'iterations_per_loop', default=1251,
help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
' If the number of iterations in the loop would exceed the number of'
' train steps, the loop will exit before reaching'
' --iterations_per_loop. The larger this value is, the higher the'
' utilization on the TPU.'))
flags.DEFINE_integer(
'num_parallel_calls', default=64,
help=('Cycle length of the parallel interleave in tf.data.dataset.'))
flags.DEFINE_integer(
'num_prefetch_threads',
default=16,
help=('Number of prefetch threads in CPU for the input pipeline'))
flags.DEFINE_bool(
'prefetch_depth_auto_tune',
default=True,
help=('Number of prefetch threads in CPU for the input pipeline'))
flags.DEFINE_integer(
'num_cores', default=8,
help=('Number of TPU cores. For a single TPU device, this is 8 because each'
' TPU has 4 chips each with 2 cores.'))
flags.DEFINE_string(
'bigtable_project', None,
'The Cloud Bigtable project. If None, --gcp_project will be used.')
flags.DEFINE_string(
'bigtable_instance', None,
'The Cloud Bigtable instance to load data from.')
flags.DEFINE_string(
'bigtable_table', 'imagenet',
'The Cloud Bigtable table to load data from.')
flags.DEFINE_string(
'bigtable_train_prefix', 'train_',
'The prefix identifying training rows.')
flags.DEFINE_string(
'bigtable_eval_prefix', 'validation_',
'The prefix identifying evaluation rows.')
flags.DEFINE_string(
'bigtable_column_family', 'tfexample',
'The column family storing TFExamples.')
flags.DEFINE_string(
'bigtable_column_qualifier', 'example',
'The column name storing TFExamples.')
flags.DEFINE_string(
'data_format', default='channels_last',
help=('A flag to override the data format used in the model. The value'
' is either channels_first or channels_last. To run the network on'
' CPU or TPU, channels_last should be used. For GPU, channels_first'
' will improve performance.'))
# TODO(chrisying): remove this flag once --transpose_tpu_infeed flag is enabled
# by default for TPU
flags.DEFINE_bool(
'transpose_input', default=True,
help='Use TPU double transpose optimization')
flags.DEFINE_string(
'export_dir',
default=None,
help=('The directory where the exported SavedModel will be stored.'))
flags.DEFINE_string(
'precision', default='bfloat16',
help=('Precision to use; one of: {bfloat16, float32}'))
flags.DEFINE_float(
'base_learning_rate', default=0.1,
help=('Base learning rate when train batch size is 256.'))
flags.DEFINE_float(
'momentum', default=0.9,
help=('Momentum parameter used in the MomentumOptimizer.'))
flags.DEFINE_float(
'weight_decay', default=1e-4,
help=('Weight decay coefficiant for l2 regularization.'))
flags.DEFINE_float(
'label_smoothing', default=0.0,
help=('Label smoothing parameter used in the softmax_cross_entropy'))
flags.DEFINE_integer('log_step_count_steps', 64, 'The number of steps at '
'which the global step information is logged.')
flags.DEFINE_bool('enable_lars',
default=False,
help=('Enable LARS optimizer for large batch training.'))
flags.DEFINE_float('poly_rate', default=0.0,
help=('Set LARS/Poly learning rate.'))
flags.DEFINE_bool(
'use_cache', default=True, help=('Enable cache for training input.'))
flags.DEFINE_bool(
'cache_decoded_image', default=False, help=('Cache decoded images.'))
flags.DEFINE_bool(
'use_async_checkpointing', default=False, help=('Enable async checkpoint'))
flags.DEFINE_float(
'stop_threshold', default=0.759, help=('Stop threshold for MLPerf.'))
flags.DEFINE_bool(
'use_eval_runner', default=True, help=('Bypass estimator on eval.'))
flags.DEFINE_bool(
'use_train_runner', default=False, help=('Bypass estimator on train.'))
flags.DEFINE_integer(
'tpu_cores_per_host', default=8, help=('Number of TPU cores per host.'))
flags.DEFINE_integer('image_size', 224, 'The input image size.')
flags.DEFINE_integer(
'distributed_group_size',
default=1,
help=('When set to > 1, it will enable distributed batch normalization'))
# Learning rate schedule
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
flags.DEFINE_boolean(
'output_summaries',
default=False,
help=('When set to true, outputs tensorboard logs'))
flags.DEFINE_boolean(
'enable_auto_tracing',
default=False,
help=('When set to true traces collected from worker-0 on every run'))
import tflex
flags.DEFINE_multi_string(
"gin_config", [],
"List of paths to the config files.")
flags.DEFINE_multi_string(
"gin_bindings", [],
"Newline separated list of Gin parameter bindings.")
@gin.configurable
def run_config(*,
iterations_per_loop,
save_checkpoints_steps,
train_steps=-1,
**kwargs):
return tflex.Dictator(
iterations_per_loop=iterations_per_loop,
save_checkpoints_steps=save_checkpoints_steps,
train_steps=train_steps,
**kwargs)
@gin.configurable
def options(*,
dataset,
batch_per_core,
use_tpu=True,
**kwargs):
return tflex.Dictator(
dataset=dataset,
batch_per_core=batch_per_core,
use_tpu=use_tpu,
**kwargs)