Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rejection Sampling on GSM8k #78

Merged
merged 104 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
0044902
run 70b
AlexPiche Oct 31, 2024
09f7fb2
clean up debug code
AlexPiche Oct 31, 2024
0537e02
conf_dir as str
AlexPiche Oct 31, 2024
4b9407a
chunked prefill vllm
AlexPiche Oct 31, 2024
04279b0
iteration as str
AlexPiche Oct 31, 2024
537f9a4
no testing when test_every_n_iterations is -1
AlexPiche Oct 31, 2024
819a161
no testing with -1
AlexPiche Oct 31, 2024
b772d97
print finetuning output
AlexPiche Oct 31, 2024
4317fcb
rm seq length tokens
AlexPiche Oct 31, 2024
d781b74
fp8 quant
AlexPiche Oct 31, 2024
8657767
use deepspeed
AlexPiche Oct 31, 2024
d050cf4
better vllm logging
AlexPiche Oct 31, 2024
6b0bc26
deepspeed training
AlexPiche Oct 31, 2024
2a3f0ee
update accelerate version
AlexPiche Oct 31, 2024
1503884
negative rewards for too many steps
AlexPiche Nov 1, 2024
5600f85
discounting
AlexPiche Nov 1, 2024
743a3e6
try to get lora working
AlexPiche Nov 2, 2024
5e11b32
step norm
AlexPiche Nov 2, 2024
b42a785
penalize 20 steps tape
AlexPiche Nov 2, 2024
e8872c3
step discount
AlexPiche Nov 2, 2024
7be219b
discoutn and max steps
AlexPiche Nov 3, 2024
f46c0ea
fix discount typo
AlexPiche Nov 4, 2024
5d87a2a
merge 8b branch
AlexPiche Nov 4, 2024
4e79c5f
use adafactor with accelerate
AlexPiche Nov 5, 2024
1dfe78d
implicit kl
AlexPiche Nov 5, 2024
7a2961c
implicit kl
AlexPiche Nov 5, 2024
c737e72
better variable names
AlexPiche Nov 5, 2024
d9857c5
update docstring
AlexPiche Nov 5, 2024
4273428
update docstring
AlexPiche Nov 5, 2024
efc38f5
get log probs worker
AlexPiche Nov 5, 2024
2df67bd
WIP commit for browsing math tapes
rizar Nov 5, 2024
a2427fe
samples -> problems
rizar Nov 5, 2024
e805a33
dataset stats logging
AlexPiche Nov 5, 2024
575df66
refactoring of the rl code
AlexPiche Nov 5, 2024
2e332a7
conf/accelerate
AlexPiche Nov 5, 2024
540e078
clean up conf
AlexPiche Nov 5, 2024
2d6ff26
clean up conf
AlexPiche Nov 5, 2024
e240688
better default hps
AlexPiche Nov 5, 2024
db225e0
clean up
AlexPiche Nov 5, 2024
fc204eb
bigger batch size
AlexPiche Nov 5, 2024
7679984
small refactor
AlexPiche Nov 5, 2024
193e94b
dataset stats on main process only
AlexPiche Nov 5, 2024
4bdf653
fix dataset stats when training is resumed
AlexPiche Nov 6, 2024
5ac8f14
try RS
AlexPiche Nov 7, 2024
ebab454
typo in config
AlexPiche Nov 7, 2024
d09d7ff
rs
AlexPiche Nov 7, 2024
1696a51
read 100 llm calls at a time
rizar Nov 8, 2024
43b7636
Merge branch 'llama70b_gsm8k' into browse_math_tapes
rizar Nov 8, 2024
127c7f3
changeable port for tape_browser
rizar Nov 8, 2024
631ebb0
tape browser launch script and json gathering script
rizar Nov 8, 2024
961a429
better way to save tapes
rizar Nov 8, 2024
98656ea
cute little renaming
rizar Nov 8, 2024
3af4b2b
Merge pull request #92 from ServiceNow/browse_math_tapes
AlexPiche Nov 12, 2024
056592c
simple rl gsm8k
AlexPiche Nov 12, 2024
8605fe5
produce more tokens
AlexPiche Nov 13, 2024
7ecc197
icnrease max tokens
AlexPiche Nov 13, 2024
196c9b4
clena up agent architecture
AlexPiche Nov 13, 2024
bb96ae8
simpler agent
AlexPiche Nov 14, 2024
7499c8f
rm loop
AlexPiche Nov 14, 2024
ac02145
clean string
AlexPiche Nov 14, 2024
c241214
hack simple agent
AlexPiche Nov 14, 2024
000ade0
improve simple agent
AlexPiche Nov 14, 2024
9ce7a5b
clean up string
AlexPiche Nov 14, 2024
64ac725
2000 tokens
AlexPiche Nov 14, 2024
3379173
better logging
AlexPiche Nov 14, 2024
294d2bf
fix browser to read new tapes
AlexPiche Nov 14, 2024
54dc1b7
assert its the right completion
AlexPiche Nov 15, 2024
73ad382
rm assert
AlexPiche Nov 15, 2024
f8e24d8
relu weights
AlexPiche Nov 15, 2024
a30dbc6
klcoef for reinforce
AlexPiche Nov 15, 2024
bbabd3e
revert to adam
AlexPiche Nov 16, 2024
4cfa51e
typo
AlexPiche Nov 16, 2024
53ac342
rm bos token
AlexPiche Nov 16, 2024
7e903c0
clean up nodes
AlexPiche Nov 17, 2024
a795ac1
clipping loss to 0
AlexPiche Nov 18, 2024
4854cdc
better handling of bos token
AlexPiche Nov 18, 2024
e0f8203
Merge pull request #101 from ServiceNow/simple_rl_gsm8k
AlexPiche Nov 18, 2024
f7b66bf
rm BOS from tests
AlexPiche Nov 18, 2024
94f7c21
Merge remote-tracking branch 'origin/simple_rl_gsm8k' into llama70b_g…
AlexPiche Nov 18, 2024
8608726
fix typo
AlexPiche Nov 18, 2024
e97ed41
clean up example
AlexPiche Nov 18, 2024
8e19489
Merge remote-tracking branch 'origin/main' into llama70b_gsm8k
AlexPiche Nov 18, 2024
3b09425
Merge remote-tracking branch 'origin/fix_test' into llama70b_gsm8k
AlexPiche Nov 18, 2024
1303602
rm run training in process
AlexPiche Nov 18, 2024
73ee354
rm if rl start from base model
AlexPiche Nov 18, 2024
35cb007
clean up
AlexPiche Nov 19, 2024
03174da
typo
AlexPiche Nov 19, 2024
6ffc6b0
clean up
AlexPiche Nov 19, 2024
397e0b8
better docs
AlexPiche Nov 19, 2024
8b60ba6
clean up
AlexPiche Nov 19, 2024
0200482
Update tapeagents/observe.py
AlexPiche Nov 19, 2024
4c4e71a
dima changes
AlexPiche Nov 19, 2024
5618143
clean up
AlexPiche Nov 19, 2024
6268098
update readme
AlexPiche Nov 19, 2024
9dd4cf7
reverse change
AlexPiche Nov 19, 2024
a59aaad
improve doc
AlexPiche Nov 19, 2024
251ce78
improve doc
AlexPiche Nov 19, 2024
cbae73f
rm debug code
AlexPiche Nov 20, 2024
94109b0
fix logging of max min dataset len
AlexPiche Nov 20, 2024
6ef2ea3
fix min seq length logging
AlexPiche Nov 20, 2024
2d698cd
fix naming of variables
AlexPiche Nov 20, 2024
18b1b02
typo
AlexPiche Nov 20, 2024
1291be2
fix naming
AlexPiche Nov 20, 2024
90b080f
hf dataset
AlexPiche Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions conf/finetune/rl_llama31_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ wandb_resume: always
# Whether to use only the basename or the full path as the run name
wandb_use_basename: false
config_name: meta-llama/Meta-Llama-3.1-8B-Instruct
learning_rate: 0.000005
train_batch_size: 1
gradient_accumulation_passes: 1024
learning_rate: 0.0000025
train_batch_size: 4
gradient_accumulation_passes: 256
seq_length: 4096
load_as_bf16: True
max_train_steps: 100000
save_checkpoint_steps: ???
optim: adamw_torch
optim: adafactor # FIXME: adamw runs OOM with accelerate
objective: rl
log_each_n_steps: 1
resume_dataloader: false
Expand All @@ -33,6 +33,7 @@ use_safetensors: true
weight_decay: 0.1
gradient_clipping_threshold: 1
rl:
kl_coef: 0.05
kl_coef: 0.0
reward_minus_kl_coef: 0.0
use_advantages: true
algo: reinforce
34 changes: 34 additions & 0 deletions conf/finetune/rs_llama31_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
defaults:
- base
- _self_

