Skip to content

Commit

Permalink
[fix] Test on urartu v2, make fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Hovhannes Tamoyan committed Jun 27, 2024
1 parent 628f917 commit d85c7af
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 73 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
urartu
jsonlines
tiktoken
langchain
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main(self):
model_inquirer.aim_run = self.aim_run
model_responder.aim_run = self.aim_run

for idx, sample in tqdm(enumerate(dataset), total=len(dataset), desc="samples"):
for idx, sample in tqdm(enumerate(dataset.dataset), total=len(dataset.dataset), desc="samples"):
for persona, persona_hash in tqdm(personas, desc="personas", leave=False):
self.aim_run["personas"][persona_hash] = persona

Expand All @@ -67,44 +67,44 @@ def main(self):
dialog = []
raw_dialog = []

instructions = [instruct.lstrip().rstrip() for instruct in sample["instruction"].split("\n")]
instructions = [instruct.lstrip().rstrip() for instruct in sample[task_cfg.dataset.input_key].split("\n")]

if self.action_cfg.task.model_inquirer.regenerate_tries:
regeneratinon_idx = 0
A_generate_cfg = None
B_output = None
inquirer_generate_cfg = None
responder_output = None
turn = 0
with tqdm(total=task_cfg.num_turns, desc="turns", leave=False) as pbar:
while turn < task_cfg.num_turns:
pbar.set_postfix(turn=turn + 1)
# ------------------------------------------ Model A ------------------------------------------
A_prompt = model_inquirer.get_prompt(
inquirer_prompt = model_inquirer.get_prompt(
turn=turn,
response_msg=B_output,
response_msg=responder_output,
persona=persona,
instructions=instructions,
)

self.track(
prompt=A_prompt,
name="A_input",
prompt=inquirer_prompt,
name="inquirer_input",
context={
"sample_id": idx,
"turn": turn,
"persona_hash": persona_hash,
},
)
A_output, _ = model_inquirer.generate(
prompt=A_prompt,
inquirer_output, _ = model_inquirer.generate(
prompt=inquirer_prompt,
generate_cfg=(
A_generate_cfg if A_generate_cfg else self.action_cfg.task.model_inquirer.generate
inquirer_generate_cfg if inquirer_generate_cfg else self.action_cfg.task.model_inquirer.generate
),
)
if not A_output:
if not inquirer_output:
break
self.track(
prompt=A_output,
name="A_output",
prompt=inquirer_output,
name="inquirer_output",
context={
"sample_id": idx,
"turn": turn,
Expand All @@ -113,21 +113,21 @@ def main(self):
)

# --------------------- if model_inquirer failed to provide coherent text ---------------------
if model_inquirer.is_non_coherent(A_output):
if model_inquirer.is_non_coherent(inquirer_output):
self.aim_run["num_non_coherent"] += 1
break

# --------------------- if model_inquirer wants to stop the dialog ---------------------
if model_inquirer.stop_dialog(A_output):
if model_inquirer.stop_dialog(inquirer_output):
break

A_output_extract, num_prompts = model_inquirer.extract_prompt(prompt=A_output)
inquirer_output_extract, num_prompts = model_inquirer.extract_prompt(prompt=inquirer_output)

if self.action_cfg.task.model_inquirer.regenerate_tries:
# --------------------- if model_inquirer failed to provide prompt ---------------------
if A_output_extract is None:
if inquirer_output_extract is None:
if regeneratinon_idx < self.action_cfg.task.model_inquirer.regenerate_tries:
A_generate_cfg = model_inquirer.get_generation_cfg()
inquirer_generate_cfg = model_inquirer.get_generation_cfg()
regeneratinon_idx += 1
continue
else:
Expand All @@ -137,15 +137,15 @@ def main(self):
if regeneratinon_idx != 0:
self.aim_run["num_regenerate_worked"] += 1
regeneratinon_idx = 0
A_generate_cfg = None
inquirer_generate_cfg = None

if A_output_extract is None:
if inquirer_output_extract is None:
self.aim_run["num_no_prompts"] += 1
break

self.track(
prompt=A_output_extract,
name="A_output_extract",
prompt=inquirer_output_extract,
name="inquirer_output_extract",
context={
"sample_id": idx,
"turn": turn,
Expand All @@ -155,31 +155,31 @@ def main(self):
)

# As the context for model_inquirer is getting bigger much faster -> Starts answering it's own questions
# To prevent this keep in the A_history only the output prompt(the thing that model_responder will see).
model_inquirer.update_history(prompt=A_prompt, output_extract=A_output_extract)
# To prevent this keep in the inquirer_history only the output prompt(the thing that model_responder will see).
model_inquirer.update_history(prompt=inquirer_prompt, output_extract=inquirer_output_extract)

# ------------------------------------------ Model B ------------------------------------------

B_prompt = model_responder.get_prompt(turn=turn, response_msg=A_output_extract)
responder_prompt = model_responder.get_prompt(turn=turn, response_msg=inquirer_output_extract)

self.track(
prompt=B_prompt,
name="B_input",
prompt=responder_prompt,
name="responder_input",
context={
"sample_id": idx,
"turn": turn,
"persona_hash": persona_hash,
},
)
B_output, B_model_output_template = model_responder.generate(
prompt=B_prompt,
responder_output, responder_model_output_template = model_responder.generate(
prompt=responder_prompt,
generate_cfg=self.action_cfg.task.model_responder.generate,
)
if not B_output:
if not responder_output:
break
self.track(
prompt=B_output,
name="B_output",
prompt=responder_output,
name="responder_output",
context={
"sample_id": idx,
"turn": turn,
Expand All @@ -188,22 +188,22 @@ def main(self):
)

# --------------------- if model_responder failed to provide coherent text ---------------------
if model_responder.is_non_coherent(B_output):
if model_responder.is_non_coherent(responder_output):
self.aim_run["num_non_coherent_model_responder"] += 1
break

model_responder.update_history(prompt=B_prompt, output_extract=B_model_output_template)
model_responder.update_history(prompt=responder_prompt, output_extract=responder_model_output_template)

# --------------------------------------- Save the dialog ---------------------------------------
dialog.append(
{
"turn": turn,
"model_inquirer": A_output_extract,
"model_responder": B_output,
"model_inquirer": inquirer_output_extract,
"model_responder": responder_output,
}
)
raw_dialog.append(A_output_extract)
raw_dialog.append(B_output)
raw_dialog.append(inquirer_output_extract)
raw_dialog.append(responder_output)

torch.cuda.empty_cache()
turn += 1
Expand Down
18 changes: 9 additions & 9 deletions roleplay/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
from typing import Any, Dict, List

from urartu.common.device import DEVICE
from urartu.common.device import Device


class Model:
Expand Down Expand Up @@ -83,15 +83,15 @@ def is_non_coherent(self, text):
return False

def get_generation_cfg(self) -> Dict[str, Any]:
A_generate_cfg = copy.deepcopy(self.cfg.generate)
generation_cfg = copy.deepcopy(self.cfg.generate)

A_generate_cfg["do_sample"] = True
A_generate_cfg["top_k"] = random.randint(5, 50)
A_generate_cfg["penalty_alpha"] = random.random()
A_generate_cfg["num_beams"] = random.randint(4, 10)
A_generate_cfg["temperature"] = random.uniform(0.5, 1)
generation_cfg["do_sample"] = True
generation_cfg["top_k"] = random.randint(5, 50)
generation_cfg["penalty_alpha"] = random.random()
generation_cfg["num_beams"] = random.randint(4, 10)
generation_cfg["temperature"] = random.uniform(0.5, 1)

return A_generate_cfg
return generation_cfg

@staticmethod
def collate_tokenize(data, tokenizer, input_key):
Expand All @@ -102,5 +102,5 @@ def collate_tokenize(data, tokenizer, input_key):
else:
input_text = element[input_key]
input_batch.append(input_text)
tokenized = tokenizer(input_batch, padding="longest", truncation=True, return_tensors="pt").to(DEVICE)
tokenized = tokenizer(input_batch, padding="longest", truncation=True, return_tensors="pt").to(Device.get_device())
return tokenized
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# @package _global_
action_name: roleplay
action_name: generate_dialogues
seed: 5

action_config:
workdir: "./"
experiment_name: roleplay
device: "auto" # auto, cuda, cpu (default)

task:
num_turns: 12
Expand All @@ -12,24 +14,28 @@ action_config:
model_responder: ???

dataset:
instruction:
- "You want to know how fast you run different distances. You use a stopwatch to measure the time it takes you to complete a 50-meter, 100-meter, and 200-meter race. You want to know how can you calculate your speed for each race? Based on that, you also want to calculate how many calories you burned during each race."
- "You can run at a rate of speed four times faster than you can walk, but you can skip at a rate of speed that is half as fast as you can run. You want to know If you can skip at 3 miles per hour, how many miles can you travel in six hours if you spend one-third of the time and two-thirds of the time running and walking, respectively. Also you are curious about the other way around (one-third of the time walking and two-thirds for running)."
- "Every day, you feed each of your chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. You give the chickens their feed in three separate meals. In the morning, you give your flock of chickens 15 cups of feed. In the afternoon, you give your chickens another 25 cups of feed. You want to know how many cups of feed do you need to give your chickens in the final meal of the day if the size of your flock is 20 chickens? Also you want to know how much does the chicken egg production rate depend on the feed you give? And if you provide enough feed to your chickens for high-rate egg production."
- "You want to make this function better. You want the chatbot to make it recursive to have memory optimal function, but make sure that it doesn’t enter to an infinite loop. After that you want to plug a CLI (command line interface) to this function, so the user can insert a number and get the factorial of it as output: 'The factorial of the <NUMBER>, is <FACTORIAL>'.
```
def factorialize(num):
factorial = 1
for i in range(1, num):
factorial *= i
return factorial
```"
- "You have a little project where you need to use JavaScript, a language you don't use every day. You have a subtask to write a function that counts how many vowels are in a given string. And you need this functionality in OOP. Also you want the chatbot to develop the snippet it provided by getting the function input string via an API call. If the chatbot uses functions or operators you are not familiar with feel free to ask follow-up questions about it."
- "You want to draw a unicorn in python using the 'turtle' module. (There should be multiple lines of short function calls). After that substitute the 10th line, which includes number argument(s), with the value 73(s)."
- "You want to know what are the world's 10 oldest continuously inhabited cities. Pick the 3rd in that list find out who established the city, in which region it is located and what was the highest population."
- "You have written a content that disagrees with the following statement: 'Technology is the cause of all societal problems' And you want the chatbot to generate a response that agrees with the statement, to make your claims stronger."
- "You plan a trip to France and would like to do a walking tour. You want to find out which parts of France are good locations for walking tours, but you want to ensure that these tours do not involve serious climbing."
- "You want to use the chatbot to create a poem about cats. Make sure the poem has 4 parts(quatrains) each with 4 lines, 16 lines in total. Refine the poem until you are satisfied and it is coherent. Also you want to change the style of one of the quatrains to reflect the distinctive style of your favourite poet."
type:
_target_: roleplay.datasets.hf_datasets.HFDatasets
input_key: "instruction"
data:
instruction:
- "You want to know how fast you run different distances. You use a stopwatch to measure the time it takes you to complete a 50-meter, 100-meter, and 200-meter race. You want to know how can you calculate your speed for each race? Based on that, you also want to calculate how many calories you burned during each race."
- "You can run at a rate of speed four times faster than you can walk, but you can skip at a rate of speed that is half as fast as you can run. You want to know If you can skip at 3 miles per hour, how many miles can you travel in six hours if you spend one-third of the time and two-thirds of the time running and walking, respectively. Also you are curious about the other way around (one-third of the time walking and two-thirds for running)."
- "Every day, you feed each of your chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. You give the chickens their feed in three separate meals. In the morning, you give your flock of chickens 15 cups of feed. In the afternoon, you give your chickens another 25 cups of feed. You want to know how many cups of feed do you need to give your chickens in the final meal of the day if the size of your flock is 20 chickens? Also you want to know how much does the chicken egg production rate depend on the feed you give? And if you provide enough feed to your chickens for high-rate egg production."
- "You want to make this function better. You want the chatbot to make it recursive to have memory optimal function, but make sure that it doesn’t enter to an infinite loop. After that you want to plug a CLI (command line interface) to this function, so the user can insert a number and get the factorial of it as output: 'The factorial of the <NUMBER>, is <FACTORIAL>'.
```
def factorialize(num):
factorial = 1
for i in range(1, num):
factorial *= i
return factorial
```"
- "You have a little project where you need to use JavaScript, a language you don't use every day. You have a subtask to write a function that counts how many vowels are in a given string. And you need this functionality in OOP. Also you want the chatbot to develop the snippet it provided by getting the function input string via an API call. If the chatbot uses functions or operators you are not familiar with feel free to ask follow-up questions about it."
- "You want to draw a unicorn in python using the 'turtle' module. (There should be multiple lines of short function calls). After that substitute the 10th line, which includes number argument(s), with the value 73(s)."
- "You want to know what are the world's 10 oldest continuously inhabited cities. Pick the 3rd in that list find out who established the city, in which region it is located and what was the highest population."
- "You have written a content that disagrees with the following statement: 'Technology is the cause of all societal problems' And you want the chatbot to generate a response that agrees with the statement, to make your claims stronger."
- "You plan a trip to France and would like to do a walking tour. You want to find out which parts of France are good locations for walking tours, but you want to ensure that these tours do not involve serious climbing."
- "You want to use the chatbot to create a poem about cats. Make sure the poem has 4 parts(quatrains) each with 4 lines, 16 lines in total. Refine the poem until you are satisfied and it is coherent. Also you want to change the style of one of the quatrains to reflect the distinctive style of your favourite poet."

spec_tokens:
persona_placeholder: "<PERSONA>"
Expand Down Expand Up @@ -198,4 +204,4 @@ action_config:
hydra:
sweeper:
params:
seed: 73,321,8479
seed: 73,321,8479
12 changes: 12 additions & 0 deletions roleplay/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Dict, List

from urartu.common.dataset import Dataset
from datasets import Dataset as HFDataset


class HFDatasets(Dataset):
def __init__(self, cfg: List[Dict[str, Any]]) -> None:
super().__init__(cfg)

def _get_dataset(self):
self.dataset = HFDataset.from_dict(dict(self.cfg.data))
Loading

0 comments on commit d85c7af

Please sign in to comment.