Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisualWebArena agent #6

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ contexttimer
ipython
pyyaml>=6
pandas
requests
pillow
gradio
gitpython # for the reproducibility script
requests
2 changes: 1 addition & 1 deletion src/agentlab/agents/most_basic_agent/most_basic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import bgym

from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.chat_api import make_system_message, make_user_message
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry
from agentlab.llm.tracking import cost_tracker_decorator
from agentlab.agents.agent_args import AgentArgs

if TYPE_CHECKING:
from agentlab.llm.chat_api import BaseModelArgs
Expand Down
187 changes: 187 additions & 0 deletions src/agentlab/agents/visualwebarena/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import base64
import dataclasses
import io
import re
import tempfile
from io import BytesIO

from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments import Agent, AgentInfo
from browsergym.utils.obs import flatten_axtree_to_str, overlay_som
from PIL import Image

from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message
from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry


def pil_to_b64(img: Image.Image) -> str:
with BytesIO() as image_buffer:
img.save(image_buffer, format="PNG")
byte_data = image_buffer.getvalue()
img_b64 = base64.b64encode(byte_data).decode("utf-8")
img_b64 = "data:image/png;base64," + img_b64
return img_b64


def b64_to_pil(img_b64: str) -> str:
if not img_b64.startswith("data:image/png;base64,"):
raise ValueError(f"Unexpected base64 encoding: {img_b64}")
img_b64 = img_b64.removeprefix("data:image/png;base64,")
img_data = base64.b64decode(img_b64)
img = Image.open(io.BytesIO(img_data))
return img


class VWAAgent(Agent):
"""
Re-implementation of the web agent from VisualWebArena.
Credits to Lawrence Jang (@ljang0)
https://github.com/web-arena-x/visualwebarena/blob/main/agent/agent.py
"""

action_set = HighLevelActionSet(
subsets=["chat", "bid", "infeas", "nav", "tab"],
strict=False,
multiaction=False,
demo_mode="off",
)

def obs_preprocessor(self, obs: dict) -> dict:
return {
"goal_object": obs["goal_object"],
"last_action": obs["last_action"],
"axtree_txt": flatten_axtree_to_str(
obs["axtree_object"], obs["extra_element_properties"]
),
"extra_properties": obs["extra_element_properties"],
"url": obs["url"],
"screenshot": obs["screenshot"],
}

def __init__(self, chat_model_args: BaseModelArgs, n_retry: int) -> None:
super().__init__()
self.model_name = chat_model_args.model_name
self.chat_llm = chat_model_args.make_model()
self.n_retry = n_retry

self.goal_images = None

def get_action(self, obs: dict) -> tuple[str, dict]:

system_prompt = f"""\
Review the current state of the page and all other information to find the best
possible next action to accomplish your goal. Your answer will be interpreted
and executed by a program, make sure to follow the formatting instructions."""

user_prompt = f"""\
# Goal:
{obs["goal_object"][0]["text"]}

# Current Accessibility Tree:
{obs["axtree_txt"]}

# Action Space
{self.action_set.describe(with_long_description=False, with_examples=True)}

Here is an example with chain of thought of a valid action when clicking on a button:
"
In order to accomplish my goal I need to click on the button with bid 12
```click("12")```
"

If you have completed the task, use the chat to return an answer. For example, if you are asked what is the color of the sky, return
"
```send_msg_to_user("blue")```
"
"""
# prompt
user_msgs = [{"type": "text", "text": user_prompt}]

# screenshot
user_msgs = [
{
"type": "text",
"text": "IMAGES: current page screenshot",
},
{
"type": "image_url",
"image_url": {
"url": pil_to_b64(
Image.fromarray(overlay_som(obs["screenshot"], obs["extra_properties"]))
)
},
},
]
# additional images
user_msgs.extend(obs["goal_object"][1:])

messages = [
make_system_message(system_prompt),
make_user_message(user_prompt),
]

def parser(response: str) -> tuple[dict, bool, str]:
pattern = r"```((.|\\n)*?)```"
match = re.search(pattern, response)
if not match:
raise ParseError("No code block found in the response")
action = match.group(1).strip()
thought = response
return {"action": action, "think": thought}

response = retry(self.chat_llm, messages, n_retry=self.n_retry, parser=parser)

action = response.get("action", None)
stats = dict(response.usage)
return action, AgentInfo(
chat_messages=messages,
think=response.get("think", None),
stats=stats,
)


@dataclasses.dataclass
class VWAAgentArgs(AgentArgs):
"""
This class is meant to store the arguments that define the agent.

By isolating them in a dataclass, this ensures serialization without storing
internal states of the agent.
"""

