diff --git a/.gitignore b/.gitignore index 56ea990..6bf040a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ data/ build dist __pycache__ +sacrebleu.egg-info +.sacrebleu +*~ +.DS_Store \ No newline at end of file diff --git a/sacrebleu/dataset/__init__.py b/sacrebleu/dataset/__init__.py index 5c3a7b8..19f4f16 100644 --- a/sacrebleu/dataset/__init__.py +++ b/sacrebleu/dataset/__init__.py @@ -74,6 +74,51 @@ DATASETS = { # wmt + "wmt22": WMTXMLDataset( + "wmt22", + data=["https://github.com/wmt-conference/wmt22-news-systems/archive/refs/tags/v1.1.tar.gz"], + description="Official evaluation and system data for WMT22.", + md5=["0840978b9b50b9ac3b2b081e37d620b9"], + langpairs={ + "cs-en": { + "path": "wmt22-news-systems-1.1/xml/wmttest2022.cs-en.all.xml", + "refs": ["B"], + }, + "cs-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.cs-uk.all.xml"], + "de-en": ["wmt22-news-systems-1.1/xml/wmttest2022.de-en.all.xml"], + "de-fr": ["wmt22-news-systems-1.1/xml/wmttest2022.de-fr.all.xml"], + "en-cs": { + "path": "wmt22-news-systems-1.1/xml/wmttest2022.en-cs.all.xml", + "refs": ["B"], + }, + "en-de": ["wmt22-news-systems-1.1/xml/wmttest2022.en-de.all.xml"], + "en-hr": ["wmt22-news-systems-1.1/xml/wmttest2022.en-hr.all.xml"], + "en-ja": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ja.all.xml"], + "en-liv": ["wmt22-news-systems-1.1/xml/wmttest2022.en-liv.all.xml"], + "en-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ru.all.xml"], + "en-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.en-uk.all.xml"], + "en-zh": ["wmt22-news-systems-1.1/xml/wmttest2022.en-zh.all.xml"], + "fr-de": ["wmt22-news-systems-1.1/xml/wmttest2022.fr-de.all.xml"], + "ja-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ja-en.all.xml"], + "liv-en": { + "path": "wmt22-news-systems-1.1/xml/wmttest2022.liv-en.all.xml", + # no translator because data is English-original + "refs": [""], + }, + "ru-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ru-en.all.xml"], + "ru-sah": { + "path": "wmt22-news-systems-1.1/xml/wmttest2022.ru-sah.all.xml", + # no translator because data is Yakut-original + "refs": [""], + }, + "sah-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.sah-ru.all.xml"], + "uk-cs": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-cs.all.xml"], + "uk-en": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-en.all.xml"], + "zh-en": ["wmt22-news-systems-1.1/xml/wmttest2022.zh-en.all.xml"], + }, + # the default reference to use with this dataset + refs=["A"], + ), "wmt21/systems": WMTXMLDataset( "wmt21/systems", data=["https://github.com/wmt-conference/wmt21-news-systems/archive/refs/tags/v1.3.tar.gz"], @@ -101,8 +146,9 @@ "xh-zu": ["wmt21-news-systems-1.3/xml/florestest2021.xh-zu.all.xml"], "zu-xh": ["wmt21-news-systems-1.3/xml/florestest2021.zu-xh.all.xml"], }, + # the reference to use with this dataset + refs=["A"], ), - "wmt21": WMTXMLDataset( "wmt21", data=["http://data.statmt.org/wmt21/translation-task/test.tgz"], @@ -210,6 +256,8 @@ "en-is": ["dev/xml/newsdev2021.en-is.xml"], "is-en": ["dev/xml/newsdev2021.is-en.xml"], }, + # datasets are bidirectional in origin, so use both refs + refs=["A", ""], ), "wmt20/tworefs": FakeSGMLDataset( "wmt20/tworefs", diff --git a/sacrebleu/dataset/wmt_xml.py b/sacrebleu/dataset/wmt_xml.py index c487a20..92c96d5 100644 --- a/sacrebleu/dataset/wmt_xml.py +++ b/sacrebleu/dataset/wmt_xml.py @@ -8,14 +8,19 @@ from collections import defaultdict +def _get_field_by_translator(translator): + if not translator: + return "ref" + else: + return f"ref:{translator}" + class WMTXMLDataset(Dataset): """ The 2021+ WMT dataset format. Everything is contained in a single file. Can be parsed with the lxml parser. """ - @staticmethod - def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]): + def _unwrap_wmt21_or_later(raw_file): """ Unwraps the XML file from wmt21 or later. This script is adapted from https://github.com/wmt-conference/wmt-format-tools @@ -37,27 +42,20 @@ def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]): for ref_doc in tree.getroot().findall(".//ref"): ref_langs.add(ref_doc.get("lang")) translator = ref_doc.get("translator") - if len(allowed_refs) == 0 or translator in allowed_refs: - translators.add(translator) + translators.add(translator) assert ( len(src_langs) == 1 ), f"Multiple source languages found in the file: {raw_file}" assert ( len(ref_langs) == 1 - ), f"Multiple reference languages found in the file: {raw_file}" + ), f"Found {len(ref_langs)} reference languages found in the file: {raw_file}" src = [] docids = [] orig_langs = [] - def get_field_by_translator(translator): - if not translator: - return "ref" - else: - return f"ref:{translator}" - - refs = {get_field_by_translator(translator): [] for translator in translators} + refs = { _get_field_by_translator(translator): [] for translator in translators } systems = defaultdict(list) @@ -97,7 +95,7 @@ def get_sents(doc): if not any([value.get(seg_id, "") for value in trans_to_ref.values()]): continue for translator in translators: - refs[get_field_by_translator(translator)].append( + refs[_get_field_by_translator(translator)].append( trans_to_ref.get(translator, {translator: {}}).get(seg_id, "") ) src.append(src_sents[seg_id]) @@ -109,6 +107,16 @@ def get_sents(doc): return {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} + def _get_langpair_path(self, langpair): + """ + Returns the path for this language pair. + This is useful because in WMT22, the language-pair data structure can be a dict, + in order to allow for overriding which test set to use. + """ + langpair_data = self._get_langpair_metadata(langpair)[langpair] + rel_path = langpair_data["path"] if type(langpair_data) == dict else langpair_data[0] + return os.path.join(self._rawdir, rel_path) + def process_to_text(self, langpair=None): """Processes raw files to plain text files. @@ -116,15 +124,14 @@ def process_to_text(self, langpair=None): """ # ensure that the dataset is downloaded self.maybe_download() - langpairs = self._get_langpair_metadata(langpair) - for langpair, files in langpairs.items(): - rawfile = os.path.join( - self._rawdir, files[0] - ) # all source and reference data in one file for wmt21 and later + for langpair in sorted(self._get_langpair_metadata(langpair).keys()): + # The data type can be a list of paths, or a dict, containing the "path" + # and an override on which labeled reference to use (key "refs") + rawfile = self._get_langpair_path(langpair) with smart_open(rawfile) as fin: - fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", [])) + fields = self._unwrap_wmt21_or_later(fin) for fieldname in fields: textfile = self._get_txt_file_path(langpair, fieldname) @@ -137,11 +144,37 @@ def process_to_text(self, langpair=None): for line in fields[fieldname]: print(self._clean(line), file=fout) + def _get_langpair_allowed_refs(self, langpair): + """ + Returns the preferred references for this language pair. + This can be set in the language pair block (as in WMT22), and backs off to the + test-set-level default, or nothing. + + There is one exception. In the metadata, sometimes there is no translator field + listed (e.g., wmt22:liv-en). In this case, the reference is set to "", and the + field "ref" is returned. + """ + defaults = self.kwargs.get("refs", []) + langpair_data = self._get_langpair_metadata(langpair)[langpair] + if type(langpair_data) == dict: + allowed_refs = langpair_data.get("refs", defaults) + else: + allowed_refs = defaults + allowed_refs = [_get_field_by_translator(ref) for ref in allowed_refs] + + return allowed_refs + def get_reference_files(self, langpair): + """ + Returns the requested reference files. + This is defined as a default at the test-set level, and can be overridden per language. + """ + # Iterate through the (label, file path) pairs, looking for permitted labels + allowed_refs = self._get_langpair_allowed_refs(langpair) all_files = self.get_files(langpair) all_fields = self.fieldnames(langpair) ref_files = [ - f for f, field in zip(all_files, all_fields) if field.startswith("ref") + f for f, field in zip(all_files, all_fields) if field in allowed_refs ] return ref_files @@ -157,10 +190,9 @@ def fieldnames(self, langpair): :return: a list of field names """ self.maybe_download() - meta = self._get_langpair_metadata(langpair)[langpair] - rawfile = os.path.join(self._rawdir, meta[0]) + rawfile = self._get_langpair_path(langpair) with smart_open(rawfile) as fin: - fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", [])) + fields = self._unwrap_wmt21_or_later(fin) return list(fields.keys()) diff --git a/test/test_dataset.py b/test/test_dataset.py index d2dae07..cff796d 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -75,4 +75,40 @@ def test_source_and_references(): """ for ds in dataset.DATASETS.values(): for pair in ds.langpairs: - assert len(list(ds.source(pair))) == len(list(ds.references(pair))) + src_len = len(list(ds.source(pair))) + ref_len = len(list(ds.references(pair))) + assert src_len == ref_len, f"source/reference failure for {ds.name}:{pair} len(source)={src_len} len(references)={ref_len}" + + +def test_wmt22_references(): + """ + WMT21 added the ability to specify which reference to use (among many in the XML). + The default was "A" for everything. + WMT22 added the ability to override this default on a per-langpair basis, by + replacing the langpair list of paths with a dict that had the list of paths and + the annotator override. + """ + wmt22 = dataset.DATASETS["wmt22"] + + # make sure CS-EN returns all reference fields + cs_en_fields = wmt22.fieldnames("cs-en") + for ref in ["ref:B", "ref:C"]: + assert ref in cs_en_fields + assert "ref:A" not in cs_en_fields + + # make sure ref:B is the one used by default + assert wmt22._get_langpair_allowed_refs("cs-en") == ["ref:B"] + + # similar check for another dataset: there should be no default ("A"), + # and the only ref found should be the unannotated one + assert "ref:A" not in wmt22.fieldnames("liv-en") + assert "ref" in wmt22.fieldnames("liv-en") + + # and that ref:A is the default for all languages where it wasn't overridden + for langpair, langpair_data in wmt22.langpairs.items(): + if type(langpair_data) == dict: + assert wmt22._get_langpair_allowed_refs(langpair) != ["ref:A"] + else: + assert wmt22._get_langpair_allowed_refs(langpair) == ["ref:A"] + +