Skip to content

Commit

Permalink
fix: dpo training (reference model)
Browse files Browse the repository at this point in the history
  • Loading branch information
okoge-kaz committed Aug 3, 2024
1 parent f09a3d6 commit 6b98926
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scripts/tsubame/dpo/Llama-3-8B/Llama-3-8B-chat-v0.2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ done <"$PE_HOSTFILE" >"$HOSTFILE_NAME"
SEQ_LENGTH=8192
DATA_PARALLEL_SIZE=$NUM_GPUS

MICRO_BATCH_SIZE=1
MICRO_BATCH_SIZE=2
GLOBAL_BATCH_SIZE=128

# optimizer config
Expand Down
30 changes: 30 additions & 0 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import sys

Expand Down Expand Up @@ -101,6 +102,10 @@ def main() -> None:
model = get_model(
model_name=args.base_model, use_cache=use_cache
)
if args.direct_preference_optimization:
reference_model = copy.deepcopy(model)
for param in reference_model.parameters():
param.requires_grad = False

if args.load:
load_model_state_dict(model, args.load) # type: ignore
Expand All @@ -115,6 +120,13 @@ def main() -> None:
elif args.fp16:
model.to(torch.float16) # type: ignore

if args.direct_preference_optimization:
with preserve_fp32_buffers(reference_model):
if args.bf16:
reference_model.to(torch.bfloat16) # type: ignore
elif args.fp16:
reference_model.to(torch.float16) # type: ignore

if args.use_freeze_layers:
print_rank_0("NOTE: freeze transformer layers")
freeze_transformer_layers(model=model, layer_ranges=args.freeze_layers)
Expand Down Expand Up @@ -142,6 +154,23 @@ def main() -> None:
if args.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model=model, model_name=args.base_model)

if args.direct_preference_optimization:
reference_model = FSDP(
reference_model, # type: ignore
auto_wrap_policy=wrapping_policy,
cpu_offload=CPUOffload(offload_params=True) if args.fsdp_cpu_offload else None,
mixed_precision=mixed_precision_policy,
sharding_strategy=get_sharding_strategy(),
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=args.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty( # type: ignore
device=torch.cuda.current_device(), recurse=False, # type: ignore
)
if args.low_cpu_fsdp and rank != 0
else None,
)

if not args.instruction_tuning and not args.direct_preference_optimization:
args.continual_pretraining = True

Expand Down Expand Up @@ -269,6 +298,7 @@ def main() -> None:
local_rank=get_local_rank(),
rank=get_rank(),
dpo_loss_fn=dpo_loss_fn,
reference_model=reference_model if args.direct_preference_optimization else None,
)


Expand Down
2 changes: 0 additions & 2 deletions src/llama_recipes/utils/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def concatenated_forward(
len_chosen = concatenated_input_ids.shape[0] // 2

all_logits = model(concatenated_input_ids).logits
print(f"DEBUG: all_logits={all_logits}, type(all_logits)={type(all_logits)}")
print(f"DEBUG: concatenated_labels={concatenated_labels}, type(concatenated_labels)={type(concatenated_labels)}")

all_log_probs = get_batch_log_probs(
all_logits, concatenated_labels # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def train(
local_rank: Optional[int] = None,
rank: Optional[int] = None,
dpo_loss_fn: Optional[DPOLoss] = None,
reference_model: Optional[torch.nn.Module] = None,
) -> None:
"""
Trains the model on the given dataloader
Expand Down Expand Up @@ -108,6 +109,11 @@ def train(
raise ValueError(
"DPO(Direct Preference Optimization) is enabled, but dpo loss function is None"
)
if reference_model is None:
raise ValueError(
"DPO(Direct Preference Optimization) is enabled, but reference model is None"
)

# forward
(
policy_chosen_log_probs,
Expand All @@ -128,7 +134,7 @@ def train(
reference_rejected_log_probs,
_,
_,
) = concatenated_forward(model=model, batch=batch, local_rank=local_rank)
) = concatenated_forward(model=reference_model, batch=batch, local_rank=local_rank)

loss, chosen_rewards, rejected_rewards = dpo_loss_fn(
policy_chosen_log_probs,
Expand Down

0 comments on commit 6b98926

Please sign in to comment.