#Use W&B experiment logging
use_wandb: True
# W&B id; if given, will resume this run
wandb_id: null
# W&B name; if not given will use run dir
wandb_name: null
# W&B entity name
wandb_entity_name: null
# W&B project name
wandb_project_name: tapeagents
# W&B resume policy
wandb_resume: always
# Whether to use only the basename or the full path as the run name
wandb_use_basename: false
config_name: meta-llama/Meta-Llama-3.1-8B-Instruct
learning_rate: 0.0000025
train_batch_size: 4
gradient_accumulation_passes: 256
seq_length: 4096
load_as_bf16: True
max_train_steps: 100000
save_checkpoint_steps: ???
optim: adafactor # FIXME: adamw runs OOM with accelerate
objective: nll
log_each_n_steps: 1
resume_dataloader: false
cuda_empty_cache: true
use_safetensors: true
weight_decay: 0.1
gradient_clipping_threshold: 1
4 changes: 3 additions & 1 deletion conf/rl_debug.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
defaults:
- rl_gsm8k
- _self_

max_agent_forks: 16
attempts: 1

test_every_n_iterations: -1
finetune:
save_checkpoint_steps: 2
gradient_accumulation_passes: 16
Expand Down
15 changes: 9 additions & 6 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ defaults:
- _self_

