Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RL multinode #152

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ vllm_config:
# VLLM get log probs OOM https://github.com/vllm-project/vllm/issues/5907
--enable-chunked-prefill: ""

output_dir: outputs/rl_gsm8k_deepspeed
output_dir: outputs/rl_gsm8k_deepspeed_llama31_70b_new_save_rl_no_kl_lr1e-6_1node
accelerate_cfg_path: conf/accelerate/accelerate_base.yaml
use_deepspeed: false

Expand Down
25 changes: 25 additions & 0 deletions conf/rl_llama31_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- rl_gsm8k
- _self_

finetune:
rl:
algo: reinforce
kl_coef: 0.0
reward_minus_kl_coef: 0.0
use_advantages: false
relu_log_p_weights: true
train_batch_size: 4
gradient_accumulation_passes: 16
learning_rate: 1e-6
force_restart: true
max_agent_forks: 5000
model_path: /mnt/llmd/base_models/Meta-Llama-3.1-70B-Instruct
n_workers_per_gpu: 16
get_logprobs_workers_per_gpu: 1
gpus_per_model_instance: 4
use_rejection_sampling: true
test_every_n_iterations: 10
attempts: 8
dataset_name: gsm8k
use_deepspeed: true
File renamed without changes.
27 changes: 18 additions & 9 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,21 @@ def main(cfg: DictConfig):
setup_logging(exp_path)
logger.info(f"Current dir: {os.getcwd()}, output dir: {cfg.output_dir}")
cfg.finetune.wandb_id = exp_path.name
run = init_wandb(cfg, exp_path, flatten_dict_config(cfg))
if run is None:
raise ValueError("Failed to initialize wandb run")

# Initialize wandb and handle failure gracefully
try:
run = init_wandb(cfg, exp_path, flatten_dict_config(cfg))
rafapi marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
logger.warning(f"Failed to initialize wandb: {e}. Continuing without wandb logging.")
run = None

def safe_wandb_log(metrics, step):
if run is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check the rank there and nothing unless it is 0?

try:
wandb.log(metrics, step=step)
except Exception as e:
logger.warning(f"Failed to log to wandb: {e}")

state_path = exp_path / "rl_state.json"
state = load_state(state_path)
# optionally clean all data at start time
Expand Down Expand Up @@ -429,10 +441,7 @@ def main(cfg: DictConfig):
time_evaluation = stats["execution_time/test_make_data"]
else:
time_evaluation = 0
wandb.log(
stats,
step=state["iteration"],
)
safe_wandb_log(stats, step=state["iteration"])

start_basemodel_logprobs = time.time()
try:
Expand Down Expand Up @@ -472,7 +481,7 @@ def main(cfg: DictConfig):
raise e

time_populating_ref_logprobs = time.time() - start_basemodel_logprobs
wandb.log(
safe_wandb_log(
{
"execution_time/populating_ref_logprobs": time_populating_ref_logprobs,
"execution_time/starting_assistantmodel_vllm": assistant_vllm_stats["starting_time"],
Expand Down Expand Up @@ -512,7 +521,7 @@ def main(cfg: DictConfig):
)
time_finetune = time.time() - start_finetune
time_iteration = time.time() - start_iteration
wandb.log(
safe_wandb_log(
{
"execution_time/finetune": time_finetune,
"execution_time/iteration": time_iteration,
Expand Down
34 changes: 32 additions & 2 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def __init__(
self.stderr_file: Optional[TextIO] = None
self.stats = {}

# Add node rank awareness
self.node_rank = int(os.environ.get("RANK", 0))
self.port_offset = self.node_rank * 1000 # Ensure different port ranges for each node
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why different port ranges for each node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically we should be able to update the sync port if the one selected by the toolkit environment is being used. But that's not true and I found out that the reason for those clashes is the subprocess subshell instead.

Copy link
Collaborator

@rizar rizar Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand. These vllms are running on different nodes, aren't they?

self.port = port + self.port_offset

def get_base_urls(self) -> list[str]:
return [
f"http://127.0.0.1:{port}" for port in self.ports
Expand Down Expand Up @@ -133,9 +138,9 @@ def _start_service(self) -> None:

threads = []

for i, device_number in enumerate(generate_cuda_device_strings(torch.cuda.device_count(), self.gpus_per_model_instance )):
for i, device_number in enumerate(generate_cuda_device_strings(torch.cuda.device_count(), self.gpus_per_model_instance)):
# Adjust port based on both node rank and GPU index
port = self.port + i
# start_llm(device_number, port, assistant_procs, ports)
thread = threading.Thread(target=self._start_llm, args=(device_number, port))
threads.append(thread)
thread.start()
Expand Down Expand Up @@ -354,8 +359,19 @@ def launch_training(
ValueError: If no GPUs are available
RuntimeError: If training process fails
"""
# environment variables
GLOBAL_RANK = int(os.environ.get("RANK", 0))
MASTER_PORT = int(os.environ.get("MASTER_PORT"))
MASTER_ADDRESS = os.environ.get("MASTER_ADDR")
# this is same as number_of_replicas
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 2))

# Check GPU availability
num_gpus = torch.cuda.device_count()
print('###############################')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use logger.info, and maybe smth like messages, "I'm rank X, training on Y GPU"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just sanity checking, it will be deleted soon. It was there to check that that number is the same as Num processes, but things are clear now.

print(f"Number of GPUs: {num_gpus}")
print('###############################')
is_multinode = num_gpus > 8
if num_gpus == 0:
raise ValueError("No GPUs available for finetuning")

Expand All @@ -381,6 +397,20 @@ def launch_training(
"--deepspeed_config_file",
"conf/accelerate/deepspeed_stage3_bf16.json",
]
if is_multinode:
base_cmd.extend([
"--num_machines",
WORLD_SIZE,
"--machine_rank",
GLOBAL_RANK,
"--main_process_ip",
MASTER_ADDRESS,
"--main_process_port",
MASTER_PORT,
"--deepspeed_multinode_launcher",
"standard",
"--same_network",
])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would training multi-node training without accelerate work without DeepSpeed? If not, should add an exception.

else:
base_cmd[2:2] = [
"--multi_gpu",
Expand Down
10 changes: 7 additions & 3 deletions tapeagents/finetune/logging_.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ def init_wandb(
cfg: DictConfig,
run_dir: Path,
config_for_wandb: DictConfig | dict,
) -> wandb_run.Run:
"""Initialize W&B.
) -> wandb_run.Run | None:
"""Initialize W&B on the main process only.

config_for_wandb is the configuration that will be logged to W&B.

Returns None if not on main process.
"""
# Only initialize on main process (rank 0)
if os.environ.get('RANK', '0') != '0':
return None

if config_for_wandb is None:
config_for_wandb = cfg.dict()

Expand Down