-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallbacks.py
39 lines (34 loc) · 1.43 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import os
from utils import moving_average, print0, print_in_rank
import torch
class PeftCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
# peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(checkpoint_folder)
kwargs["model"].config.save_pretrained(checkpoint_folder)
kwargs["tokenizer"].save_pretrained(checkpoint_folder)
# pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
# if os.path.exists(pytorch_model_path):
# os.remove(pytorch_model_path)
return control
class DistilCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "pytorch_model")
kwargs["model"].save_pretrained(peft_model_path)
return control