Skip to content

Commit

Permalink
Merge pull request #180 from ServiceNow/parallel-study
Browse files Browse the repository at this point in the history
parallel study evaluation
  • Loading branch information
recursix authored Dec 18, 2024
2 parents fc4c62f + 46f84d0 commit 64c8bc9
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 12 deletions.
10 changes: 10 additions & 0 deletions src/agentlab/agents/agent_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@


class AgentArgs(AbstractAgentArgs):
"""Base class for agent arguments for instantiating an agent.
Define agent arguments as dataclass variables of this class. For example:
class MyAgentArgs(AgentArgs):
my_arg: str = "default_value"
my_other_arg: int = 42
Note: for working properly with AgentXRay, the arguments need to be serializable and hasable.
"""

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
"""Optional method to set benchmark specific flags.
Expand Down
90 changes: 90 additions & 0 deletions src/agentlab/experiments/multi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from copy import deepcopy
from dataclasses import dataclass
import os
import sys
from browsergym.webarena.instance import WebArenaInstance


class BaseServer:
"""Base class for server instances.
Behaves like an identity function for running in parallel on servers that don't need multiple
instances.
"""

def init(self):
pass


@dataclass
class WebArenaInstanceVars(BaseServer):
base_url: str
shopping: str
shopping_admin: str
reddit: str
gitlab: str
wikipedia: str
map: str
homepage: str
full_reset: str
module_name: str = "webarena"
prefix: str = "WA_"

def make_env_vars(self):
"""Return a dictionary of environment variables"""
return {
f"{self.prefix}SHOPPING": f"{self.base_url}:{self.shopping}",
f"{self.prefix}SHOPPING_ADMIN": f"{self.base_url}:{self.shopping_admin}",
f"{self.prefix}REDDIT": f"{self.base_url}:{self.reddit}",
f"{self.prefix}GITLAB": f"{self.base_url}:{self.gitlab}",
f"{self.prefix}WIKIPEDIA": f"{self.base_url}:{self.wikipedia}",
f"{self.prefix}MAP": f"{self.base_url}:{self.map}",
f"{self.prefix}HOMEPAGE": f"{self.base_url}:{self.homepage}",
f"{self.prefix}FULL_RESET": f"{self.base_url}:{self.full_reset}",
}

def init(self):
# necessary for webarena to re-import the env vars
unimport_modules(self.module_name)
for key, value in self.make_env_vars().items():
os.environ[key] = value

# this is just a dynamic check to see that the env vars are set correctly
bgym_instance = WebArenaInstance()
base_url, _ = _split_url(bgym_instance.urls["reddit"])
assert base_url == self.base_url, f"Expected {self.base_url}, got {base_url}"

@staticmethod
def from_env_vars(prefix="WA_", module_name="webarena"):
kwargs = {"module_name": module_name}
base_urls = set()
for key, url in os.environ.items():
if key.startswith(prefix):
base_url, url_tail = _split_url(url)
base_urls.add(base_url)
kwargs[key[len(prefix) :].lower()] = url_tail

if len(base_urls) > 1:
raise ValueError("Multiple base urls found in environment variables")

kwargs["base_url"] = base_urls.pop()
return WebArenaInstanceVars(**kwargs)

def clone(self):
"""Return a deep copy of the instance"""
return deepcopy(self)


def unimport_modules(base_name):
"""un-import any module starting with base_name"""
for module in sys.modules.copy():
if module.startswith(base_name):
del sys.modules[module]


def _split_url(url: str):
"""Extract the base url and the port/page from a url"""
parts = url.split(":")
base_url = ":".join(parts[0:2])
url_tail = ":".join(parts[2:])
return base_url, url_tail
89 changes: 78 additions & 11 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
import logging
import os
import pickle
import uuid
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,8 @@
from agentlab.experiments import reproducibility_util as repro
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
from multiprocessing import Pool, Manager, Queue

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +30,7 @@ def make_study(
suffix="",
comment=None,
ignore_dependencies=False,
parallel_servers=None,
):
"""Run a list of agents on a benchmark.
Expand Down Expand Up @@ -57,10 +61,18 @@ def make_study(
3x compare to sequential executionz. To accelerate execution, you can ignore
dependencies and run in full parallel. This leads to a decrease in performance of about
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
parallel_servers: list[WebArenaInstanceVars]
The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to
dispatch agent_args on a pool of servers in parallel. If len(agent_args) >
len(parallel_servers), the servers will be reused for next evaluation (with a reset) as
soon as it is done.
Returns:
Study object or SequentialStudies object if the benchmark requires manual reset after each
evaluation such as WebArena and VisualWebArena.
Study | SequentialStudies | ParallelStudies object.
SequentialStudies: if the benchmark requires manual reset after each evaluation such as
WebArena and VisualWebArena.
ParallelStudies: if the benchmark has multiple servers to run in parallel.
Study: otherwise.
"""

if not isinstance(agent_args, (list, tuple)):
Expand All @@ -69,7 +81,7 @@ def make_study(
if isinstance(benchmark, str):
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]()

if "webarena" in benchmark.name and len(agent_args) > 1:
if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None):
logger.warning(
"*WebArena* requires manual reset after each evaluation. Running through SequentialStudies."
)
Expand All @@ -85,8 +97,10 @@ def make_study(
ignore_dependencies=ignore_dependencies,
)
)

return SequentialStudies(studies)
if parallel_servers is not None:
return ParallelStudies(studies, parallel_servers=parallel_servers)
else:
return SequentialStudies(studies)
else:
return Study(
agent_args,
Expand Down Expand Up @@ -164,7 +178,7 @@ class Study(AbstractStudy):
A suffix to add to the study name. This can be useful to keep track of your experiments.
By default the study name contains agent name, benchmark name and date.
uuid: str
A unique identifier for the study.
A unique identifier for the study. Will be generated automatically.
reproducibility_info: dict
Information about the study that may affect the reproducibility of the experiment. e.g.:
versions of BrowserGym, benchmark, AgentLab...
Expand All @@ -178,12 +192,12 @@ class Study(AbstractStudy):
information. Leave any extra information that can explain why results could be different
than expected.
ignore_dependencies: bool
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So
If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So
far, only WebArena and VisualWebArena have dependencies between tasks to minimize the
influence of solving one task before another one. This dependency graph allows
experiments to run in parallel while respecting task dependencies. However, it still
can't run more than 4 and, in practice it's speeding up evaluation by a factor of only
3x compare to sequential executionz. To accelerate execution, you can ignore
3x compare to sequential execution. To accelerate execution, you can ignore
dependencies and run in full parallel. This leads to a decrease in performance of about
1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work.
avg_step_timeout: int
Expand Down Expand Up @@ -455,13 +469,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_
study.make_dir(exp_root=self.dir)

self.save()

for study in self.studies:
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
self._run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
_, summary_df, _ = self.get_results()
logger.info("\n" + str(summary_df))
logger.info(f"SequentialStudies {self.name} finished.")

def _run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3):
for study in self.studies:
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)

def override_max_steps(self, max_steps):
for study in self.studies:
study.override_max_steps(max_steps)
Expand All @@ -471,6 +487,57 @@ def append_to_journal(self, strict_reproducibility=True):
study.append_to_journal(strict_reproducibility=strict_reproducibility)


def _init_worker(server_queue: Queue):
"""Run once at the initialization of the worker in the multiprocessing.Pool.
This is typically used to initialize different environment variables of the WebArena server for
multiple instances in parallel.
Args:
server_queue: Queue
A queue of object implementing BaseServer to initialize (or anything with a init
method).
"""
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
server_instance.init()


def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch):
"""Wrapper to run a study remotely."""
study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch)


@dataclass
class ParallelStudies(SequentialStudies):

parallel_servers: list[BaseServer] | int = None

def _run(
self,
n_jobs=1,
parallel_backend="ray",
strict_reproducibility=False,
n_relaunch=3,
):
parallel_servers = self.parallel_servers
if isinstance(parallel_servers, int):
parallel_servers = [BaseServer() for _ in range(parallel_servers)]

server_queue = Manager().Queue()
for server in parallel_servers:
server_queue.put(server)

with Pool(len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)) as p:
p.starmap(
_run_study,
[
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
for study in self.studies
],
)


def get_most_recent_study(
root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None
):
Expand Down
37 changes: 37 additions & 0 deletions tests/experiments/test_multi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from agentlab.experiments.multi_server import WebArenaInstanceVars
from browsergym.webarena.instance import WebArenaInstance


def test_webarena_multiserver():

instance_1 = WebArenaInstanceVars(
base_url="http://webarena1.eastus.cloudapp.azure.com",
shopping="8082/",
shopping_admin="8083/admin",
reddit="8080",
gitlab="9001",
wikipedia="8081/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing",
map="443",
homepage="80",
full_reset="7565",
module_name="webarena",
prefix="WA_",
)

instance_1.init()

bgym_instance = WebArenaInstance()
base_url_1 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
assert base_url_1 == instance_1.base_url

instance_2 = instance_1.clone()
instance_2.base_url = "http://webarena2.eastus.cloudapp.azure.com"
instance_2.init()

bgym_instance = WebArenaInstance()
base_url_2 = bgym_instance.urls["reddit"].rsplit(":", 1)[0]
assert base_url_2 == instance_2.base_url


if __name__ == "__main__":
test_webarena_multiserver()
2 changes: 1 addition & 1 deletion tests/experiments/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_execute_task_graph():
# Verify that parallel tasks (task2 and task3) started within a short time of each other
parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time)
print(f"parallel_start_diff: {parallel_start_diff}")
assert parallel_start_diff < 1.5 # Allow for a small delay
assert parallel_start_diff < 2 # Allow for a small delay

# Ensure that the entire task graph took the expected amount of time
total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time
Expand Down
58 changes: 58 additions & 0 deletions tests/experiments/test_study.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
from agentlab.experiments.study import ParallelStudies, make_study, Study
from agentlab.experiments.multi_server import WebArenaInstanceVars


def _make_agent_args_list():
# CheatMiniWoB agents won't succeed on WebArena, this is just for testing parallelization
agent_args_list = []
for i in range(2):
agent_args = GenericAgentArgs(
chat_model_args=CheatMiniWoBLLMArgs(),
flags=FLAGS_GPT_4o,
)

agent_args.agent_name = agent_args.agent_name + f"_{i}"
agent_args_list.append(agent_args)
return agent_args_list


@pytest.mark.skip(reason="This test requires WebArena instances to be running")
def manual_test_launch_parallel_study_webarena():
agent_args_list = _make_agent_args_list()

server_instance_1 = WebArenaInstanceVars.from_env_vars()
server_instance_2 = server_instance_1.clone()
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com"
parallel_servers = [server_instance_1, server_instance_2]

for server in parallel_servers:
print(server)

study = make_study(
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
)
assert isinstance(study, ParallelStudies)

study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)


def test_launch_parallel_study():
agent_args_list = _make_agent_args_list()

study = make_study(agent_args_list, benchmark="miniwob_tiny_test", parallel_servers=2)
assert isinstance(study, ParallelStudies)

study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
_, summary_df, _ = study.get_results()
assert len(summary_df) == 2
for n_completed in summary_df["n_completed"]:
assert n_completed == "4/4"


if __name__ == "__main__":
# test_launch_parallel_study()
manual_test_launch_parallel_study_webarena()

0 comments on commit 64c8bc9

Please sign in to comment.