-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFALcon_config_imagenet.py
107 lines (90 loc) · 4.92 KB
/
FALcon_config_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 8 14:32:59 2021
@author: saketi, tibrayev
Defines all hyperparameters.
"""
class FALcon_config(object):
# SEED
seed = 16
loader_random_seed = 1
# dataset
dataset = 'imagenet'
dataset_dir = '/home/nano01/a/tibrayev/imagenet/annotated_imagenet2012'
num_classes = 1000
in_num_channels = 3
full_res_img_size = (256, 256) #(height, width) as used in transforms.Resize
gt_bbox_dir = dataset_dir + '/anno_val'
wsol_method = 'PSOL'
pseudo_bbox_dir = './{}/results/ImageNet_train_set/predicted_bounding_boxes/'.format(wsol_method)
valid_split_size = 0.1 # should be in range [0, 1)
# model_M3
model_name = 'vgg16'
initialize = 'pretrained'
assert initialize in ['pretrained', 'random', 'resume_from_pretrained', 'resume_from_random'], ...
"Specify which initialization method to choose. Options ('pretrained', 'random', 'resume_from_pretrained', 'resume_from_random')"
if 'resume' in initialize:
initialize, init_factual = initialize.split("_from_")
else:
init_factual = initialize
if 'vgg' in model_name:
downsampling = 'M'
fc1 = 256
fc2 = 128
dropout = 0.5
norm = 'none'
init_weights = True
adaptive_avg_pool_out = (1, 1)
saccade_fc1 = 256
saccade_dropout = False
assert model_name in ['custom_vgg8_narrow_k2', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']
"Specify which VGG model to use for training. Options ('custom_vgg8_narrow_k2', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn')"
assert norm in ['none', 'batchnorm', 'evonorm'], ...
"Specify which normalization type to use for normalization layers. Options ('batchnorm', 'instancenorm', 'layernorm', 'evonorm')"
elif 'resnet' in model_name:
norm = 'batchnorm'
init_weights = True
adaptive_avg_pool_out = (1, 1)
fc1 = 512 # for RL head
fc2 = 256 # for RL head
assert model_name in ['resnet50', 'resnet101'], ...
"Specify which ResNet model to use for training. Options ('resnet50')"
assert norm in ['batchnorm', 'instancenorm', 'layernorm', 'evonorm'], ...
"Specify which normalization type to use for normalization layers. Options ('batchnorm', 'instancenorm', 'layernorm', 'evonorm')"
# training
train_loader_type = 'train'
if train_loader_type == 'train_and_val':
valid_loader_type = 'train_and_val'
elif train_loader_type == 'train':
valid_loader_type = 'test'
print("Warning: selected training on entire ImageNet train split, hence validation is going to be performed on test (ImageNet val) split!")
else:
raise ValueError("Unrecognized type of split to train on: ({})".format(train_loader_type))
experiment_name = (dataset + '/wsol_method_{}'.format(wsol_method) +
'/trained_on_{}_split/'.format(train_loader_type) +
'arch_{}_{}_init_normalization_{}_seed_{}/'.format(model_name, init_factual, norm, seed))
save_dir = './results/' + experiment_name
batch_size_train = 512
batch_size_eval = 512
epochs = 100
lr_start = 1e-2
lr_min = 1e-5
milestones = [30, 60, 90]
weight_decay = 0.0001
momentum = 0.9
# testing
ckpt_dir = save_dir + 'model.pth'
attr_detection_th = 0.0
# AVS-specific parameters
num_glimpses = 10
fovea_control_neurons = 4
glimpse_size_grid = (20, 20) #(width, height) of each grid when initially dividing image into grid cells
glimpse_size_init = (20, 20) #(width, height) size of initial foveation glimpse at the selected grid cell (usually, the same as above)
glimpse_size_fixed = (96, 96) #(width, height) size of foveated glimpse as perceived by the network
glimpse_size_step = (20, 20) #step size of foveation in (x, y) direction at each action in each (+dx, -dx, +dy, -dy) directions
glimpse_change_th = 0.5 #threshold, deciding whether or not to take the action based on post-sigmoid logit value
iou_th = 0.5
# switching cell behavior
ratio_wrong_init_glimpses = 0.5 # ratio of the incorrect initial glimpses to the total glimpses in the batch
switch_location_th = 0.5