From 4871fc115d7b94b1787415f7955e6662c8ec8a37 Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 13 May 2024 09:22:55 +0800 Subject: [PATCH] Fix for termination issue (#38) --- scripts/inference/inference_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index 43ac2e8..7cf3ad7 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -191,7 +191,7 @@ def generate_prompt(instruction): generation_output = model.generate( input_ids = inputs["input_ids"].to(device), attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, + eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id, generation_config = generation_config ) @@ -215,4 +215,4 @@ def generate_prompt(instruction): with open(dirname+'/generation_config.json','w') as f: json.dump(generation_config,f,ensure_ascii=False,indent=2) else: - generation_config.save_pretrained('./') \ No newline at end of file + generation_config.save_pretrained('./')