n_workers: 32
max_loops: 10
get_log_probs_workers: 1
max_loops: 1
test_every_n_iterations: 5
model_path: meta-llama/Meta-Llama-3.1-8B-Instruct
max_agent_forks: 1024
attempts: 64
force_restart: false
max_iterations: 100
max_iterations: 1000
use_rejection_sampling: false
llm:
parameters:
max_tokens: 1024
max_tokens: 2000
temperature: 0.7
test_llm:
parameters:
max_tokens: 1024
parameters:
max_tokens: ${...llm.parameters.max_tokens}
temperature: 0.

finetune:
Expand All @@ -27,16 +29,17 @@ finetune:
# One step is one weight update. See the finetuning configuration
# for the info in how many sequences are used for each weight update.
save_checkpoint_steps: 10
seq_length: 2000

vllm_config:
vllm_kwargs:
--download-dir: /mnt/llmd/base_models/
--max-model-len: 8000
--gpu-memory-utilization: 0.9
# VLLM get log probs OOM https://github.com/vllm-project/vllm/issues/5907
--enable-chunked-prefill: ""

output_dir: outputs/rl_gsm8k
accelerate_cfg_path: conf/accelerate/accelerate_base.yaml

hydra:
run:
Expand Down
6 changes: 6 additions & 0 deletions conf/rs_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- rl_gsm8k
- override finetune: rs_llama31_8b
- _self_

use_rejection_sampling: true
11 changes: 8 additions & 3 deletions examples/gsm8k_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ This example demonstrates how to improve the math skills of a small LLama model.
We built a basic math agent that uses the LLama 3.1 70B model equipped with reasoning and calculator tool to solve math problems: [math_agent.py](math_agent.py).

Steps to distill the math agent:
- [run it as a teacher](produce_teacher_tapes.py), collect the tapes of the successful solutions, and produce training data for the Math Agent from them. How to run: `python -m examples.gsm8k_tuning.produce_teacher_tapes`
- [run it as a teacher](produce_teacher_tapes.py), collect the tapes of the successful solutions, and produce training data for the Math Agent from them. How to run: `python -m examples.gsm8k_tuning.produce_teacher_tapes` or download the data generated by `meta-llama/Llama-3.1-70B-Instruct`
```python
df = pd.read_json("hf://datasets/ServiceNow/llama31_70b_gsm8k_3k/training_samples_3k.jsonl", lines=True)
df.to_json("training_samples_3k.jsonl", orient="records", lines=True)
```
- [fine-tune smaller LLama 3.1 8B model](finetune_student.py) on the training data to get a tuned Math Agent. How to run: `python -m examples.gsm8k_tuning.finetune_student`
- [merge the lora weights](../../tapeagents/finetune/lora.py) to be able to serve the model with vLLM. How to run: `python -m tapeagents.finetune.lora PATH/TO/WEIGHTS`
- [evaluate the tuned Math Agent](evaluate_student.py) on the subset of GSM8K test set, comparing the accuracy of the teacher agent, student agent before tuning, and student agent after tuning. How to run: `python -m examples.gsm8k_tuning.evaluate_student`

<img width="526" alt="image" src="https://github.com/user-attachments/assets/a7aa2908-2a86-4b85-92d2-8c133e9ac0ff">
<img width="526" alt="image" src="https://github.com/user-attachments/assets/55d099ab-ff5c-480b-b5b3-504b4206e677">

| Model | Test accuracy |
| ----- | ------------- |
| 8B student before tuning | 0.662 |
| 8B student after tuning | 0.775 |
| 8B student after tuning | 0.785 |
| 70B teacher | 0.931 |

