-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcnns.py
154 lines (137 loc) · 5.95 KB
/
cnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from torch import nn
from torchvision.models.resnet import BasicBlock as BasicResidualBlock
NETWORK_ARCHITECTURE_DEFINITIONS = {
'BasicCNN': [
{'out_dim': 32, 'kernel_size': 8, 'stride': 4},
{'out_dim': 64, 'kernel_size': 4, 'stride': 2},
{'out_dim': 64, 'kernel_size': 3, 'stride': 1},
],
'MAGICALCNN': [
{'out_dim': 32, 'kernel_size': 5, 'stride': 1, 'padding': 2},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
],
'MAGICALCNN-resnet': [
{'out_dim': 64, 'stride': 4, 'residual': True},
{'out_dim': 128, 'stride': 2, 'residual': True},
],
'MAGICALCNN-resnet-128': [
{'out_dim': 64, 'stride': 4, 'residual': True},
{'out_dim': 128, 'stride': 2, 'residual': True},
{'out_dim': 128, 'stride': 2, 'residual': True},
],
'MAGICALCNN-resnet-256': [
{'out_dim': 64, 'stride': 4, 'residual': True},
{'out_dim': 128, 'stride': 2, 'residual': True},
{'out_dim': 256, 'stride': 2, 'residual': True},
],
'MAGICALCNN-small': [
{'out_dim': 32, 'kernel_size': 5, 'stride': 2, 'padding': 2},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
{'out_dim': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1},
]
}
def magical_conv_block(in_chans, out_chans, kernel_size, stride, padding, use_bn, use_sn, dropout, activation_cls):
# We sometimes disable bias because batch norm has its own bias.
conv_layer = nn.Conv2d(
in_chans,
out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=not use_bn,
padding_mode='zeros')
if use_sn:
# apply spectral norm if necessary
conv_layer = nn.utils.spectral_norm(conv_layer)
layers = [conv_layer]
if dropout:
# dropout after conv, but before activation
# (doesn't matter for ReLU)
layers.append(nn.Dropout2d(dropout))
layers.append(activation_cls())
if use_bn:
# Insert BN layer after convolution (and optionally after
# dropout). I doubt order matters much, but see here for
# CONTROVERSY:
# https://github.com/keras-team/keras/issues/1802#issuecomment-187966878
layers.append(nn.BatchNorm2d(out_chans))
return layers
class MAGICALCNN(nn.Module):
"""The CNN from the MAGICAL paper."""
def __init__(self,
input_channels,
representation_dim=128,
fc_size=128,
use_bn=True,
use_ln=False,
dropout=None,
use_sn=False,
arch_str='MAGICALCNN-resnet-128',
ActivationCls=torch.nn.ReLU):
super().__init__()
# If block_type == resnet, use ResNet's basic block.
# If block_type == magical, use MAGICAL block from its paper.
assert arch_str in NETWORK_ARCHITECTURE_DEFINITIONS.keys()
width = 1 if 'resnet' in arch_str else 2
self.features_dim = representation_dim
w = width
self.architecture_definition = NETWORK_ARCHITECTURE_DEFINITIONS[arch_str]
conv_layers = []
in_dim = input_channels
block = magical_conv_block
if 'resnet' in arch_str:
block = BasicResidualBlock
for layer_definition in self.architecture_definition:
if layer_definition.get('residual', False):
block_kwargs = {
'stride': layer_definition['stride'],
'downsample': nn.Sequential(nn.Conv2d(in_dim,
layer_definition['out_dim'],
kernel_size=1,
stride=layer_definition['stride']),
nn.BatchNorm2d(layer_definition['out_dim']))
}
conv_layers += [block(in_dim,
layer_definition['out_dim'] * w,
**block_kwargs)]
else:
block_kwargs = {
'stride': layer_definition['stride'],
'kernel_size': layer_definition['kernel_size'],
'padding': layer_definition['padding'],
'use_bn': use_bn,
'use_sn': use_sn,
'dropout': dropout,
'activation_cls': ActivationCls
}
conv_layers += block(in_dim,
layer_definition['out_dim'] * w,
**block_kwargs)
in_dim = layer_definition['out_dim']*w
if 'resnet' in arch_str:
conv_layers.append(nn.Conv2d(in_dim, 32, 1))
conv_layers.append(nn.Flatten())
# another FC layer to make feature maps the right size
fc_in_size = 1152
fc_layers = [
nn.Linear(fc_in_size, fc_size * w),
ActivationCls(),
nn.Linear(fc_size * w, representation_dim),
]
if use_sn:
# apply SN to linear layers too
fc_layers = [
nn.utils.spectral_norm(layer) if isinstance(layer, nn.Linear) else layer
for layer in fc_layers
]
all_layers = [*conv_layers, *fc_layers]
self.shared_network = nn.Sequential(*all_layers)
def forward(self, x):
# warn_on_non_image_tensor(x)
return self.shared_network(x)