Skip to content

Commit

Permalink
update metric calculation and requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
chujiezheng committed Dec 17, 2024
1 parent 8b66394 commit 519df16
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions code/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def main():
with open(os.path.join(output_dir, f'{config}_correct.jsonl'), 'w') as f:
for e in correct_data:
f.write(json.dumps(e) + '\n')

acc1 = np.mean([e['match'] for e in error_data]) * 100
acc2 = np.mean([e['match'] for e in correct_data]) * 100
f1 = 2 * acc1 * acc2 / (acc1 + acc2)
print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')


if __name__ == '__main__':
Expand Down
6 changes: 6 additions & 0 deletions code/run_eval_prm_rlhflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import os
import numpy as np
import json
from tqdm import tqdm
from multiprocessing import Pool
Expand Down Expand Up @@ -74,6 +75,11 @@ def single_process(d):
with open(f'outputs/Llama3.1-8B-PRM-Mistral-Data/{config}_correct.jsonl', 'w') as f:
for e in data2:
f.write(json.dumps(e) + '\n')

acc1 = np.mean([e['match'] for e in data1]) * 100
acc2 = np.mean([e['match'] for e in data2]) * 100
f1 = 2 * acc1 * acc2 / (acc1 + acc2)
print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')

if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.4.0
transformers==4.46.1
vllm==0.6.3.post1

0 comments on commit 519df16

Please sign in to comment.