-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbnn.py
118 lines (92 loc) · 4.14 KB
/
bnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from collections import OrderedDict
from torchbnn.modules.linear import BayesLinear
from torchbnn.modules.module import BayesModule
from torchbnn.utils import freeze, unfreeze
from torch.nn.functional import relu
import torch.nn.functional as F
import torch
class BayesLayerWithSample(BayesLinear):
"""
same of BayesLinear from torchbnn, just add the possibility of sample
a set of weight from the net. Functional paradigm for calc linear output
"""
def sample_layer_functional(self, device):
if self.weight_eps is None:
weight = (self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(
self.weight_log_sigma)).detach()
else:
weight = (self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps).detach()
weight = weight.to(device)
if self.bias:
if self.bias_eps is None:
bias = (self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)).detach()
else:
bias = (self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps).detach()
bias = bias.to(device)
else:
bias = None
def linear_step(x):
return F.linear(x, weight, bias)
return linear_step
def sample_weight(self, requires_grad=False, var_in_out=0.1 ):
weight, bias = None, None
if self.weight_eps is None:
weight = (self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(
self.weight_log_sigma)*var_in_out).detach()
else:
weight = (self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps*var_in_out).detach()
weight.requires_grad = requires_grad
if self.bias:
if self.bias_eps is None:
bias = (self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)*var_in_out).detach()
else:
bias = (self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps*var_in_out).detach()
bias.requires_grad = requires_grad
return weight, bias
class BNN(BayesModule):
def __init__(self, action_dim, obs_dim, reward_dim, weight=None):
super(BayesModule, self).__init__()
self.in_features = action_dim + obs_dim
self.out_features = obs_dim + reward_dim
self.input_layer = BayesLayerWithSample(prior_mu=0, prior_sigma=0.5, in_features=self.in_features, out_features=128)
self.hidden1_layer = BayesLayerWithSample(prior_mu=0, prior_sigma=0.5, in_features=128, out_features=128)
self.hidden2_layer = BayesLayerWithSample(prior_mu=0, prior_sigma=0.5, in_features=128, out_features=128)
self.output_layer = BayesLayerWithSample(prior_mu=0, prior_sigma=0.5, in_features=128, out_features=self.out_features)
if weight:
self.copy_params_from_model(weight)
def forward(self, x):
x = relu(self.input_layer(x))
x = relu(self.hidden1_layer(x))
x = relu(self.hidden2_layer(x))
x = self.output_layer(x)
return x
def copy_params_from_model(self, W):
try:
self.load_state_dict(W)
except BaseException:
print('non compatible W')
def sample_linear_net_functional(self, device):
step = []
for layer in self._modules.items():
step.append(layer[1].sample_layer_functional(device))
def forward_with_sample(x):
for op in step:
x = F.relu(op(x))
return x
return forward_with_sample
# forse rimuovere questa e' una buona idea
def sample_linear_net_weight(self):
params = OrderedDict()
for name, layer in self._modules.items():
dict_forlayer = layer.sample_weight()
for tensor_name, tensor in dict_forlayer.items():
params[name+'.'+tensor_name] = tensor
return params
def deterministic_mode(self):
'''deterministic output'''
freeze(self)
def stochatisc_mode(self):
'''stochatisc output'''
unfreeze(self)
if __name__ == '__main__':
bnn = BNN(4, 39, 1)