forked from stonezwr/TSSL-BP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnns.py
62 lines (56 loc) · 2.55 KB
/
cnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
import layers.conv as conv
import layers.pooling as pooling
import layers.dropout as dropout
import layers.linear as linear
import functions.loss_f as f
import global_v as glv
class Network(nn.Module):
def __init__(self, network_config, layers_config, input_shape):
super(Network, self).__init__()
self.layers = []
self.network_config = network_config
self.layers_config = layers_config
parameters = []
print("Network Structure:")
for key in layers_config:
c = layers_config[key]
if c['type'] == 'conv':
self.layers.append(conv.ConvLayer(network_config, c, key, input_shape))
self.layers[-1].to(glv.device)
input_shape = self.layers[-1].out_shape
parameters.append(self.layers[-1].get_parameters())
elif c['type'] == 'linear':
self.layers.append(linear.LinearLayer(network_config, c, key, input_shape))
self.layers[-1].to(glv.device)
input_shape = self.layers[-1].out_shape
parameters.append(self.layers[-1].get_parameters())
elif c['type'] == 'pooling':
self.layers.append(pooling.PoolLayer(network_config, c, key, input_shape))
self.layers[-1].to(glv.device)
input_shape = self.layers[-1].out_shape
elif c['type'] == 'dropout':
self.layers.append(dropout.DropoutLayer(c, key))
else:
raise Exception('Undefined layer type. It is: {}'.format(c['type']))
self.my_parameters = nn.ParameterList(parameters)
print("-----------------------------------------")
def forward(self, spike_input, epoch, is_train):
spikes = f.psp(spike_input, self.network_config)
skip_spikes = {}
assert self.network_config['model'] == "LIF"
for i in range(len(self.layers)):
if self.layers[i].type == "dropout":
if is_train:
spikes = self.layers[i](spikes)
elif self.network_config["rule"] == "TSSLBP":
spikes = self.layers[i].forward_pass(spikes, epoch)
else:
raise Exception('Unrecognized rule type. It is: {}'.format(self.network_config['rule']))
return spikes
def get_parameters(self):
return self.my_parameters
def weight_clipper(self):
for l in self.layers:
l.weight_clipper()