RL tuning on both successful and unsuccessful solutions is coming soon. Stay tuned!
19 changes: 19 additions & 0 deletions examples/rl_gsm8k/browse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from pathlib import Path
import sys

from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
from tapeagents.tape_browser import TapeBrowser
from examples.rl_gsm8k.cot_math_agent import MathTape

# comment this code out if loading the prompt and completions takes too long for you
tape_dir = Path(sys.argv[1])
exp_dir = tape_dir
# try to find a parent directory for tape_dir path that contains llm_calls.sqlite
while not os.path.exists(exp_dir / "llm_calls.sqlite") and exp_dir != Path("."):
exp_dir = exp_dir.parent
os.environ["TAPEAGENTS_SQLITE_DB"] = os.path.join(exp_dir, "llm_calls.sqlite")


browser = TapeBrowser(MathTape, sys.argv[1], CameraReadyRenderer(), file_extension=".json")
browser.launch(port=7680 if len(sys.argv) < 3 else int(sys.argv[2]))
109 changes: 109 additions & 0 deletions examples/rl_gsm8k/cot_math_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
from typing import Annotated, Generator, Literal, TypeAlias, Union

from pydantic import Field

from examples.gsm8k_tuning.math_agent import extract_result_value # noqa
from tapeagents.agent import Agent
from tapeagents.core import (
Action,
LLMOutputParsingFailureAction,
Observation,
Step,
Tape,
Thought,
)
from tapeagents.environment import Environment
from tapeagents.llms import LLM
from tapeagents.nodes import MonoNode

logger = logging.getLogger(__name__)

COT_GUIDANCE = "Think step by step. When you know the answer to the question, provide it in the following format: The answer is: <number>"


class Task(Observation):
kind: Literal["task"] = "task"
task: str

def llm_view(self, indent: int | None = 2) -> str:
return f"{self.task} {COT_GUIDANCE}"


class ReasoningThoughtwithValue(Thought):
"""
Thoughts produced by the agent during the reasoning process.
"""

kind: Literal["reasoning_thought_with_value"] = "reasoning_thought_with_value"
reasoning: str = Field(description="chain of thoughts")
value: float = Field(description="value of the reasoning")


MathAgentStep: TypeAlias = Annotated[
ReasoningThoughtwithValue,
Field(discriminator="kind"),
]

MathTape = Tape[
None,
Union[
Task,
ReasoningThoughtwithValue,
LLMOutputParsingFailureAction,
],
]


class ReasoningNode(MonoNode):
def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, None, None]:
if "The answer is" not in completion:
yield LLMOutputParsingFailureAction(
error=f"Failed to parse agent output: {completion}", llm_output=completion
)
return
try:
value = completion.split("The answer is")[-1]
value = value.replace(",", "")
value = value.replace(" ", "")
value = value.replace(":", "")
value = value.replace("$", "")
value = value.replace("%", "")
value = value.replace("€", "")
value = value.strip()
step = ReasoningThoughtwithValue(reasoning=completion, value=float(value))
except Exception as e:
logger.info(f"Failed to parse agent output: {completion}\n\nError: {e}")
yield LLMOutputParsingFailureAction(
error=f"Failed to parse agent output: {completion}\n\nError: {e}", llm_output=completion
)
return
yield step


#### Agent and Environment ####
class CoTMathAgent(Agent):
@classmethod
def create(cls, llm: LLM):
return super().create(
llm,
nodes=[
ReasoningNode(
name="cot",
agent_step_cls=MathAgentStep,
),
],
max_iterations=1,
)


class MathEnvironment(Environment):
def __init__(self) -> None:
super().__init__()

def react(self, tape: MathTape) -> MathTape:
actions = [step for step in tape.steps[-tape.metadata.n_added_steps :] if isinstance(step, Action)]
for action in actions:
if isinstance(action, LLMOutputParsingFailureAction):
continue
return tape
25 changes: 25 additions & 0 deletions examples/rl_gsm8k/gather_jsons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# read json files from a folder, create new json with the same name that contains all the content

import sys
import json
import os

from tapeagents.io import load_tapes
from examples.gsm8k_tuning.math_agent import MathTape

def gather_jsons(folder: str):
all_jsons = []
for root, _, files in os.walk(folder):
for file in files:
if file.endswith(".json"):
with open(os.path.join(root, file)) as f:
all_jsons.append(json.load(f))

dst_dir = f"{folder}/all"
os.makedirs(dst_dir, exist_ok=True)
dst_name = f"{dst_dir}/tapes.json"
with open(dst_name, "w") as f:
json.dump(all_jsons, f, indent=4)


gather_jsons(sys.argv[1])
Loading