-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #180 from ServiceNow/parallel-study
parallel study evaluation
- Loading branch information
Showing
6 changed files
with
274 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
import os | ||
import sys | ||
from browsergym.webarena.instance import WebArenaInstance | ||
|
||
|
||
class BaseServer: | ||
"""Base class for server instances. | ||
Behaves like an identity function for running in parallel on servers that don't need multiple | ||
instances. | ||
""" | ||
|
||
def init(self): | ||
pass | ||
|
||
|
||
@dataclass | ||
class WebArenaInstanceVars(BaseServer): | ||
base_url: str | ||
shopping: str | ||
shopping_admin: str | ||
reddit: str | ||
gitlab: str | ||
wikipedia: str | ||
map: str | ||
homepage: str | ||
full_reset: str | ||
module_name: str = "webarena" | ||
prefix: str = "WA_" | ||
|
||
def make_env_vars(self): | ||
"""Return a dictionary of environment variables""" | ||
return { | ||
f"{self.prefix}SHOPPING": f"{self.base_url}:{self.shopping}", | ||
f"{self.prefix}SHOPPING_ADMIN": f"{self.base_url}:{self.shopping_admin}", | ||
f"{self.prefix}REDDIT": f"{self.base_url}:{self.reddit}", | ||
f"{self.prefix}GITLAB": f"{self.base_url}:{self.gitlab}", | ||
f"{self.prefix}WIKIPEDIA": f"{self.base_url}:{self.wikipedia}", | ||
f"{self.prefix}MAP": f"{self.base_url}:{self.map}", | ||
f"{self.prefix}HOMEPAGE": f"{self.base_url}:{self.homepage}", | ||
f"{self.prefix}FULL_RESET": f"{self.base_url}:{self.full_reset}", | ||
} | ||
|
||
def init(self): | ||
# necessary for webarena to re-import the env vars | ||
unimport_modules(self.module_name) | ||
for key, value in self.make_env_vars().items(): | ||
os.environ[key] = value | ||
|
||
# this is just a dynamic check to see that the env vars are set correctly | ||
bgym_instance = WebArenaInstance() | ||
base_url, _ = _split_url(bgym_instance.urls["reddit"]) | ||
assert base_url == self.base_url, f"Expected {self.base_url}, got {base_url}" | ||
|
||
@staticmethod | ||
def from_env_vars(prefix="WA_", module_name="webarena"): | ||
kwargs = {"module_name": module_name} | ||
base_urls = set() | ||
for key, url in os.environ.items(): | ||
if key.startswith(prefix): | ||
base_url, url_tail = _split_url(url) | ||
base_urls.add(base_url) | ||
kwargs[key[len(prefix) :].lower()] = url_tail | ||
|
||
if len(base_urls) > 1: | ||
raise ValueError("Multiple base urls found in environment variables") | ||
|
||
kwargs["base_url"] = base_urls.pop() | ||
return WebArenaInstanceVars(**kwargs) | ||
|
||
def clone(self): | ||
"""Return a deep copy of the instance""" | ||
return deepcopy(self) | ||
|
||
|
||
def unimport_modules(base_name): | ||
"""un-import any module starting with base_name""" | ||
for module in sys.modules.copy(): | ||
if module.startswith(base_name): | ||
del sys.modules[module] | ||
|
||
|
||
def _split_url(url: str): | ||
"""Extract the base url and the port/page from a url""" | ||
parts = url.split(":") | ||
base_url = ":".join(parts[0:2]) | ||
url_tail = ":".join(parts[2:]) | ||
return base_url, url_tail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from agentlab.experiments.multi_server import WebArenaInstanceVars | ||
from browsergym.webarena.instance import WebArenaInstance | ||
|
||
|
||
def test_webarena_multiserver(): | ||
|
||
instance_1 = WebArenaInstanceVars( | ||
base_url="http://webarena1.eastus.cloudapp.azure.com", | ||
shopping="8082/", | ||
shopping_admin="8083/admin", | ||
reddit="8080", | ||
gitlab="9001", | ||
wikipedia="8081/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing", | ||
map="443", | ||
homepage="80", | ||
full_reset="7565", | ||
module_name="webarena", | ||
prefix="WA_", | ||
) | ||
|
||
instance_1.init() | ||
|
||
bgym_instance = WebArenaInstance() | ||
base_url_1 = bgym_instance.urls["reddit"].rsplit(":", 1)[0] | ||
assert base_url_1 == instance_1.base_url | ||
|
||
instance_2 = instance_1.clone() | ||
instance_2.base_url = "http://webarena2.eastus.cloudapp.azure.com" | ||
instance_2.init() | ||
|
||
bgym_instance = WebArenaInstance() | ||
base_url_2 = bgym_instance.urls["reddit"].rsplit(":", 1)[0] | ||
assert base_url_2 == instance_2.base_url | ||
|
||
|
||
if __name__ == "__main__": | ||
test_webarena_multiserver() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o | ||
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs | ||
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs | ||
from agentlab.experiments.study import ParallelStudies, make_study, Study | ||
from agentlab.experiments.multi_server import WebArenaInstanceVars | ||
|
||
|
||
def _make_agent_args_list(): | ||
# CheatMiniWoB agents won't succeed on WebArena, this is just for testing parallelization | ||
agent_args_list = [] | ||
for i in range(2): | ||
agent_args = GenericAgentArgs( | ||
chat_model_args=CheatMiniWoBLLMArgs(), | ||
flags=FLAGS_GPT_4o, | ||
) | ||
|
||
agent_args.agent_name = agent_args.agent_name + f"_{i}" | ||
agent_args_list.append(agent_args) | ||
return agent_args_list | ||
|
||
|
||
@pytest.mark.skip(reason="This test requires WebArena instances to be running") | ||
def manual_test_launch_parallel_study_webarena(): | ||
agent_args_list = _make_agent_args_list() | ||
|
||
server_instance_1 = WebArenaInstanceVars.from_env_vars() | ||
server_instance_2 = server_instance_1.clone() | ||
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com" | ||
parallel_servers = [server_instance_1, server_instance_2] | ||
|
||
for server in parallel_servers: | ||
print(server) | ||
|
||
study = make_study( | ||
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers | ||
) | ||
assert isinstance(study, ParallelStudies) | ||
|
||
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1) | ||
|
||
|
||
def test_launch_parallel_study(): | ||
agent_args_list = _make_agent_args_list() | ||
|
||
study = make_study(agent_args_list, benchmark="miniwob_tiny_test", parallel_servers=2) | ||
assert isinstance(study, ParallelStudies) | ||
|
||
study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1) | ||
_, summary_df, _ = study.get_results() | ||
assert len(summary_df) == 2 | ||
for n_completed in summary_df["n_completed"]: | ||
assert n_completed == "4/4" | ||
|
||
|
||
if __name__ == "__main__": | ||
# test_launch_parallel_study() | ||
manual_test_launch_parallel_study_webarena() |