From 6b989268c6192ea764143152de87c1f8208a4ca9 Mon Sep 17 00:00:00 2001 From: kazuki Date: Sun, 4 Aug 2024 01:30:11 +0900 Subject: [PATCH] fix: dpo training (reference model) --- .../dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh | 2 +- src/llama_recipes/finetuning.py | 30 +++++++++++++++++++ src/llama_recipes/utils/dpo.py | 2 -- src/llama_recipes/utils/train_utils.py | 8 ++++- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/scripts/tsubame/dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh b/scripts/tsubame/dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh index c37514c..8e134ab 100644 --- a/scripts/tsubame/dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh +++ b/scripts/tsubame/dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh @@ -42,7 +42,7 @@ done <"$PE_HOSTFILE" >"$HOSTFILE_NAME" SEQ_LENGTH=8192 DATA_PARALLEL_SIZE=$NUM_GPUS -MICRO_BATCH_SIZE=1 +MICRO_BATCH_SIZE=2 GLOBAL_BATCH_SIZE=128 # optimizer config diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index d1d9528..8669558 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -1,3 +1,4 @@ +import copy import os import sys @@ -101,6 +102,10 @@ def main() -> None: model = get_model( model_name=args.base_model, use_cache=use_cache ) + if args.direct_preference_optimization: + reference_model = copy.deepcopy(model) + for param in reference_model.parameters(): + param.requires_grad = False if args.load: load_model_state_dict(model, args.load) # type: ignore @@ -115,6 +120,13 @@ def main() -> None: elif args.fp16: model.to(torch.float16) # type: ignore + if args.direct_preference_optimization: + with preserve_fp32_buffers(reference_model): + if args.bf16: + reference_model.to(torch.bfloat16) # type: ignore + elif args.fp16: + reference_model.to(torch.float16) # type: ignore + if args.use_freeze_layers: print_rank_0("NOTE: freeze transformer layers") freeze_transformer_layers(model=model, layer_ranges=args.freeze_layers) @@ -142,6 +154,23 @@ def main() -> None: if args.fsdp_activation_checkpointing: apply_fsdp_checkpointing(model=model, model_name=args.base_model) + if args.direct_preference_optimization: + reference_model = FSDP( + reference_model, # type: ignore + auto_wrap_policy=wrapping_policy, + cpu_offload=CPUOffload(offload_params=True) if args.fsdp_cpu_offload else None, + mixed_precision=mixed_precision_policy, + sharding_strategy=get_sharding_strategy(), + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + sync_module_states=args.low_cpu_fsdp, + param_init_fn=lambda module: module.to_empty( # type: ignore + device=torch.cuda.current_device(), recurse=False, # type: ignore + ) + if args.low_cpu_fsdp and rank != 0 + else None, + ) + if not args.instruction_tuning and not args.direct_preference_optimization: args.continual_pretraining = True @@ -269,6 +298,7 @@ def main() -> None: local_rank=get_local_rank(), rank=get_rank(), dpo_loss_fn=dpo_loss_fn, + reference_model=reference_model if args.direct_preference_optimization else None, ) diff --git a/src/llama_recipes/utils/dpo.py b/src/llama_recipes/utils/dpo.py index ae46f71..ef9d325 100644 --- a/src/llama_recipes/utils/dpo.py +++ b/src/llama_recipes/utils/dpo.py @@ -71,8 +71,6 @@ def concatenated_forward( len_chosen = concatenated_input_ids.shape[0] // 2 all_logits = model(concatenated_input_ids).logits - print(f"DEBUG: all_logits={all_logits}, type(all_logits)={type(all_logits)}") - print(f"DEBUG: concatenated_labels={concatenated_labels}, type(concatenated_labels)={type(concatenated_labels)}") all_log_probs = get_batch_log_probs( all_logits, concatenated_labels # type: ignore diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d34d82f..12b948e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -40,6 +40,7 @@ def train( local_rank: Optional[int] = None, rank: Optional[int] = None, dpo_loss_fn: Optional[DPOLoss] = None, + reference_model: Optional[torch.nn.Module] = None, ) -> None: """ Trains the model on the given dataloader @@ -108,6 +109,11 @@ def train( raise ValueError( "DPO(Direct Preference Optimization) is enabled, but dpo loss function is None" ) + if reference_model is None: + raise ValueError( + "DPO(Direct Preference Optimization) is enabled, but reference model is None" + ) + # forward ( policy_chosen_log_probs, @@ -128,7 +134,7 @@ def train( reference_rejected_log_probs, _, _, - ) = concatenated_forward(model=model, batch=batch, local_rank=local_rank) + ) = concatenated_forward(model=reference_model, batch=batch, local_rank=local_rank) loss, chosen_rewards, rejected_rewards = dpo_loss_fn( policy_chosen_log_probs,