forked from erickrf/assin
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils_submission.py
82 lines (65 loc) · 2.77 KB
/
utils_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import xml.etree.ElementTree as ET
import numpy as np
import pandas as pd
import json
import pathlib
class SubmissionWriter(object):
def __init__(self, source=None, target=None, target_filename='submission.xml', entailment_preds=None, similarity_preds=None, entailment_data=None, similarity_data=None):
self.source = ET.parse(source)
self.target = target + '/' + target_filename
if similarity_preds:
self.similarity_preds = np.load(similarity_preds)
else:
self.similarity_preds = None
if entailment_preds:
self.entailment_preds = np.load(entailment_preds)
else:
self.entailment_preds = None
if entailment_data:
self.entailment_data = pd.read_csv(entailment_data, sep='\t')
else:
self.entailment_data = None
if similarity_data:
self.similarity_data = pd.read_csv(similarity_data, sep='\t')
else:
self.similarity_data = None
self.result = None
def get_score(self, data, preds, test, hypothesis):
assert(data.shape[0] == len(preds) )
test_string = test
hypothesis_string = hypothesis
idx = data.index[ (data['sentence1'].str.contains(test_string, regex=False)) & (data['sentence2'].str.contains(hypothesis_string, regex=False)) ].tolist()
idx2 = data.index[ (data['sentence1'].str.contains(hypothesis_string, regex=False)) & (data['sentence2'].str.contains(test_string, regex=False)) ].tolist()
if len(idx) == 0 and len(idx2) == 0:
raise Exception('Sentence not found.')
if len(idx) >= 1:
return preds[idx[0]]
elif len(idx2) >= 1:
return preds[idx2[0]]
def convert(self):
entailment_dict = {
-1: 'Unknown',
0: 'None',
1: 'Entailment',
2: 'Paraphrase'
}
root = self.source.getroot()
for pair in root.iter('pair'):
test = pair.find('t').text
hypothesis = pair.find('h').text
entailment_score = self.get_score(self.entailment_data, self.entailment_preds, test, hypothesis)
similarity_score = self.get_score(self.similarity_data, self.similarity_preds, test, hypothesis)
pair.set('entailment', entailment_dict[round(entailment_score)])
pair.set('similarity', str(similarity_score))
self.result = root
return self
def save(self):
self.source.write(self.target)
return self
def assin_json_writer(data):
for record in data:
directory = record['target']
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
SubmissionWriter(**record).convert().save()
if __name__ == '__main__':
pass