agent_name: str = "vwa"
temperature: float = 0.1
chat_model_args: BaseModelArgs = None

def make_agent(self):
return VWAAgent()


CONFIG = VWAAgentArgs(model_name="gpt-4-1106-vision-preview")


def main():
from pathlib import Path

from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result

exp_args = ExpArgs(
agent_args=VWAAgentArgs(model_name="gpt-4-1106-preview"),
env_args=EnvArgs(
task_name="visualwebarena.423",
task_seed=42,
headless=False, # shows the browser
),
)
exp_args.prepare(exp_root=Path("./results"))
exp_args.run()
exp_result = get_exp_result(exp_args.exp_dir)
exp_record = exp_result.get_exp_record()

for key, val in exp_record.items():
print(f"{key}: {val}")


if __name__ == "__main__":
main()
32 changes: 28 additions & 4 deletions src/agentlab/experiments/study_generators.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import logging
from dataclasses import dataclass
from datetime import datetime
import logging
from pathlib import Path

from bgym import ExpArgs, EnvArgs
from bgym import EnvArgs, ExpArgs

from agentlab.agents.agent_args import AgentArgs
from agentlab.agents.generic_agent.agent_configs import RANDOM_SEARCH_AGENT, AGENT_4o_MINI
from agentlab.analyze import inspect_results
from agentlab.experiments import args
from agentlab.experiments import reproducibility_util as repro
from agentlab.experiments import task_collections as tasks
from agentlab.experiments.launch_exp import run_experiments, relaunch_study
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments import reproducibility_util as repro
from agentlab.experiments.launch_exp import relaunch_study, run_experiments


@dataclass
Expand Down Expand Up @@ -267,3 +267,27 @@ def random_search(
study = run_agents_on_benchmark(agents, benchmark, demo_mode=demo_mode)
study.suffix = "random_search"
return study


def final_run_vwa(agent: AgentArgs = AGENT_4o_MINI, benchmark: str = "miniwob"):
# agent.flags = miniwob_add_html(benchmark, agent.flags)

env_args_list_reset, env_args_list_no_reset = tasks.get_benchmark_env_args(
"visualwebarena", meta_seed=43, max_steps=None, n_repeat=None, is_agent_curriculum=False
)

return args.expand_cross_product(
ExpArgs(
agent_args=args.CrossProd([agent]),
env_args=args.CrossProd(env_args_list_reset),
enable_debug=False,
logging_level=logging.DEBUG,
)
), args.expand_cross_product(
ExpArgs(
agent_args=args.CrossProd([agent]),
env_args=args.CrossProd(env_args_list_no_reset),
enable_debug=False,
logging_level=logging.DEBUG,
)
)
14 changes: 14 additions & 0 deletions src/agentlab/experiments/task_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from browsergym.experiments import EnvArgs
from browsergym.webarena import ALL_WEBARENA_TASK_IDS
from browsergym.visualwebarena import (
VISUALWEBARENA_TASK_IDS_WITH_RESET,
VISUALWEBARENA_TASK_IDS_WITHOUT_RESET,
)

df = pd.read_csv(Path(__file__).parent / "miniwob_tasks_all.csv")
# append miniwob. to task_name column
Expand Down Expand Up @@ -122,6 +126,7 @@ def get_benchmark_env_args(
"workarena.l2": 50,
"workarena.l3": 50,
"webarena": 15,
"visualwebarena": 15,
"miniwob": 10,
"miniwob_tiny_test": 5,
}
Expand All @@ -131,6 +136,7 @@ def get_benchmark_env_args(
"workarena.l2": 1,
"workarena.l3": 1,
"webarena": 1,
"visualwebarena": 1,
"miniwob": 5,
"miniwob_tiny_test": 2,
}
Expand Down Expand Up @@ -176,6 +182,14 @@ def get_benchmark_env_args(
from browsergym.webarena import ALL_WEBARENA_TASK_IDS

env_args_list = _make_env_args(ALL_WEBARENA_TASK_IDS, max_steps, n_repeat, rng)
elif benchmark_name == "visualwebarena":
env_args_list_reset = _make_env_args(
VISUALWEBARENA_TASK_IDS_WITH_RESET, max_steps, n_repeat, rng
)
env_args_list_no_reset = _make_env_args(
VISUALWEBARENA_TASK_IDS_WITHOUT_RESET, max_steps, n_repeat, rng
)
env_args_list = (env_args_list_reset, env_args_list_no_reset)
elif benchmark_name.startswith("miniwob"):
miniwob_benchmarks_map = {
"miniwob": MINIWOB_ALL,
Expand Down