Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool-based actions, env as a tools collection #144

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d1d761c
browser class that can execute relevant action steps and produce obse…
ollmer Dec 10, 2024
4db43ff
tool and multitool classes, actions based on tools
ollmer Dec 11, 2024
0a5bf4d
new common tools and browser as stateful multitool
ollmer Dec 11, 2024
a64b9f4
pydantic fixes
ollmer Dec 11, 2024
06a2d30
fix
ollmer Dec 11, 2024
612a69f
fix nodes
ollmer Dec 11, 2024
9348564
Merge branch 'main' into tool_based_actions
ollmer Dec 11, 2024
4358043
Merge branch 'main' into tool_based_actions
ollmer Dec 11, 2024
250939a
stop replay on error during test
ollmer Dec 12, 2024
51c2492
use cache in new tools
ollmer Dec 12, 2024
69c6f89
further simplify gaia agent
ollmer Dec 12, 2024
d11c3da
further simplify gaia agent
ollmer Dec 12, 2024
59c1840
Merge branch 'tool_based_actions' of github.com:ServiceNow/TapeAgents…
ollmer Dec 12, 2024
1a42993
fix code step and cache key
ollmer Dec 12, 2024
aedfc9b
fix
ollmer Dec 12, 2024
f34a873
even cleaner agent
ollmer Dec 12, 2024
b2f1250
exp config
ollmer Dec 12, 2024
6d08dbc
fix parsing error action handling
ollmer Dec 12, 2024
8fcd103
fix prompt
ollmer Dec 12, 2024
d36be3e
better locks
ollmer Dec 12, 2024
dbffe55
update exp conf
ollmer Dec 12, 2024
f73617f
common step tool env
ollmer Dec 12, 2024
49c17ab
better name
ollmer Dec 12, 2024
a75c4c4
agent name with version
ollmer Dec 12, 2024
d9c7ad1
action usage stat in tape browser
ollmer Dec 12, 2024
ad8aadb
Merge branch 'main' into tool_based_actions
ollmer Dec 13, 2024
1197d4e
browsergym markdown mode
ollmer Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions conf/gaia_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ agent:
subtasks: false

env:
attachment_dir: ${exp_path}/attachments/
image_observations: false
use_web_cache: true

hydra:
Expand Down
8 changes: 3 additions & 5 deletions conf/gaia_openai.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
defaults:
- llm: gpt4o
- llm: gpt4o_mini
- _self_

exp_name: gpt4o_val_search1
exp_name: gpt4o_mini_val_toolenv1
exp_path: outputs/gaia/runs/${exp_name}
split: validation
batch: 1
batch: 32

agent:
plain_code: false

env:
attachment_dir: ${exp_path}/attachments/
image_observations: true
use_web_cache: true

studio:
Expand Down
1 change: 0 additions & 1 deletion conf/workarena_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ exp_path: ../workarena/runs/${exp_name}
agent: baseline
env:
exp_path: ${exp_path}
baseline_obs: True
headless: False
seeds: [42]

Expand Down
1 change: 0 additions & 1 deletion conf/workarena_openai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ exp_path: ../workarena/runs/${exp_name}
agent: guided
env:
exp_path: ${exp_path}
baseline_obs: False
headless: True
seeds: [0, 42, 1337, 900, 103]

Expand Down
129 changes: 21 additions & 108 deletions examples/gaia_agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,49 @@
import logging
from enum import Enum
from typing import Any

from pydantic import Field

from tapeagents.agent import Agent
from tapeagents.environment import CodeExecutionResult, ExecuteCode
from tapeagents.llms import LLM
from tapeagents.nodes import MonoNode
from tapeagents.steps import VideoObservation
from tapeagents.tools.container_executor import extract_code_blocks
from tapeagents.steps import ActionExecutionFailure, VideoObservation
from tapeagents.tools.simple_browser import PageObservation

from .prompts import PromptRegistry
from .steps import (
ActionExecutionFailure,
CalculationResultObservation,
CodeResultObservation,
FinishSubtask,
GaiaAgentStep,
GaiaQuestion,
ListOfFactsThought,
NewFactThought,
PageObservation,
PlanThought,
SourcesThought,
all_steps,
nocode_steps,
plan_steps,
)
from .steps import AGENT_STEPS, STEPS_WITHOUT_CODE, FactsSurvey, Plan
from .tape import GaiaTape

logger = logging.getLogger(__name__)


class PlanningMode(str, Enum):
simple = "simple"
facts_and_sources = "facts_and_sources"
multiplan = "multiplan"
replan_after_sources = "replan_after_sources"
reflect = "reflect"


class GaiaNode(MonoNode):
system_prompt: str = PromptRegistry.system_prompt
steps_prompt: str = PromptRegistry.allowed_steps
agent_step_cls: Any = Field(exclude=True, default=GaiaAgentStep)
allowed_steps: str

