-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathapprox_lipschitz.py
64 lines (52 loc) · 2.14 KB
/
approx_lipschitz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
class LipschitzCalculator():
def __init__(self, net, eps, step_size, num_steps):
super(LipschitzCalculator, self).__init__()
self.net = net
self.step_size = step_size
self.num_steps = num_steps
self.eps = eps
def __call__(self, inputs):
requires_grads = [x.requires_grad for x in self.net.parameters()]
self.net.requires_grad_(False)
x = inputs.detach()
targets = self.net(x)
ans = torch.zeros_like(targets)
for label in range(targets.size(1)):
init_noise = torch.zeros_like(x).normal_(0, self.eps / 4)
x = x + torch.clamp(init_noise, -self.eps / 2, self.eps / 2)
for i in range(self.num_steps):
x.requires_grad_()
logits = self.net(x)
loss = (logits[:, label] - targets[:, label]).abs().sum()
loss.backward()
x = torch.add(x.detach(), torch.sign(x.grad.detach()), alpha=self.step_size)
x = torch.min(torch.max(x, inputs - self.eps), inputs + self.eps)
with torch.no_grad():
logits = self.net(x)
ans[:, label] = logits[:, label]
diff = torch.norm(ans - targets, dim=1, p=float('inf')) / self.eps
for p, r in zip(self.net.parameters(), requires_grads):
p.requires_grad_(r)
return diff
def cal_Lipschitz(net, data_loader, eps=1/255, num_steps=20, restart=5, gpu=0):
import os
file_name = ''
for j in range(100):
file_name = 'lipschitz%d.txt' % j
if not os.path.exists(file_name):
break
fp = open(file_name, 'w')
cal = LipschitzCalculator(net, eps, eps / 4, num_steps)
net.eval()
ans_list = []
for batch_idx, (inputs, targets) in enumerate(data_loader):
inputs = inputs.cuda(gpu)
ans = torch.zeros(inputs.size(0), device=inputs.device)
print(batch_idx)
for _ in range(restart):
lipschitz = cal(inputs)
ans = torch.maximum(ans, lipschitz)
ans_list += list(ans.cpu().numpy())
fp.writelines(['%.4f\n' % x for x in ans_list])
fp.close()