Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RL multinode #152

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions conf/accelerate/ds_multinode.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"sub_group_size": 8e7,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1.5e8,
"stage3_max_reuse_distance": 1.5e8,
"gather_16bit_weights_on_model_save": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true,
"fast_init": true,
"ratio": 0.45,
"buffer_count": 8,
"pipeline_read": true,
"pipeline_write": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"memory_efficient_linear": true,
"round_robin_gradients": true
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": true,
"profile": false
},
"communication_data_type": "bf16",
"communication_data_parallel": {
"nccl_start_collective_timeout": 3600,
"nccl_timeout": 3600,
"hierarchical_allreduce": true,
"allreduce_always_fp32": true
},
"wall_clock_breakdown": true,
"monitor_config": {
"enabled": true,
"tag": "ds_train",
"csv_monitor": {
"enabled": true,
"output_path": "/mnt/results/logs/",
"job_name": "llama_train"
}
},
"torch_cuda_alloc_conf": {
"max_split_size_mb": 128,
"garbage_collection_threshold": 0.8
},
"aio": {
"block_size": 1048576,
"queue_depth": 24,
"single_submit": false,
"overlap_events": true,
"thread_count": 6
},
"wandb": {
"enabled": false
},
"elastic_training": {
"enabled": true,
"max_out_of_sync_steps": 200,
"sync_timeout": 3600,
"num_retries": 3
},
"distributed": {
"timeout": 3600,
"initialization_timeout": 3600
}
}
2 changes: 1 addition & 1 deletion conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ vllm_config:
--enable-chunked-prefill: ""
--max-num-batched-tokens: 256

output_dir: outputs/rl_gsm8k_deepspeed
output_dir: outputs/rl_gsm8k_deepspeed_llama31_70b_new_save_rl_no_kl_lr1e-6_1node
accelerate_cfg_path: conf/accelerate/accelerate_base.yaml
use_deepspeed: false

Expand Down
26 changes: 26 additions & 0 deletions conf/rl_llama31_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
defaults:
- rl_gsm8k
- _self_

finetune:
rl:
algo: reinforce
kl_coef: 0.0
reward_minus_kl_coef: 0.0
use_advantages: false
relu_log_p_weights: true
train_batch_size: 2
gradient_accumulation_passes: 16
learning_rate: 1e-6
force_restart: true
max_agent_forks: 5000
model_path: /mnt/llmd/base_models/Meta-Llama-3.1-70B-Instruct
n_workers_per_gpu: 16
get_logprobs_workers_per_gpu: 1
gpus_per_model_instance: 4
use_rejection_sampling: true
test_every_n_iterations: 10
attempts: 8
dataset_name: gsm8k
use_deepspeed: true
keep_intermediate_checkpoints: false
175 changes: 175 additions & 0 deletions examples/rl_gsm8k/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import time
import logging
import datetime
import torch
import os

logger = logging.getLogger(__name__)

class DistributedManager:
@staticmethod
def is_main_process() -> bool:
return int(os.environ.get("RANK", "0")) == 0

@staticmethod
def get_world_size() -> int:
return int(os.environ.get("WORLD_SIZE", "1"))

@staticmethod
def get_rank() -> int:
return int(os.environ.get("RANK", "0"))

@staticmethod
def get_master_address() -> str:
return os.environ.get("MASTER_ADDR", "localhost")

@staticmethod
def get_master_port() -> str:
return os.environ.get("MASTER_PORT", "29500")

@classmethod
def cleanup_gpu_resources(cls):
if torch.cuda.is_available():
try:
# Force garbage collection
import gc
gc.collect()

# Clear CUDA cache
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Monitor cleanup effectiveness
cls.check_memory_status()

except Exception as e:
logger.error(f"Cleanup failed: {e}")

@classmethod
def robust_barrier(cls, message: str = "", timeout_mins: int = 30, max_retries: int = 3) -> bool:
"""More robust barrier implementation with retries and cleanup"""
if not torch.distributed.is_initialized():
return True

