-
Notifications
You must be signed in to change notification settings - Fork 266
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding MMLU and Winogrande human-translated into 11 African languages (…
- Loading branch information
1 parent
a271355
commit f4ea5c8
Showing
6 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
src/helm/benchmark/run_specs/mmlu_clinical_afr_run_specs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Run spec functions for three clinical sections of MMLU human-translated into 11 African languages | ||
Available subjects: "clinical_knowledge", "college_medicine", "virology" | ||
Available langs: "af", "zu", "xh", "am", "bm", "ig", "nso", "sn", "st", "tn", "ts" (see lang_map below for language code mapping to language name, or here for ISO code reference: https://huggingface.co/languages) | ||
""" | ||
|
||
from helm.benchmark.adaptation.adapter_spec import ( | ||
ADAPT_GENERATION, | ||
ADAPT_MULTIPLE_CHOICE_JOINT, | ||
AdapterSpec, | ||
) | ||
from helm.benchmark.adaptation.common_adapter_specs import ( | ||
get_generation_adapter_spec, | ||
get_machine_translation_adapter_spec, | ||
get_multiple_choice_adapter_spec, | ||
) | ||
from helm.benchmark.metrics.common_metric_specs import ( | ||
get_basic_generation_metric_specs, | ||
get_basic_metric_specs, | ||
get_exact_match_metric_specs, | ||
get_f1_metric_specs, | ||
get_generative_harms_metric_specs, | ||
get_generic_metric_specs, | ||
get_open_ended_generation_metric_specs, | ||
) | ||
from helm.benchmark.run_spec import RunSpec, run_spec_function | ||
from helm.benchmark.runner import get_benchmark_output_path | ||
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path | ||
|
||
|
||
@run_spec_function("mmlu_clinical_afr") | ||
def get_mmlu_clinical_afr_spec(subject: str, lang: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> RunSpec: | ||
scenario_spec = ScenarioSpec( | ||
class_name="helm.benchmark.scenarios.mmlu_clinical_afr_scenario.MMLU_Clinical_Afr_Scenario", args={"subject": subject, "lang": lang} | ||
) | ||
|
||
lang_map = { | ||
'af': 'Afrikaans', | ||
'zu': 'Zulu', | ||
'xh': 'Xhosa', | ||
'am': 'Amharic', | ||
'bm': 'Bambara', | ||
'ig': 'Igbo', | ||
'nso': 'Sepedi', | ||
'sn': 'Shona', | ||
'st': 'Sesotho', | ||
'tn': 'Setswana', | ||
'ts': 'Tsonga', | ||
} | ||
|
||
adapter_spec = get_multiple_choice_adapter_spec( | ||
method=method, | ||
instructions=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')} " | ||
f"in {lang_map[lang]}.", | ||
input_noun="Question", | ||
output_noun="Answer", | ||
) | ||
|
||
return RunSpec( | ||
name=f"mmlu_clinical_afr:subject={subject},lang={lang},method={method}", | ||
scenario_spec=scenario_spec, | ||
adapter_spec=adapter_spec, | ||
metric_specs=get_exact_match_metric_specs(), | ||
groups=["low_resource_languages"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Run spec functions for Winogrande human-translated into 11 African languages | ||
Available langs: "af", "zu", "xh", "am", "bm", "ig", "nso", "sn", "st", "tn", "ts" (see lang_map below for language code mapping to language name, or here for ISO code reference: https://huggingface.co/languages) | ||
""" | ||
|
||
from helm.benchmark.adaptation.adapter_spec import ( | ||
ADAPT_GENERATION, | ||
ADAPT_MULTIPLE_CHOICE_JOINT, | ||
AdapterSpec, | ||
) | ||
from helm.benchmark.adaptation.common_adapter_specs import ( | ||
get_generation_adapter_spec, | ||
get_machine_translation_adapter_spec, | ||
get_multiple_choice_adapter_spec, | ||
) | ||
from helm.benchmark.metrics.common_metric_specs import ( | ||
get_basic_generation_metric_specs, | ||
get_basic_metric_specs, | ||
get_exact_match_metric_specs, | ||
get_f1_metric_specs, | ||
get_generative_harms_metric_specs, | ||
get_generic_metric_specs, | ||
get_open_ended_generation_metric_specs, | ||
) | ||
from helm.benchmark.run_spec import RunSpec, run_spec_function | ||
from helm.benchmark.runner import get_benchmark_output_path | ||
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path | ||
|
||
|
||
@run_spec_function("winogrande_afr") | ||
def get_winogrande_afr_spec(lang: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> RunSpec: | ||
scenario_spec = ScenarioSpec( | ||
class_name="helm.benchmark.scenarios.winogrande_afr_scenario.Winogrande_Afr_Scenario", args={"lang": lang} | ||
) | ||
|
||
lang_map = { | ||
'af': 'Afrikaans', | ||
'zu': 'Zulu', | ||
'xh': 'Xhosa', | ||
'am': 'Amharic', | ||
'bm': 'Bambara', | ||
'ig': 'Igbo', | ||
'nso': 'Sepedi', | ||
'sn': 'Shona', | ||
'st': 'Sesotho', | ||
'tn': 'Setswana', | ||
'ts': 'Tsonga', | ||
} | ||
|
||
adapter_spec = get_multiple_choice_adapter_spec( | ||
method=method, | ||
instructions=f"The following are binary choice fill-in-the-blank sentences (with answers), requiring common sense reasoning " | ||
f"in {lang_map[lang]}.", | ||
input_noun="Question", | ||
output_noun="Answer", | ||
) | ||
|
||
return RunSpec( | ||
name=f"winogrande_afr:lang={lang},method={method}", | ||
scenario_spec=scenario_spec, | ||
adapter_spec=adapter_spec, | ||
metric_specs=get_exact_match_metric_specs(), | ||
groups=["low_resource_languages"], | ||
) |
74 changes: 74 additions & 0 deletions
74
src/helm/benchmark/scenarios/mmlu_clinical_afr_scenario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import csv | ||
import os | ||
from typing import Dict, List | ||
|
||
from helm.common.general import ensure_file_downloaded | ||
from helm.common.hierarchical_logger import hlog | ||
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, VALID_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output | ||
|
||
|
||
class MMLU_Clinical_Afr_Scenario(Scenario): | ||
""" | ||
https://github.com/InstituteforDiseaseModeling/Bridging-the-Gap-Low-Resource-African-Languages | ||
""" | ||
|
||
name = "mmlu_clinical_afr" | ||
description = "Massive Multitask Language Understanding (MMLU) translated into 11 African low-resource languages" | ||
tags = ["knowledge", "multiple_choice", "low_resource_languages"] | ||
|
||
def __init__(self, subject: str = "clinical_knowledge", lang: str = "af"): | ||
super().__init__() | ||
self.subject: str = subject | ||
self.lang: str = lang | ||
|
||
def download_mmlu_clinical_afr(self, path: str): | ||
ensure_file_downloaded( | ||
source_url="https://github.com/InstituteforDiseaseModeling/Bridging-the-Gap-Low-Resource-African-Languages/raw/refs/heads/main/data/evaluation_benchmarks_afr_release.zip", | ||
target_path=path, | ||
unpack=True, | ||
unpack_type='unzip' | ||
) | ||
|
||
def process_csv(self, csv_path: str, split: str) -> List[Instance]: | ||
instances: List[Instance] = [] | ||
hlog(f"Reading {csv_path}") | ||
with open(csv_path) as f: | ||
reader = csv.reader(f, delimiter=",") | ||
for row in reader: | ||
|
||
question, answers, correct_choice = row[0], row[1:-1], row[-1] | ||
answers_dict = dict(zip(["A", "B", "C", "D"], answers)) | ||
correct_answer: str = answers_dict[correct_choice] | ||
|
||
def answer_to_reference(answer: str) -> Reference: | ||
return Reference(Output(text=answer), tags=[CORRECT_TAG] if answer == correct_answer else []) | ||
|
||
instance = Instance( | ||
input=Input(text=question), | ||
references=list(map(answer_to_reference, answers)), | ||
split=split, | ||
) | ||
instances.append(instance) | ||
return instances | ||
|
||
def get_instances(self, output_path: str) -> List[Instance]: | ||
# Download the raw data | ||
desired_dir = 'mmlu_cm_ck_vir' | ||
data_path: str = os.path.join(output_path, desired_dir) | ||
self.download_mmlu_clinical_afr(data_path) | ||
|
||
# Read all the instances | ||
instances: List[Instance] = [] | ||
splits: Dict[str, str] = { | ||
"dev": TRAIN_SPLIT, | ||
"val": VALID_SPLIT, | ||
"test": TEST_SPLIT, | ||
} | ||
for split in splits: | ||
csv_path: str = os.path.join(data_path, desired_dir, f"{self.subject}_{split}_{self.lang}.csv") | ||
if not os.path.exists(csv_path): | ||
hlog(f"{csv_path} doesn't exist, skipping") | ||
continue | ||
instances.extend(self.process_csv(csv_path, splits[split])) | ||
|
||
return instances |
21 changes: 21 additions & 0 deletions
21
src/helm/benchmark/scenarios/test_mmlu_clinical_afr_scenario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
from tempfile import TemporaryDirectory | ||
|
||
from helm.benchmark.scenarios.mmlu_clinical_afr_scenario import MMLU_Clinical_Afr_Scenario | ||
from helm.benchmark.scenarios.scenario import CORRECT_TAG, Input, Output, Reference | ||
|
||
|
||
@pytest.mark.scenarios | ||
def test_mmlu_clinical_afr_scenario(): | ||
with TemporaryDirectory() as tmpdir: | ||
scenario = MMLU_Clinical_Afr_Scenario(subject="clinical_knowledge", lang="am") | ||
instances = scenario.get_instances(tmpdir) | ||
assert len(instances) == 299 | ||
assert instances[0].input == Input(text="ለሁሉም የጡንቻ መኮማተር ዓይነቶች የሚያስፈልገው ኢኔርጅ የሚቀርበው ከሚከተሉት ነው፦") | ||
assert instances[0].references == [ | ||
Reference(output=Output(text="ATP።"), tags=[CORRECT_TAG]), | ||
Reference(output=Output(text="ADP።"), tags=[]), | ||
Reference(output=Output(text="ፎስፎክሬቲን።"), tags=[]), | ||
Reference(output=Output(text="ኦክስዳቲቪ ፎስፎሪሌሽን።"), tags=[]), | ||
] | ||
assert instances[0].split == "train" |
19 changes: 19 additions & 0 deletions
19
src/helm/benchmark/scenarios/test_winogrande_afr_scenario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pytest | ||
from tempfile import TemporaryDirectory | ||
|
||
from helm.benchmark.scenarios.winogrande_afr_scenario import Winogrande_Afr_Scenario | ||
from helm.benchmark.scenarios.scenario import CORRECT_TAG, Input, Output, Reference | ||
|
||
|
||
@pytest.mark.scenarios | ||
def test_winogrande_afr_scenario(): | ||
with TemporaryDirectory() as tmpdir: | ||
scenario = Winogrande_Afr_Scenario(lang="am") | ||
instances = scenario.get_instances(tmpdir) | ||
assert len(instances) == 3674 | ||
assert instances[0].input == Input(text="ሳራ ከማሪያ በጣም የተሻለች የቀዶ ጥገና ሐኪም ስለነበረች ሁልጊዜ _ ቀላል ህመሞችን ታክማለች.") | ||
assert instances[0].references == [ | ||
Reference(output=Output(text="ሳራ"), tags=[]), | ||
Reference(output=Output(text="ማሪያ"), tags=[CORRECT_TAG]), | ||
] | ||
assert instances[0].split == "train" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import csv | ||
import os | ||
from typing import Dict, List | ||
|
||
from helm.common.general import ensure_file_downloaded | ||
from helm.common.hierarchical_logger import hlog | ||
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, VALID_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output | ||
|
||
|
||
class Winogrande_Afr_Scenario(Scenario): | ||
""" | ||
https://github.com/InstituteforDiseaseModeling/Bridging-the-Gap-Low-Resource-African-Languages | ||
""" | ||
|
||
name = "winogrande_afr" | ||
description = "Winogrande (S) translated into 11 African low-resource languages" | ||
tags = ["knowledge", "multiple_choice", "low_resource_languages"] | ||
|
||
def __init__(self, lang: str = "af"): | ||
super().__init__() | ||
self.lang: str = lang | ||
|
||
def download_winogrande_afr(self, path: str): | ||
ensure_file_downloaded( | ||
source_url="https://github.com/InstituteforDiseaseModeling/Bridging-the-Gap-Low-Resource-African-Languages/raw/refs/heads/main/data/evaluation_benchmarks_afr_release.zip", | ||
target_path=path, | ||
unpack=True, | ||
unpack_type='unzip' | ||
) | ||
|
||
def process_csv(self, csv_path: str, split: str, pseudo_split: str) -> List[Instance]: | ||
# Match naming in Winogrande | ||
if pseudo_split == 'val': | ||
pseudo_split = 'train_s' | ||
instances: List[Instance] = [] | ||
hlog(f"Reading {csv_path}") | ||
with open(csv_path) as f: | ||
reader = csv.reader(f, delimiter=",") | ||
next(reader, None) # skip the header | ||
for row in reader: | ||
if row[-1] != pseudo_split: # ensure correct split is taken | ||
continue | ||
question, answers, correct_choice = row[-5], row[-4:-2], row[-2] | ||
answers_dict = dict(zip(["1", "2"], answers)) | ||
correct_answer: str = answers_dict[correct_choice] | ||
|
||
def answer_to_reference(answer: str) -> Reference: | ||
return Reference(Output(text=answer), tags=[CORRECT_TAG] if answer == correct_answer else []) | ||
|
||
instance = Instance( | ||
input=Input(text=question), | ||
references=list(map(answer_to_reference, answers)), | ||
split=split, | ||
) | ||
instances.append(instance) | ||
return instances | ||
|
||
def get_instances(self, output_path: str) -> List[Instance]: | ||
# Download the raw data | ||
desired_dir = 'winogrande_s' | ||
data_path: str = os.path.join(output_path, desired_dir) | ||
self.download_winogrande_afr(data_path) | ||
|
||
# Read all the instances | ||
instances: List[Instance] = [] | ||
splits: Dict[str, str] = { | ||
"dev": TRAIN_SPLIT, | ||
"val": VALID_SPLIT, | ||
"test": TEST_SPLIT, | ||
} | ||
for split in splits: | ||
csv_path: str = os.path.join(data_path, desired_dir, f"winogrande_{self.lang}.csv") | ||
if not os.path.exists(csv_path): | ||
hlog(f"{csv_path} doesn't exist, skipping") | ||
continue | ||
instances.extend(self.process_csv(csv_path, splits[split], split)) | ||
|
||
return instances |