-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
130 lines (105 loc) · 4.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import json
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import config
def batch_accuracy(predicted, true):
""" Compute the accuracies for a batch of predictions and answers """
_, predicted_index = predicted.max(dim=1, keepdim=True)
agreeing = true.gather(dim=1, index=predicted_index)
'''
Acc needs to be averaged over all 10 choose 9 subsets of human answers.
While we could just use a loop, surely this can be done more efficiently (and indeed, it can).
There are two cases for the 1 chosen answer to be discarded:
(1) the discarded answer is not the predicted answer => acc stays the same
(2) the discarded answer is the predicted answer => we have to subtract 1 from the number of agreeing answers
There are (10 - num_agreeing_answers) of case 1 and num_agreeing_answers of case 2, thus
acc = ((10 - agreeing) * min( agreeing / 3, 1)
+ agreeing * min((agreeing - 1) / 3, 1)) / 10
Let's do some more simplification:
if num_agreeing_answers == 0:
acc = 0 since the case 1 min term becomes 0 and case 2 weighting term is 0
if num_agreeing_answers >= 4:
acc = 1 since the min term in both cases is always 1
The only cases left are for 1, 2, and 3 agreeing answers.
In all of those cases, (agreeing - 1) / 3 < agreeing / 3 <= 1, so we can get rid of all the mins.
By moving num_agreeing_answers from both cases outside the sum we get:
acc = agreeing * ((10 - agreeing) + (agreeing - 1)) / 3 / 10
which we can simplify to:
acc = agreeing * 0.3
Finally, we can combine all cases together with:
min(agreeing * 0.3, 1)
'''
return (agreeing * 0.3).clamp(max=1)
def path_for(train=False, val=False, test=False, question=False, answer=False):
assert train + val + test == 1
assert question + answer == 1
if train:
split = 'train2014'
elif val:
split = 'val2014'
else:
split = config.test_split
if question:
fmt = 'v2_{0}_{1}_{2}_questions.json'
else:
if test:
# just load validation data in the test=answer=True case, will be ignored anyway
split = 'val2014'
fmt = 'v2_{1}_{2}_annotations.json'
s = fmt.format(config.task, config.dataset, split)
return os.path.join(config.qa_path, s)
class Tracker:
""" Keep track of results over time, while having access to monitors to display information about them. """
def __init__(self):
self.data = {}
def track(self, name, *monitors):
""" Track a set of results with given monitors under some name (e.g. 'val_acc').
When appending to the returned list storage, use the monitors to retrieve useful information.
"""
l = Tracker.ListStorage(monitors)
self.data.setdefault(name, []).append(l)
return l
def to_dict(self):
# turn list storages into regular lists
return {k: list(map(list, v)) for k, v in self.data.items()}
class ListStorage:
""" Storage of data points that updates the given monitors """
def __init__(self, monitors=[]):
self.data = []
self.monitors = monitors
for monitor in self.monitors:
setattr(self, monitor.name, monitor)
def append(self, item):
for monitor in self.monitors:
monitor.update(item)
self.data.append(item)
def __iter__(self):
return iter(self.data)
class MeanMonitor:
""" Take the mean over the given values """
name = 'mean'
def __init__(self):
self.n = 0
self.total = 0
def update(self, value):
self.total += value
self.n += 1
@property
def value(self):
return self.total / self.n
class MovingMeanMonitor:
""" Take an exponentially moving mean over the given values """
name = 'mean'
def __init__(self, momentum=0.9):
self.momentum = momentum
self.first = True
self.value = None
def update(self, value):
if self.first:
self.value = value
self.first = False
else:
m = self.momentum
self.value = m * self.value + (1 - m) * value