def get_steps_description(self, tape: GaiaTape, agent: Any) -> str:
"""
Allow different subset of steps based on the agent's configuration
"""
return self.steps_prompt.format(allowed_steps=self.allowed_steps)

def prepare_tape(self, tape: GaiaTape, max_chars: int = 200) -> GaiaTape:
"""
Trim long observations except for the last 3 steps
"""
tape = super().prepare_tape(tape) # type: ignore
steps = []
for step in tape.steps[:-3]:
if isinstance(step, PageObservation):
short_text = f"{step.text[:max_chars]}\n..." if len(step.text) > max_chars else step.text
new_step = step.model_copy(update=dict(text=short_text))
elif isinstance(step, ActionExecutionFailure):
short_error = f"{step.error[:max_chars]}\n..." if len(step.error) > max_chars else step.error
new_step = step.model_copy(update=dict(error=short_error))
steps_border = -3
for step in tape.steps[:steps_border]:
if isinstance(step, PageObservation) and len(step.text) > max_chars:
trimmed_step = step.model_copy(update=dict(text=f"{step.text[:max_chars]}\n..."))
elif isinstance(step, ActionExecutionFailure) and len(step.error) > max_chars:
trimmed_step = step.model_copy(update=dict(error=f"{step.error[:max_chars]}\n..."))
elif isinstance(step, VideoObservation):
new_step = step.model_copy(update=dict(video_contact_sheet_paths=None, subtitle_text=None))
trimmed_step = step.model_copy(update=dict(video_contact_sheet_paths=None, subtitle_text=None))
else:
new_step = step
steps.append(new_step)
trimmed_tape = tape.model_copy(update=dict(steps=steps + tape.steps[-3:]))
return trimmed_tape

def trim_tape(self, tape: GaiaTape) -> GaiaTape:
"""
Make tape shorter to fit llm context size limits
"""
finish_subtask_positions = [i for i, step in enumerate(tape) if isinstance(step, FinishSubtask)]
# trim either after last finished subtask or at 2/3 of the tape
summarization_border = (finish_subtask_positions[-1] + 1) if finish_subtask_positions else int(len(tape) * 0.66)
short_tape = tape.model_copy(update=dict(steps=[]))
pre_tape: GaiaTape = tape[:summarization_border] # type: ignore
for step in pre_tape.steps:
if isinstance(
step,
(
GaiaQuestion,
PlanThought,
SourcesThought,
ListOfFactsThought,
NewFactThought,
CalculationResultObservation,
CodeResultObservation,
CodeExecutionResult,
),
):
short_tape.steps.append(step)
for step in tape.steps[summarization_border:]:
short_tape.steps.append(step)
logger.info(f"Tape reduced from {len(tape)} to {len(short_tape)} steps")
return short_tape

def parse_completion(self, llm_output: str, prompt_id: str):
if llm_output.strip().startswith("```"):
code_blocks = extract_code_blocks(llm_output)
yield ExecuteCode(code=code_blocks)
else:
for step in super().parse_completion(llm_output, prompt_id):
yield step
trimmed_step = step
steps.append(trimmed_step)
return tape.model_copy(update=dict(steps=steps + tape.steps[steps_border:]))


class GaiaAgent(Agent):
plain_code: bool
name: str = "gaia_agent_v3"

@classmethod
def create(cls, llm: LLM, plain_code: bool = False, **kwargs):
steps_prompt = PromptRegistry.allowed_steps_code if plain_code else PromptRegistry.allowed_steps
steps = STEPS_WITHOUT_CODE if plain_code else AGENT_STEPS
nodes = [
GaiaNode(name="plan", guidance=PromptRegistry.plan, allowed_steps=plan_steps),
GaiaNode(name="facts_survey", guidance=PromptRegistry.facts_survey, allowed_steps=plan_steps),
GaiaNode(
name="start_execution",
guidance=PromptRegistry.start_execution,
steps_prompt=PromptRegistry.allowed_steps_code if plain_code else PromptRegistry.allowed_steps,
allowed_steps=nocode_steps if plain_code else all_steps,
),
GaiaNode(
name="act",
steps_prompt=PromptRegistry.allowed_steps_code if plain_code else PromptRegistry.allowed_steps,
allowed_steps=nocode_steps if plain_code else all_steps,
next_node="act",
),
GaiaNode(name="plan", guidance=PromptRegistry.plan, agent_steps=Plan),
GaiaNode(name="facts_survey", guidance=PromptRegistry.facts_survey, agent_steps=FactsSurvey),
GaiaNode(name="start", guidance=PromptRegistry.start, steps_prompt=steps_prompt, agent_steps=steps),
GaiaNode(name="act", steps_prompt=steps_prompt, agent_steps=steps, next_node="act"),
]
return super().create(llm, nodes=nodes, max_iterations=2, plain_code=plain_code, **kwargs)
return super().create(llm, nodes=nodes, max_iterations=2, **kwargs)
Loading
Loading