-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
executable file
·133 lines (87 loc) · 3.47 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
import math
import sys
import warnings
from collections import namedtuple, OrderedDict
from functools import partial
import pandas as pd
class ListShouldBeEmptyWarning(UserWarning):
pass
Question = namedtuple('Question', 'id explanations')
Explanation = namedtuple('Explanation', 'id role')
def load_gold(filepath_or_buffer, sep='\t'):
df = pd.read_csv(filepath_or_buffer, sep=sep, dtype=str)
gold = OrderedDict()
for _, row in df[['questionID', 'explanation']].dropna().iterrows():
explanations = OrderedDict((uid.lower(), Explanation(uid.lower(), role))
for e in row['explanation'].split()
for uid, role in (e.split('|', 1),))
question = Question(row['questionID'].lower(), explanations)
gold[question.id] = question
return gold
def load_pred(filepath_or_buffer, sep='\t'):
df = pd.read_csv(filepath_or_buffer, sep=sep, names=('question', 'explanation'), dtype=str)
if any(df[field].isnull().all() for field in df.columns):
raise ValueError('invalid format of the prediction dataset, possibly the wrong separator')
pred = OrderedDict()
for question_id, df_explanations in df.groupby('question'):
pred[question_id.lower()] = list(OrderedDict.fromkeys(df_explanations['explanation'].str.lower()))
print(len(pred))
return pred
def compute_ranks(true, pred):
ranks = []
if not true or not pred:
return ranks
targets = list(true)
# I do not understand the corresponding block of the original Scala code.
for i, pred_id in enumerate(pred):
for true_id in targets:
if pred_id == true_id:
ranks.append(i + 1)
targets.remove(pred_id)
break
# Example: Mercury_SC_416133
if targets:
warnings.warn('targets list should be empty, but it contains: ' + ', '.join(targets), ListShouldBeEmptyWarning)
for _ in targets:
ranks.append(10**9)
return ranks
def average_precision(ranks):
total = 0.
if not ranks:
return total
for i, rank in enumerate(ranks):
precision = float(i + 1) / float(rank) if rank > 0 else math.inf
total += precision
return total / len(ranks)
def mean_average_precision_score(gold, pred, callback=None):
total, count = 0., 0
for question in gold.values():
if question.id in pred:
ranks = compute_ranks(list(question.explanations), pred[question.id])
score = average_precision(ranks)
if not math.isfinite(score):
score = 0.
total += score
count += 1
print(question.id, score)
print(ranks)
if callback:
callback(question.id, score)
mean_ap = total / count if count > 0 else 0.
return mean_ap
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gold', type=argparse.FileType('r', encoding='UTF-8'), required=True)
parser.add_argument('pred', type=argparse.FileType('r', encoding='UTF-8'))
args = parser.parse_args()
gold, pred = load_gold(args.gold), load_pred(args.pred)
# callback is optional, here it is used to print intermediate results to STDERR
mean_ap = mean_average_precision_score(
gold, pred,
callback=partial(print, file=sys.stderr)
)
print('MAP: ', mean_ap)
if '__main__' == __name__:
main()