retry_delay = 5
for attempt in range(max_retries):
try:
logger.info(f"[Rank {cls.get_rank()}] Barrier attempt {attempt + 1}/{max_retries}: {message}")

# Clear GPU memory before barrier
cls.cleanup_gpu_resources()

# Attempt barrier with timeout
torch.distributed.barrier(timeout=datetime.timedelta(minutes=timeout_mins))

logger.info(f"[Rank {cls.get_rank()}] Barrier successful: {message}")
return True

except Exception as e:
logger.warning(f"[Rank {cls.get_rank()}] Barrier attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
logger.info(f"Waiting {retry_delay}s before retry...")
time.sleep(retry_delay)

# Attempt to reinit process group
if cls.reinit_process_group(timeout_mins=timeout_mins):
logger.info("Successfully reinitialized process group")
else:
logger.warning("Failed to reinitialize process group")

logger.error(f"[Rank {cls.get_rank()}] Failed all barrier attempts: {message}")
return False

@classmethod
def sync_nodes(cls, message: str = "", timeout_mins: int = 30) -> bool:
"""High-level sync function with additional safeguards"""
if not torch.distributed.is_initialized():
return True

try:
# Ensure GPU operations are finished
cls.cleanup_gpu_resources()

# Additional small wait to ensure all processes are ready
time.sleep(cls.get_rank() * 0.1) # Stagger by rank

return cls.robust_barrier(message, timeout_mins)

except Exception as e:
logger.error(f"[Rank {cls.get_rank()}] Failed to sync nodes: {e}")
return False

@classmethod
def all_gather_object(cls, obj):
"""Gather objects from all processes and return them in a list.

Args:
obj: Any picklable Python object
Returns:
list: List of objects gathered from all processes
"""
if not torch.distributed.is_initialized():
return [obj]

try:
gathered_objects = [None] * cls.get_world_size()
torch.distributed.all_gather_object(gathered_objects, obj)
return gathered_objects

except Exception as e:
logger.error(f"[Rank {cls.get_rank()}] Failed to gather objects: {e}")
# Return list with just this process's object on failure
return [obj]

@classmethod
def broadcast_object(cls, obj, src=0):
"""Broadcast an object from the source rank to all other processes."""
if not torch.distributed.is_initialized():
return obj

try:
# Debug logging before broadcast
logger.info(f"[Rank {cls.get_rank()}] Starting broadcast operation")

# Ensure all processes are ready for broadcast
if not cls.sync_nodes("before broadcast"):
raise RuntimeError(f"Failed sync before broadcast on rank {cls.get_rank()}")

# Create object list with None for non-source ranks
object_list = [obj if cls.get_rank() == src else None]

# Log object details before broadcast
logger.info(f"[Rank {cls.get_rank()}] Before broadcast: "
f"object_list[0] type: {type(object_list[0])}, "
f"length: {len(object_list[0]) if object_list[0] and hasattr(object_list[0], '__len__') else 'N/A'}")

# Perform broadcast with explicit timeout
torch.distributed.broadcast_object_list(
object_list,
src=src,
timeout=datetime.timedelta(minutes=10)
)

# Get result and verify
result = object_list[0]

# Log result details
logger.info(f"[Rank {cls.get_rank()}] After broadcast: "
f"result type: {type(result)}, "
f"length: {len(result) if result and hasattr(result, '__len__') else 'N/A'}")

# Verify broadcast result
if result is None:
raise RuntimeError(f"Broadcast resulted in None on rank {cls.get_rank()}")

# Ensure all processes received the data
if not cls.sync_nodes("after broadcast"):
raise RuntimeError(f"Failed sync after broadcast on rank {cls.get_rank()}")

logger.info(f"[Rank {cls.get_rank()}] Broadcast operation completed successfully")

return result

except Exception as e:
logger.error(f"[Rank {cls.get_rank()}] Failed to broadcast object: {e}")
raise RuntimeError(f"Broadcast failed on rank {cls.get_rank()}: {e}")
Loading