diff --git a/requirements.txt b/requirements.txt index e96fa61..249d5b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,8 @@ contexttimer ipython pyyaml>=6 pandas +requests +pillow gradio gitpython # for the reproducibility script requests \ No newline at end of file diff --git a/src/agentlab/agents/most_basic_agent/most_basic_agent.py b/src/agentlab/agents/most_basic_agent/most_basic_agent.py index 2e0cfcb..ec0efa6 100644 --- a/src/agentlab/agents/most_basic_agent/most_basic_agent.py +++ b/src/agentlab/agents/most_basic_agent/most_basic_agent.py @@ -4,11 +4,11 @@ import bgym +from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry from agentlab.llm.tracking import cost_tracker_decorator -from agentlab.agents.agent_args import AgentArgs if TYPE_CHECKING: from agentlab.llm.chat_api import BaseModelArgs diff --git a/src/agentlab/agents/visualwebarena/agent.py b/src/agentlab/agents/visualwebarena/agent.py new file mode 100644 index 0000000..d89fa2f --- /dev/null +++ b/src/agentlab/agents/visualwebarena/agent.py @@ -0,0 +1,187 @@ +import base64 +import dataclasses +import io +import re +import tempfile +from io import BytesIO + +from browsergym.core.action.highlevel import HighLevelActionSet +from browsergym.experiments import Agent, AgentInfo +from browsergym.utils.obs import flatten_axtree_to_str, overlay_som +from PIL import Image + +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message +from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry + + +def pil_to_b64(img: Image.Image) -> str: + with BytesIO() as image_buffer: + img.save(image_buffer, format="PNG") + byte_data = image_buffer.getvalue() + img_b64 = base64.b64encode(byte_data).decode("utf-8") + img_b64 = "data:image/png;base64," + img_b64 + return img_b64 + + +def b64_to_pil(img_b64: str) -> str: + if not img_b64.startswith("data:image/png;base64,"): + raise ValueError(f"Unexpected base64 encoding: {img_b64}") + img_b64 = img_b64.removeprefix("data:image/png;base64,") + img_data = base64.b64decode(img_b64) + img = Image.open(io.BytesIO(img_data)) + return img + + +class VWAAgent(Agent): + """ + Re-implementation of the web agent from VisualWebArena. + Credits to Lawrence Jang (@ljang0) + https://github.com/web-arena-x/visualwebarena/blob/main/agent/agent.py + """ + + action_set = HighLevelActionSet( + subsets=["chat", "bid", "infeas", "nav", "tab"], + strict=False, + multiaction=False, + demo_mode="off", + ) + + def obs_preprocessor(self, obs: dict) -> dict: + return { + "goal_object": obs["goal_object"], + "last_action": obs["last_action"], + "axtree_txt": flatten_axtree_to_str( + obs["axtree_object"], obs["extra_element_properties"] + ), + "extra_properties": obs["extra_element_properties"], + "url": obs["url"], + "screenshot": obs["screenshot"], + } + + def __init__(self, chat_model_args: BaseModelArgs, n_retry: int) -> None: + super().__init__() + self.model_name = chat_model_args.model_name + self.chat_llm = chat_model_args.make_model() + self.n_retry = n_retry + + self.goal_images = None + + def get_action(self, obs: dict) -> tuple[str, dict]: + + system_prompt = f"""\ +Review the current state of the page and all other information to find the best +possible next action to accomplish your goal. Your answer will be interpreted +and executed by a program, make sure to follow the formatting instructions.""" + + user_prompt = f"""\ +# Goal: +{obs["goal_object"][0]["text"]} + +# Current Accessibility Tree: +{obs["axtree_txt"]} + +# Action Space +{self.action_set.describe(with_long_description=False, with_examples=True)} + +Here is an example with chain of thought of a valid action when clicking on a button: +" +In order to accomplish my goal I need to click on the button with bid 12 +```click("12")``` +" + +If you have completed the task, use the chat to return an answer. For example, if you are asked what is the color of the sky, return +" +```send_msg_to_user("blue")``` +" +""" + # prompt + user_msgs = [{"type": "text", "text": user_prompt}] + + # screenshot + user_msgs = [ + { + "type": "text", + "text": "IMAGES: current page screenshot", + }, + { + "type": "image_url", + "image_url": { + "url": pil_to_b64( + Image.fromarray(overlay_som(obs["screenshot"], obs["extra_properties"])) + ) + }, + }, + ] + # additional images + user_msgs.extend(obs["goal_object"][1:]) + + messages = [ + make_system_message(system_prompt), + make_user_message(user_prompt), + ] + + def parser(response: str) -> tuple[dict, bool, str]: + pattern = r"```((.|\\n)*?)```" + match = re.search(pattern, response) + if not match: + raise ParseError("No code block found in the response") + action = match.group(1).strip() + thought = response + return {"action": action, "think": thought} + + response = retry(self.chat_llm, messages, n_retry=self.n_retry, parser=parser) + + action = response.get("action", None) + stats = dict(response.usage) + return action, AgentInfo( + chat_messages=messages, + think=response.get("think", None), + stats=stats, + ) + + +@dataclasses.dataclass +class VWAAgentArgs(AgentArgs): + """ + This class is meant to store the arguments that define the agent. + + By isolating them in a dataclass, this ensures serialization without storing + internal states of the agent. + """ + + agent_name: str = "vwa" + temperature: float = 0.1 + chat_model_args: BaseModelArgs = None + + def make_agent(self): + return VWAAgent() + + +CONFIG = VWAAgentArgs(model_name="gpt-4-1106-vision-preview") + + +def main(): + from pathlib import Path + + from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result + + exp_args = ExpArgs( + agent_args=VWAAgentArgs(model_name="gpt-4-1106-preview"), + env_args=EnvArgs( + task_name="visualwebarena.423", + task_seed=42, + headless=False, # shows the browser + ), + ) + exp_args.prepare(exp_root=Path("./results")) + exp_args.run() + exp_result = get_exp_result(exp_args.exp_dir) + exp_record = exp_result.get_exp_record() + + for key, val in exp_record.items(): + print(f"{key}: {val}") + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/experiments/study_generators.py b/src/agentlab/experiments/study_generators.py index e079ba7..ece04a1 100644 --- a/src/agentlab/experiments/study_generators.py +++ b/src/agentlab/experiments/study_generators.py @@ -1,18 +1,18 @@ +import logging from dataclasses import dataclass from datetime import datetime -import logging from pathlib import Path -from bgym import ExpArgs, EnvArgs +from bgym import EnvArgs, ExpArgs from agentlab.agents.agent_args import AgentArgs from agentlab.agents.generic_agent.agent_configs import RANDOM_SEARCH_AGENT, AGENT_4o_MINI from agentlab.analyze import inspect_results from agentlab.experiments import args +from agentlab.experiments import reproducibility_util as repro from agentlab.experiments import task_collections as tasks -from agentlab.experiments.launch_exp import run_experiments, relaunch_study from agentlab.experiments.exp_utils import RESULTS_DIR -from agentlab.experiments import reproducibility_util as repro +from agentlab.experiments.launch_exp import relaunch_study, run_experiments @dataclass @@ -267,3 +267,27 @@ def random_search( study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode) study.suffix = "random_search" return study + + +def final_run_vwa(agent: AgentArgs = AGENT_4o_MINI, benchmark: str = "miniwob"): + # agent.flags = miniwob_add_html(benchmark, agent.flags) + + env_args_list_reset, env_args_list_no_reset = tasks.get_benchmark_env_args( + "visualwebarena", meta_seed=43, max_steps=None, n_repeat=None, is_agent_curriculum=False + ) + + return args.expand_cross_product( + ExpArgs( + agent_args=args.CrossProd([agent]), + env_args=args.CrossProd(env_args_list_reset), + enable_debug=False, + logging_level=logging.DEBUG, + ) + ), args.expand_cross_product( + ExpArgs( + agent_args=args.CrossProd([agent]), + env_args=args.CrossProd(env_args_list_no_reset), + enable_debug=False, + logging_level=logging.DEBUG, + ) + ) diff --git a/src/agentlab/experiments/task_collections.py b/src/agentlab/experiments/task_collections.py index 6e91488..99c0f11 100644 --- a/src/agentlab/experiments/task_collections.py +++ b/src/agentlab/experiments/task_collections.py @@ -9,6 +9,10 @@ from browsergym.experiments import EnvArgs from browsergym.webarena import ALL_WEBARENA_TASK_IDS +from browsergym.visualwebarena import ( + VISUALWEBARENA_TASK_IDS_WITH_RESET, + VISUALWEBARENA_TASK_IDS_WITHOUT_RESET, +) df = pd.read_csv(Path(__file__).parent / "miniwob_tasks_all.csv") # append miniwob. to task_name column @@ -122,6 +126,7 @@ def get_benchmark_env_args( "workarena.l2": 50, "workarena.l3": 50, "webarena": 15, + "visualwebarena": 15, "miniwob": 10, "miniwob_tiny_test": 5, } @@ -131,6 +136,7 @@ def get_benchmark_env_args( "workarena.l2": 1, "workarena.l3": 1, "webarena": 1, + "visualwebarena": 1, "miniwob": 5, "miniwob_tiny_test": 2, } @@ -176,6 +182,14 @@ def get_benchmark_env_args( from browsergym.webarena import ALL_WEBARENA_TASK_IDS env_args_list = _make_env_args(ALL_WEBARENA_TASK_IDS, max_steps, n_repeat, rng) + elif benchmark_name == "visualwebarena": + env_args_list_reset = _make_env_args( + VISUALWEBARENA_TASK_IDS_WITH_RESET, max_steps, n_repeat, rng + ) + env_args_list_no_reset = _make_env_args( + VISUALWEBARENA_TASK_IDS_WITHOUT_RESET, max_steps, n_repeat, rng + ) + env_args_list = (env_args_list_reset, env_args_list_no_reset) elif benchmark_name.startswith("miniwob"): miniwob_benchmarks_map = { "miniwob": MINIWOB_ALL,