From 9afc0e2a3ed086e3fb3e6ee235da4635e7b3c17e Mon Sep 17 00:00:00 2001 From: Thibault LSDC <78021491+ThibaultLSDC@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:32:01 -0400 Subject: [PATCH 1/3] 0.2.2 Release (#67) * downgrading ubuntu version for github tests (#62) * Llm api update (#59) * getting rid of .invoke() * adding an AbstractChatModel * changing chat_api structure * Reproducibility again (#61) * core functions * switch to dask * removing joblib dependency and adding dask * fixing imports * handles multiple backends * ensure asyncio loop creation * more tests * setting dashboard address to None * minor * Finally found a way to make it work * initial reproducibility files * Seems to be superflus * adding a reproducibility journal * minor update * more robust * adding reproducibility tools * fix white listing * minor * minor * minor * minor * minor fix * more tests * more results yay * disabling this test * update * update * black * maybe fixing github workflow ? * make get_git_username great again * trigger change * new browsergym * GPT-4o result (and new comment column) * Seems like there was a change to 4o flags, trying these * minor comment * better xray * minor fix * addming a comment field * new agent * another test with GPT-4o * adding llama3 from openrouter * fix naming * unused import * new summary tools and remove "_args" from columns in results * add Llama * initial code for reproducibility agent * adjust inspect results * infer from benchmark * fix reproducibility agent * prevent the repro_dir to be an index variable * updating repro agent stats * Reproducibility agent * instructions to setup workarena * fixing tests * handles better a few edge cases * default progress function to None * minor formatting * minor * initial commit * refactoring with Study class * refactor to adapt for study class * minor * fix pricy test * fixing tests * tmp * print report * minor fix * refine little details about reproducibility * minor * no need for set_temp anymore * sanity check before running main * minor update * minor * new results with 4o on workarena.l1 * sharing is caring * add llama to main.py * new hournal entry * lamma 3 70B * minor * typo * black fix (wasn't configured) --------- Co-authored-by: Thibault Le Sellier de Chezelles * version bump --------- Co-authored-by: Alexandre Lacoste --- .github/workflows/unit_tests.yml | 2 +- reproducibility_journal.csv | 1 + src/agentlab/__init__.py | 2 +- .../generic_agent/reproducibility_agent.py | 22 +++- src/agentlab/experiments/reproduce_study.py | 8 +- src/agentlab/experiments/study_generators.py | 7 +- src/agentlab/llm/base_api.py | 33 ++++++ src/agentlab/llm/chat_api.py | 67 ++++++------ src/agentlab/llm/huggingface_utils.py | 101 +----------------- src/agentlab/llm/llm_utils.py | 2 +- tests/agents/test_agent.py | 10 +- tests/llm/test_chat_api.py | 4 +- tests/llm/test_llm_utils.py | 23 ++-- tests/llm/test_tracking.py | 6 +- 14 files changed, 114 insertions(+), 174 deletions(-) create mode 100644 src/agentlab/llm/base_api.py diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 1b09ad6e..a6b44f87 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -9,7 +9,7 @@ on: jobs: agentlab: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 defaults: run: diff --git a/reproducibility_journal.csv b/reproducibility_journal.csv index 8eeb33ba..df2ff747 100644 --- a/reproducibility_journal.csv +++ b/reproducibility_journal.csv @@ -9,3 +9,4 @@ recursix,GenericAgent-gpt-4o-mini-2024-07-18,miniwob,0.6.3,2024-10-01_11-45-23,0 recursix,GenericAgent-gpt-4o-mini-2024-07-18,workarena.l1,0.3.2,2024-10-05_13-21-27,0.23,0.023,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,aadf86b397cd36c581e1a61e491aec649ac5a140, M: main.py,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, recursix,GenericAgent-gpt-4o-2024-05-13,workarena.l1,0.3.2,2024-10-05_15-45-42,0.382,0.027,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,ab447e997af589bbd022de7a5189a7685ddfa6ef,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, recursix,GenericAgent-meta-llama_llama-3.1-70b-instruct,miniwob_tiny_test,0.7.0,2024-10-05_17-49-15,1.0,0.0,0,4/4,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,a98fa24426a6ddde8443e8be44ed94cd9522e5ca,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, +recursix,GenericAgent-meta-llama_llama-3-70b-instruct,workarena.l1,0.3.2,2024-10-09_21-16-37,0.176,0.021,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.1,c847dbd334184271b32b252409a1b6c1042d7442,,0.7.0,2a0ab7e8e8795f8ca35fe4d4d67c6892d635dc12, diff --git a/src/agentlab/__init__.py b/src/agentlab/__init__.py index 3ced3581..b5fdc753 100644 --- a/src/agentlab/__init__.py +++ b/src/agentlab/__init__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index b484ac7d..c197b76e 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -43,7 +43,7 @@ def __init__(self, old_messages, delay=1) -> None: self.old_messages = old_messages self.delay = delay - def invoke(self, messages: list): + def __call__(self, messages: list): self.new_messages = copy(messages) if len(messages) >= len(self.old_messages): @@ -56,6 +56,9 @@ def invoke(self, messages: list): # return the next message in the list return old_response + def get_stats(self): + return {} + @dataclass class ReproAgentArgs(GenericAgentArgs): @@ -102,6 +105,14 @@ def get_action(self, obs): ) return None, agent_info + # an old bug prevented the response from being saved. + if len(old_chat_messages) == 2: + recorded_action = step_info.action + if recorded_action: + # Recreate the 3rd message based on the recorded action + assistant_message = make_assistant_message(f"{recorded_action}") + old_chat_messages.append(assistant_message) + self.chat_llm = ReproChatModel(old_chat_messages) action, agent_info = super().get_action(obs) @@ -128,27 +139,28 @@ def _format_messages(messages: list[dict]): return "\n".join(f"{m['role']} message:\n{m['content']}\n" for m in messages) -def reproduce_study(original_study_dir: Path | str): +def reproduce_study(original_study_dir: Path | str, log_level=logging.INFO): """Reproduce a study by running the same experiments with the same agent.""" original_study_dir = Path(original_study_dir) study_name = f"reproducibility_of_{original_study_dir.name}" - exp_args_list = [] + exp_args_list: list[ExpArgs] = [] for exp_result in yield_all_exp_results(original_study_dir, progress_fn=None): agent_args = make_repro_agent(exp_result.exp_args.agent_args, exp_dir=exp_result.exp_dir) exp_args_list.append( ExpArgs( agent_args=agent_args, env_args=exp_result.exp_args.env_args, - logging_level=logging.DEBUG, + logging_level=log_level, ) ) + benchmark_name = exp_args_list[0].env_args.task_name.split(".")[0] return Study( exp_args_list=exp_args_list, - benchmark_name="repro_study", + benchmark_name=benchmark_name, agent_names=[agent_args.agent_name], ) diff --git a/src/agentlab/experiments/reproduce_study.py b/src/agentlab/experiments/reproduce_study.py index 3c2dd0ae..93ef07fb 100644 --- a/src/agentlab/experiments/reproduce_study.py +++ b/src/agentlab/experiments/reproduce_study.py @@ -5,18 +5,14 @@ the diff in HTML format. """ -import logging - from agentlab.agents.generic_agent.reproducibility_agent import reproduce_study from agentlab.experiments.exp_utils import RESULTS_DIR -logging.getLogger().setLevel(logging.INFO) - if __name__ == "__main__": - old_study = "2024-06-02_18-16-17_final_run" - # old_study = "2024-09-12_08-39-16_GenericAgent-gpt-4o-mini_on_miniwob_tiny_test" + # old_study = "2024-06-03_13-53-50_final_run_workarena_L1_llama3-70b" + old_study = "2024-06-03_12-28-51_final_run_miniwob_llama3-70b" study = reproduce_study(RESULTS_DIR / old_study) n_jobs = 1 diff --git a/src/agentlab/experiments/study_generators.py b/src/agentlab/experiments/study_generators.py index e079ba7f..3a2567d5 100644 --- a/src/agentlab/experiments/study_generators.py +++ b/src/agentlab/experiments/study_generators.py @@ -153,7 +153,10 @@ def set_demo_mode(env_args_list: list[EnvArgs]): def run_agents_on_benchmark( - agents: list[AgentArgs] | AgentArgs = AGENT_4o_MINI, benchmark: str = "miniwob", demo_mode=False + agents: list[AgentArgs] | AgentArgs = AGENT_4o_MINI, + benchmark: str = "miniwob", + demo_mode=False, + log_level=logging.INFO, ): """Run one or multiple agents on a benchmark. @@ -190,7 +193,7 @@ def run_agents_on_benchmark( ExpArgs( agent_args=args.CrossProd(agents), env_args=args.CrossProd(env_args_list), - logging_level=logging.DEBUG, + logging_level=log_level, ) ) diff --git a/src/agentlab/llm/base_api.py b/src/agentlab/llm/base_api.py new file mode 100644 index 00000000..9c1ebf5f --- /dev/null +++ b/src/agentlab/llm/base_api.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class AbstractChatModel(ABC): + @abstractmethod + def __call__(self, messages: list[dict]) -> dict: + pass + + def get_stats(self): + return {} + + +@dataclass +class BaseModelArgs(ABC): + """Base class for all model arguments.""" + + model_name: str + max_total_tokens: int = None + max_input_tokens: int = None + max_new_tokens: int = None + temperature: float = 0.1 + vision_support: bool = False + + @abstractmethod + def make_model(self) -> AbstractChatModel: + pass + + def prepare_server(self): + pass + + def close_server(self): + pass diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 139b2ca5..a4df0a97 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -2,14 +2,17 @@ import os import re import time -from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import partial +from typing import Optional import openai +from huggingface_hub import InferenceClient from openai import AzureOpenAI, OpenAI import agentlab.llm.tracking as tracking -from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel +from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs +from agentlab.llm.huggingface_utils import HFBaseChatModel def make_system_message(content: str) -> dict: @@ -24,10 +27,10 @@ def make_assistant_message(content: str) -> dict: return dict(role="assistant", content=content) -class CheatMiniWoBLLM: +class CheatMiniWoBLLM(AbstractChatModel): """For unit-testing purposes only. It only work with miniwob.click-test task.""" - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: prompt = messages[-1]["content"] match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) @@ -44,12 +47,6 @@ def invoke(self, messages) -> str: """ return make_assistant_message(answer) - def __call__(self, messages) -> str: - return self.invoke(messages) - - def get_stats(self): - return {} - @dataclass class CheatMiniWoBLLMArgs: @@ -68,28 +65,6 @@ def close_server(self): pass -@dataclass -class BaseModelArgs(ABC): - """Base class for all model arguments.""" - - model_name: str - max_total_tokens: int = None - max_input_tokens: int = None - max_new_tokens: int = None - temperature: float = 0.1 - vision_support: bool = False - - @abstractmethod - def make_model(self) -> "ChatModel": - pass - - def prepare_server(self): - pass - - def close_server(self): - pass - - @dataclass class OpenRouterModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI @@ -221,7 +196,7 @@ def handle_error(error, itr, min_retry_wait_time, max_retry): return error_type -class ChatModel: +class ChatModel(AbstractChatModel): def __init__( self, model_name, @@ -310,9 +285,6 @@ def __call__(self, messages: list[dict]) -> dict: return make_assistant_message(completion.choices[0].message.content) - def invoke(self, messages: list[dict]) -> dict: - return self(messages) - def get_stats(self): return { "n_retry_llm": self.retries, @@ -401,3 +373,26 @@ def __init__( client_args=client_args, pricing_func=tracking.get_pricing_openai, ) + + +class HuggingFaceURLChatModel(HFBaseChatModel): + def __init__( + self, + model_name: str, + model_url: str, + token: Optional[str] = None, + temperature: Optional[int] = 1e-1, + max_new_tokens: Optional[int] = 512, + n_retry_server: Optional[int] = 4, + ): + super().__init__(model_name, n_retry_server) + if temperature < 1e-3: + logging.warning("Models might behave weirdly when temperature is too low.") + + if token is None: + token = os.environ["TGI_TOKEN"] + + client = InferenceClient(model=model_url, token=token) + self.llm = partial( + client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens + ) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index ce4dae06..470324bd 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,17 +1,15 @@ import logging -import os import time -from functools import partial from typing import Any, List, Optional -from huggingface_hub import InferenceClient from pydantic import Field from transformers import AutoTokenizer, GPT2TokenizerFast +from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -class HFBaseChatModel: +class HFBaseChatModel(AbstractChatModel): """ Custom LLM Chatbot that can interface with HuggingFace models. @@ -94,101 +92,6 @@ def __call__( def _llm_type(self): return "huggingface" - def invoke(self, messages: list[dict]) -> dict: - return self(messages) - - def get_stats(self): - return {} - - -class HuggingFaceURLChatModel(HFBaseChatModel): - def __init__( - self, - model_name: str, - model_url: str, - token: Optional[str] = None, - temperature: Optional[int] = 1e-1, - max_new_tokens: Optional[int] = 512, - n_retry_server: Optional[int] = 4, - ): - super().__init__(model_name, n_retry_server) - if temperature < 1e-3: - logging.warning("Models might behave weirdly when temperature is too low.") - - if token is None: - token = os.environ["TGI_TOKEN"] - - client = InferenceClient(model=model_url, token=token) - self.llm = partial( - client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens - ) - - -# def _convert_messages_to_dict(messages, column_remap={}): -# """ -# Converts a list of message objects into a list of dictionaries, categorizing each message by its role. - -# Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. -# The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. - -# Args: -# messages (list): A list of message objects. -# column_remap (dict): A dictionary that maps the column names to the desired output format. - -# Returns: -# list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. - -# Raises: -# ValueError: If an unsupported message type is encountered. - -# Example: -# >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] -# >>> _convert_messages_to_dict(messages) -# [ -# {"role": "system", "content": "System initializing..."}, -# {"role": "user", "content": "Hello!"}, -# {"role": "assistant", "content": "How can I assist?"} -# ] -# """ - -# human_key = column_remap.get("HumanMessage", "user") -# ai_message_key = column_remap.get("AIMessage", "assistant") -# role_key = column_remap.get("role", "role") -# text_key = column_remap.get("text", "content") -# image_key = column_remap.get("image", "media_url") - -# # Mapping of message types to roles -# message_type_to_role = { -# SystemMessage: "system", -# HumanMessage: human_key, -# AIMessage: ai_message_key, -# } - -# def convert_format_vision(message_content, role, text_key, image_key): -# result = {} -# result["type"] = role -# for item in message_content: -# if item["type"] == "text": -# result[text_key] = item["text"] -# elif item["type"] == "image_url": -# result[image_key] = item["image_url"] -# return result - -# chat = [] -# for message in messages: -# message_role = message_type_to_role.get(type(message)) -# if message_role: -# if isinstance(message.content, str): -# chat.append({role_key: message_role, text_key: message.content}) -# else: -# chat.append( -# convert_format_vision(message.content, message_role, text_key, image_key) -# ) -# else: -# raise ValueError(f"Message type {type(message)} not supported") - -# return chat - def _prepend_system_to_first_user(messages, column_remap={}): # Initialize an index for the system message diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 4b876b54..c3d75009 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -79,7 +79,7 @@ def retry( """ tries = 0 while tries < n_retry: - answer = chat.invoke(messages) + answer = chat(messages) messages.append(answer) # TODO: could we change this to not use inplace modifications ? try: diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index ae173289..0b2c31f2 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -50,7 +50,7 @@ class CheatMiniWoBLLM_ParseRetry: n_retry: int retry_count: int = 0 - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: if self.retry_count < self.n_retry: self.retry_count += 1 return dict(role="assistant", content="I'm retrying") @@ -71,9 +71,6 @@ def invoke(self, messages) -> str: """ return dict(role="assistant", content=answer) - def __call__(self, messages) -> str: - return self.invoke(messages) - def get_stats(self): return {} @@ -94,7 +91,7 @@ class CheatLLM_LLMError: n_retry: int = 0 success: bool = False - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: if self.success: prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) @@ -113,9 +110,6 @@ def invoke(self, messages) -> str: return dict(role="assistant", content=answer) raise OpenAIError("LLM failed to respond") - def __call__(self, messages) -> str: - return self.invoke(messages) - def get_stats(self): return {"n_llm_retry": self.n_retry, "n_llm_busted_retry": int(not self.success)} diff --git a/tests/llm/test_chat_api.py b/tests/llm/test_chat_api.py index b49f3588..f06fa7fa 100644 --- a/tests/llm/test_chat_api.py +++ b/tests/llm/test_chat_api.py @@ -35,7 +35,7 @@ def test_api_model_args_azure(): make_system_message("You are an helpful virtual assistant"), make_user_message("Give the third prime number"), ] - answer = model.invoke(messages) + answer = model(messages) assert "5" in answer.get("content") @@ -56,6 +56,6 @@ def test_api_model_args_openai(): make_system_message("You are an helpful virtual assistant"), make_user_message("Give the third prime number"), ] - answer = model.invoke(messages) + answer = model(messages) assert "5" in answer.get("content") diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 1314bea0..7e5bb87c 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -93,9 +93,12 @@ def test_compress_string(): # Mock ChatOpenAI class class MockChatOpenAI: - def invoke(self, messages): + def call(self, messages): return "mocked response" + def __call__(self, messages): + return self.call(messages) + def mock_parser(answer): if answer == "correct content": @@ -126,7 +129,7 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # Test to ensure function stops retrying after reaching the max wait time # def test_rate_limit_max_wait_time(): # mock_chat = MockChatOpenAI() -# mock_chat.invoke = Mock( +# mock_chat.call = Mock( # side_effect=mock_rate_limit_error("Rate limit reached. Please try again in 2s.") # ) @@ -141,12 +144,12 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # ) # # The function should stop retrying after 2 attempts (6s each time, 12s total which is greater than the 10s max wait time) -# assert mock_chat.invoke.call_count == 3 +# assert mock_chat.call.call_count == 3 # def test_rate_limit_success(): # mock_chat = MockChatOpenAI() -# mock_chat.invoke = Mock( +# mock_chat.call = Mock( # side_effect=[ # mock_rate_limit_error("Rate limit reached. Please try again in 2s."), # make_system_message("correct content"), @@ -163,7 +166,7 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # ) # assert result == "Parsed value" -# assert mock_chat.invoke.call_count == 2 +# assert mock_chat.call.call_count == 2 # Mock a successful parser response to test function exit before max retries @@ -172,7 +175,7 @@ def test_successful_parse_before_max_retries(): # mock a chat that returns the wrong content the first 2 time, but the right # content on the 3rd time - mock_chat.invoke = Mock( + mock_chat.call = Mock( side_effect=[ make_system_message("wrong content"), make_system_message("wrong content"), @@ -183,7 +186,7 @@ def test_successful_parse_before_max_retries(): result = llm_utils.retry(mock_chat, [], 5, mock_parser) assert result == "Parsed value" - assert mock_chat.invoke.call_count == 3 + assert mock_chat.call.call_count == 3 def test_unsuccessful_parse_before_max_retries(): @@ -191,7 +194,7 @@ def test_unsuccessful_parse_before_max_retries(): # mock a chat that returns the wrong content the first 2 time, but the right # content on the 3rd time - mock_chat.invoke = Mock( + mock_chat.call = Mock( side_effect=[ make_system_message("wrong content"), make_system_message("wrong content"), @@ -201,12 +204,12 @@ def test_unsuccessful_parse_before_max_retries(): with pytest.raises(llm_utils.ParseError): result = llm_utils.retry(mock_chat, [], 2, mock_parser) - assert mock_chat.invoke.call_count == 2 + assert mock_chat.call.call_count == 2 def test_retry_parse_raises(): mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock(return_value=make_system_message("mocked response")) + mock_chat.call = Mock(return_value=make_system_message("mocked response")) parser_raises = Mock(side_effect=ValueError("Parser error")) with pytest.raises(ValueError): diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index cc5abd36..01ebcc06 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -136,7 +136,7 @@ def test_openai_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -161,7 +161,7 @@ def test_azure_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -178,6 +178,6 @@ def test_openrouter_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 From 7384d4998067be381c309b61ef401b954b628450 Mon Sep 17 00:00:00 2001 From: xhluca Date: Wed, 16 Oct 2024 15:30:21 -0400 Subject: [PATCH 2/3] Make share=TRue into a environment variable, disabled by default for security --- src/agentlab/analyze/agent_xray.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 228901b3..4865996e 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -482,7 +482,9 @@ def run_gradio(results_dir: Path): tabs.select(tab_select) demo.queue() - demo.launch(server_port=int(os.getenv("AGENTXRAY_APP_PORT", 7899)), share=True) + + do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", 'false').lower() == 'true' + demo.launch(server_port=int(os.getenv("AGENTXRAY_APP_PORT", "7899")), share=do_share) def tab_select(evt: gr.SelectData): From a332b77984083fcf8e1ad52086adb23962afbf3e Mon Sep 17 00:00:00 2001 From: xhluca Date: Wed, 16 Oct 2024 17:51:14 -0400 Subject: [PATCH 3/3] fix floating point issue with std_reward in agent xray --- src/agentlab/analyze/inspect_results.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/agentlab/analyze/inspect_results.py b/src/agentlab/analyze/inspect_results.py index 7d46113c..6a172f3d 100644 --- a/src/agentlab/analyze/inspect_results.py +++ b/src/agentlab/analyze/inspect_results.py @@ -247,6 +247,7 @@ def get_std_err(df, metric): std_err = np.sqrt(mean * (1 - mean) / len(data)) else: return get_sample_std_err(df, metric) + return mean, std_err @@ -258,7 +259,7 @@ def get_sample_std_err(df, metric): mean = np.mean(data) std_err = np.std(data, ddof=1) / np.sqrt(len(data)) if np.isnan(std_err): - std_err = 0 + std_err = np.zeros_like(std_err) return mean, std_err @@ -289,7 +290,7 @@ def summarize(sub_df, use_bootstrap=False): record = dict( avg_reward=sub_df["cum_reward"].mean(skipna=True).round(3), - std_err=std_reward.round(3), + std_err=std_reward.astype(float).round(3), # avg_raw_reward=sub_df["cum_raw_reward"].mean(skipna=True).round(3), avg_steps=sub_df["n_steps"].mean(skipna=True).round(3), n_completed=f"{n_completed}/{len(sub_df)}",