Skip to content

Commit

Permalink
Release test set annotations for EfficientQA.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 386451859
  • Loading branch information
tomkwiat committed Jul 23, 2021
1 parent 141e243 commit be01ce2
Show file tree
Hide file tree
Showing 6 changed files with 3,632 additions and 37 deletions.
15 changes: 15 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

3 changes: 2 additions & 1 deletion make_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from absl import flags

import eval_utils as util
import six

flags.DEFINE_string('gold_path', None, 'Path to gold data.')
flags.DEFINE_string('output_path', None, 'Path to write JSON.')
Expand Down Expand Up @@ -107,7 +108,7 @@ def label_to_pred(labels):
return pred

predictions = []
for _, labels in nq_gold_dict.iteritems():
for _, labels in six.iteritems(nq_gold_dict):
predictions.append(label_to_pred(labels))

with open(FLAGS.output_path, 'w') as f:
Expand Down
11 changes: 3 additions & 8 deletions nq_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import gzip
import json
import os
import sys

import wsgiref.simple_server

Expand All @@ -45,8 +44,6 @@
import tornado.web
import tornado.wsgi

reload(sys)
sys.setdefaultencoding('utf-8')

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -87,11 +84,9 @@ def __init__(self, json_example):

# Whole example info.
self.url = json_example['document_url']
self.title = (
json_example['document_title']
if json_example.has_key('document_title') else 'Wikipedia')
self.title = json_example.get('document_title', 'Wikipedia')
self.example_id = base64.urlsafe_b64encode(
str(self.json_example['example_id']))
str(self.json_example['example_id']).encode('utf-8'))
self.document_html = self.json_example['document_html'].encode('utf-8')
self.document_tokens = self.json_example['document_tokens']
self.question_text = json_example['question_text']
Expand Down Expand Up @@ -210,7 +205,7 @@ def render_long_answer(self, long_answer):
long_answer['end_byte'])

def render_span(self, start, end):
return self.document_html[start:end]
return self.document_html[start:end].decode()

def get_candidates(self, json_candidates):
"""Returns a list of `LongAnswerCandidate` objects for top level candidates.
Expand Down
Loading

0 comments on commit be01ce2

Please sign in to comment.