Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rejection Sampling on GSM8k #78

Merged
merged 104 commits into from
Nov 20, 2024
Merged

Rejection Sampling on GSM8k #78

merged 104 commits into from
Nov 20, 2024

Conversation

AlexPiche
Copy link
Collaborator

@AlexPiche AlexPiche commented Oct 31, 2024

Rejection Sampling on GSM8k

Tape Browser PR from @rizar

  • add a script to launch the tape browser
  • save tapes in a single json file, which is a format that the default tape browser can load
  • add a script to gather legacy tapes in single file so that we can view legacy tapes in the browser
  • load LLM Calls 100 at a time to speed up tape browser loading time (the speedup was huge when the browser did not know where to look for LLM Calls; now that I added the search for llm_calls.sqlite to browser.py, loading is taking some time again)

Examples:
run tape browser like this: python -m examples.rl_gsm8k.browse outputs/yolo2_4gpu_lr1e-6/tapes/train/0/all

Reasoning Architecture

Screenshot 2024-11-18 at 10 10 54 AM

Learning curve

python examples/rl_gsm8k/orchestrate_rl.py finetune.rl.algo=reinforce finetune.train_batch_size=4 finetune.gradient_accumulation_passes=1 finetune.rl.implicit_kl_coef=0.0 finetune.rl.kl_coef=0.0 finetune.rl.use_advantages=false +finetune.rl.relu_weights=true use_rejection_sampling=true test_every_n_iterations=5 finetune.learning_rate=0.000001 finetune.gradient_clipping_threshold=1.0 finetune.save_checkpoint_steps=8 finetune.weight_decay=0.1 max_agent_forks=5000 attempts=8

Screenshot 2024-11-19 at 9 46 08 AM

Reproducing GSM8k

Screenshot 2024-11-19 at 4 33 43 PM

@AlexPiche AlexPiche changed the base branch from main to grpo_wild_chat October 31, 2024 02:42
@rizar
Copy link
Collaborator

rizar commented Oct 31, 2024

will address #77 and #76

Base automatically changed from grpo_wild_chat to main October 31, 2024 13:28
@AlexPiche AlexPiche changed the title RL training of Llama 3.1 8b on multi-gpus Rejection Sampling on GSM8k Nov 18, 2024
@AlexPiche AlexPiche changed the base branch from main to fix_test November 19, 2024 15:13
@AlexPiche AlexPiche changed the base branch from fix_test to main November 19, 2024 16:28
@AlexPiche AlexPiche requested a review from rizar November 19, 2024 16:40
Copy link
Collaborator

@rizar rizar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good! We should not traverse the whole training set every time the script is launched though.

Note for the re-review:

  • check ReLU weights citation
  • check running aggregation of input lengths

tapeagents/finetune/checkpoints.py Show resolved Hide resolved
tapeagents/finetune/finetune.py Outdated Show resolved Hide resolved
tapeagents/finetune/rl/__init__.py Outdated Show resolved Hide resolved
tapeagents/finetune/rl/__init__.py Outdated Show resolved Hide resolved
tapeagents/observe.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rizar rizar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@AlexPiche AlexPiche merged commit 74948b6 into main Nov 20, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants