From 80d4bf59cc7b0faef43b8e251b2002a3ee40f65f Mon Sep 17 00:00:00 2001 From: teocns <59549574+teocns@users.noreply.github.com> Date: Fri, 15 Nov 2024 19:28:12 -0600 Subject: [PATCH] ci: pre-commit hooks for static analysis (#510) * chore: add pre-commit for ruff & black * docs: add pre-commit setup instructions to CONTRIBUTING.md * chore: update ruff configuration in pyproject.toml * when you run pre-commit run --all-files, it should actually format the files instead of just reporting what would change. After the formatting is done, you should see the changes in git status and be able to commit them. * Remove black; will fight over eachother with ruff * run ruff on pre-existing issues * chore(pyproject.toml): remove black, update dirctories * ci: remove black-formatter.yaml * ci: add static-analysis.yaml running pre-commit hook * build: add ruff to development dependencies --- .github/workflows/black-formatter.yml | 29 --------- .github/workflows/static-analysis.yaml | 62 +++++++++++++++++++ .pre-commit-config.yaml | 7 +++ CONTRIBUTING.md | 12 ++++ agentops/cli.py | 4 +- agentops/client.py | 53 ++++------------ agentops/decorators.py | 27 +++----- agentops/helpers.py | 19 +++--- agentops/host_env.py | 7 +-- agentops/http_client.py | 18 ++---- agentops/llms/__init__.py | 44 ++++--------- agentops/llms/ai21.py | 17 ++--- agentops/llms/anthropic.py | 55 ++++------------ agentops/llms/cohere.py | 23 +++---- agentops/llms/groq.py | 8 +-- agentops/llms/instrumented_provider.py | 4 +- agentops/llms/litellm.py | 24 ++----- agentops/llms/mistral.py | 9 +-- agentops/llms/ollama.py | 8 +-- agentops/llms/openai.py | 34 +++------- agentops/meta_client.py | 4 +- agentops/partners/autogen_logger.py | 12 +--- .../partners/langchain_callback_handler.py | 41 +++--------- agentops/session.py | 16 ++--- agentops/singleton.py | 2 - agentops/time_travel.py | 20 ++---- pyproject.toml | 48 ++++++++++++++ tests/core_manual_tests/api_server/server.py | 1 - tests/core_manual_tests/multi_session_llm.py | 4 +- .../providers/litellm_canary.py | 4 +- .../_test_langchain_handler.py | 24 ++----- tests/openai_handlers/_test_gpt_vision.py | 8 +-- .../_test_handler_openai_v0.py | 20 ++---- .../_test_llm_tracker_ge_1_async.py | 4 +- tests/test_canary.py | 4 +- tests/test_events.py | 8 +-- tests/test_pre_init.py | 8 +-- tests/test_record_action.py | 20 ++---- tests/test_record_tool.py | 20 ++---- tests/test_session.py | 26 +++----- tests/test_teardown.py | 8 +-- tests/test_time_travel.py | 1 - 42 files changed, 288 insertions(+), 479 deletions(-) delete mode 100644 .github/workflows/black-formatter.yml create mode 100644 .github/workflows/static-analysis.yaml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/black-formatter.yml b/.github/workflows/black-formatter.yml deleted file mode 100644 index c8b3e342..00000000 --- a/.github/workflows/black-formatter.yml +++ /dev/null @@ -1,29 +0,0 @@ - -name: Black Code Formatter Check - -on: - pull_request: - paths: - - '**/*.py' - - '**/*.ipynb' - -jobs: - black-check: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.11' # Specify the version of Python you need - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install "black[jupyter]" - - - name: Run Black - run: black --diff --check . \ No newline at end of file diff --git a/.github/workflows/static-analysis.yaml b/.github/workflows/static-analysis.yaml new file mode 100644 index 00000000..6b03faf9 --- /dev/null +++ b/.github/workflows/static-analysis.yaml @@ -0,0 +1,62 @@ +name: Static analysis + +# on PR and push to main +on: + push: + branches: + - main + pull_request: + paths: + - '**/*.py' + +permissions: + contents: read + +# Limit concurrency by workflow/branch combination. +# +# For pull request builds, pushing additional changes to the +# branch will cancel prior in-progress and pending builds. +# +# For builds triggered on a branch push, additional changes +# will wait for prior builds to complete before starting. +# +# https://docs.github.com/en/actions/using-jobs/using-concurrency +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +jobs: + pre-commit-checks: + name: Pre-commit checks + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + id: setup_python + with: + python-version: "3.11.10" + + - name: UV Cache + # Manually cache the uv cache directory + # until setup-python supports it: + # https://github.com/actions/setup-python/issues/822 + uses: actions/cache@v4 + id: cache-uv + with: + path: ~/.cache/uv + key: uvcache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} + + - name: Install packages + run: | + python -m pip install -U uv pre-commit + uv pip install --upgrade --system -e .[dev] + + - name: Run pre-commit + run: | + pre-commit run --show-diff-on-failure --color=always --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..138eedbc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.2.1" + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4ba0523d..1db574e5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,6 +77,18 @@ Even if you're not ready to contribute code, we'd love to hear your thoughts. Dr .\venv\Scripts\activate # Windows ``` +3. **Pre-commit Setup**: + We use pre-commit hooks to automatically format and lint code. Set them up with: + ```bash + pip install pre-commit + pre-commit install + ``` + + That's it! The hooks will run automatically when you commit. To manually check all files: + ```bash + pre-commit run --all-files + ``` + ## Testing We use a comprehensive testing stack to ensure code quality and reliability. Our testing framework includes pytest and several specialized testing tools. diff --git a/agentops/cli.py b/agentops/cli.py index 562c8053..29a81123 100644 --- a/agentops/cli.py +++ b/agentops/cli.py @@ -6,9 +6,7 @@ def main(): parser = argparse.ArgumentParser(description="AgentOps CLI") subparsers = parser.add_subparsers(dest="command") - timetravel_parser = subparsers.add_parser( - "timetravel", help="Time Travel Debugging commands", aliases=["tt"] - ) + timetravel_parser = subparsers.add_parser("timetravel", help="Time Travel Debugging commands", aliases=["tt"]) timetravel_parser.add_argument( "branch_name", type=str, diff --git a/agentops/client.py b/agentops/client.py index 80cb1cc5..86fe49b8 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -44,10 +44,7 @@ def __init__(self): api_key=os.environ.get("AGENTOPS_API_KEY"), parent_key=os.environ.get("AGENTOPS_PARENT_KEY"), endpoint=os.environ.get("AGENTOPS_API_ENDPOINT"), - env_data_opt_out=os.environ.get( - "AGENTOPS_ENV_DATA_OPT_OUT", "False" - ).lower() - == "true", + env_data_opt_out=os.environ.get("AGENTOPS_ENV_DATA_OPT_OUT", "False").lower() == "true", ) def configure( @@ -106,9 +103,7 @@ def initialize(self) -> Union[Session, None]: if session: for agent_args in self._pre_init_queue["agents"]: - session.create_agent( - name=agent_args["name"], agent_id=agent_args["agent_id"] - ) + session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"]) self._pre_init_queue["agents"] = [] return session @@ -141,9 +136,7 @@ def add_tags(self, tags: List[str]) -> None: session = self._safe_get_session() if session is None: - return logger.warning( - "Could not add tags. Start a session by calling agentops.start_session()." - ) + return logger.warning("Could not add tags. Start a session by calling agentops.start_session().") session.add_tags(tags=tags) @@ -162,9 +155,7 @@ def set_tags(self, tags: List[str]) -> None: session = self._safe_get_session() if session is None: - return logger.warning( - "Could not set tags. Start a session by calling agentops.start_session()." - ) + return logger.warning("Could not set tags. Start a session by calling agentops.start_session().") else: session.set_tags(tags=tags) @@ -198,9 +189,7 @@ def record(self, event: Union[Event, ErrorEvent]) -> None: session = self._safe_get_session() if session is None: - return logger.error( - "Could not record event. Start a session by calling agentops.start_session()." - ) + return logger.error("Could not record event. Start a session by calling agentops.start_session().") session.record(event) def start_session( @@ -244,9 +233,7 @@ def start_session( if self._pre_init_queue["agents"] and len(self._pre_init_queue["agents"]) > 0: for agent_args in self._pre_init_queue["agents"]: - session.create_agent( - name=agent_args["name"], agent_id=agent_args["agent_id"] - ) + session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"]) self._pre_init_queue["agents"] = [] self._sessions.append(session) @@ -277,9 +264,7 @@ def end_session( if is_auto_end and self._config.skip_auto_end_session: return - token_cost = session.end_session( - end_state=end_state, end_state_reason=end_state_reason, video=video - ) + token_cost = session.end_session(end_state=end_state, end_state_reason=end_state_reason, video=video) return token_cost @@ -299,9 +284,7 @@ def create_agent( # if no session passed, assume single session session = self._safe_get_session() if session is None: - self._pre_init_queue["agents"].append( - {"name": name, "agent_id": agent_id} - ) + self._pre_init_queue["agents"].append({"name": name, "agent_id": agent_id}) else: session.create_agent(name=name, agent_id=agent_id) @@ -326,9 +309,7 @@ def signal_handler(signum, frame): """ signal_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM" logger.info("%s detected. Ending session...", signal_name) - self.end_session( - end_state="Fail", end_state_reason=f"Signal {signal_name} detected" - ) + self.end_session(end_state="Fail", end_state_reason=f"Signal {signal_name} detected") sys.exit(0) def handle_exception(exc_type, exc_value, exc_traceback): @@ -341,9 +322,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): exc_traceback (TracebackType): A traceback object encapsulating the call stack at the point where the exception originally occurred. """ - formatted_traceback = "".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) + formatted_traceback = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) for session in self._sessions: session.end_session( @@ -376,13 +355,7 @@ def add_pre_init_warning(self, message: str): # replaces the session currently stored with a specific session_id, with a new session def _update_session(self, session: Session): self._sessions[ - self._sessions.index( - [ - sess - for sess in self._sessions - if sess.session_id == session.session_id - ][0] - ) + self._sessions.index([sess for sess in self._sessions if sess.session_id == session.session_id][0]) ] = session def _safe_get_session(self) -> Optional[Session]: @@ -392,9 +365,7 @@ def _safe_get_session(self) -> Optional[Session]: return self._sessions[0] if len(self._sessions) > 1: - calling_function = inspect.stack()[ - 2 - ].function # Using index 2 because we have a wrapper at index 1 + calling_function = inspect.stack()[2].function # Using index 2 because we have a wrapper at index 1 return logger.warning( f"Multiple sessions detected. You must use session.{calling_function}(). More info: https://docs.agentops.ai/v1/concepts/core-concepts#session-management" ) diff --git a/agentops/decorators.py b/agentops/decorators.py index d291c416..860b7c73 100644 --- a/agentops/decorators.py +++ b/agentops/decorators.py @@ -45,9 +45,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): arg_names = list(func_args.keys()) # Get default values arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty + name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty } # Update with positional arguments arg_values.update(dict(zip(arg_names, args))) @@ -111,9 +109,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): arg_names = list(func_args.keys()) # Get default values arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty + name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty } # Update with positional arguments arg_values.update(dict(zip(arg_names, args))) @@ -191,9 +187,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): arg_names = list(func_args.keys()) # Get default values arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty + name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty } # Update with positional arguments arg_values.update(dict(zip(arg_names, args))) @@ -257,9 +251,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): arg_names = list(func_args.keys()) # Get default values arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty + name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty } # Update with positional arguments arg_values.update(dict(zip(arg_names, args))) @@ -338,12 +330,8 @@ def new_init(self, *args, **kwargs): session=session, ) except AttributeError as e: - Client().add_pre_init_warning( - f"Failed to track an agent {name} with the @track_agent decorator." - ) - logger.warning( - "Failed to track an agent with the @track_agent decorator." - ) + Client().add_pre_init_warning(f"Failed to track an agent {name} with the @track_agent decorator.") + logger.warning("Failed to track an agent with the @track_agent decorator.") original_init(self, *args, **kwargs) obj.__init__ = new_init @@ -351,7 +339,8 @@ def new_init(self, *args, **kwargs): elif inspect.isfunction(obj): obj.agent_ops_agent_id = str(uuid4()) # type: ignore Client().create_agent( - name=obj.agent_ops_agent_name, agent_id=obj.agent_ops_agent_id # type: ignore + name=obj.agent_ops_agent_name, + agent_id=obj.agent_ops_agent_id, # type: ignore ) else: diff --git a/agentops/helpers.py b/agentops/helpers.py index fbcd21d5..02530213 100644 --- a/agentops/helpers.py +++ b/agentops/helpers.py @@ -9,7 +9,6 @@ from .log_config import logger from uuid import UUID -from importlib.metadata import version def get_ISO_time(): @@ -38,7 +37,9 @@ def filter_dict(obj): k: ( filter_dict(v) if isinstance(v, (dict, list)) or is_jsonable(v) - else str(v) if isinstance(v, UUID) else "" + else str(v) + if isinstance(v, UUID) + else "" ) for k, v in obj.items() } @@ -47,7 +48,9 @@ def filter_dict(obj): ( filter_dict(x) if isinstance(x, (dict, list)) or is_jsonable(x) - else str(x) if isinstance(x, UUID) else "" + else str(x) + if isinstance(x, UUID) + else "" ) for x in obj ] @@ -85,9 +88,7 @@ def remove_unwanted_items(value): """Recursively remove self key and None/... values from dictionaries so they aren't serialized""" if isinstance(value, dict): return { - k: remove_unwanted_items(v) - for k, v in value.items() - if v is not None and v is not ... and k != "self" + k: remove_unwanted_items(v) for k, v in value.items() if v is not None and v is not ... and k != "self" } elif isinstance(value, list): return [remove_unwanted_items(item) for item in value] @@ -106,9 +107,7 @@ def check_call_stack_for_agent_id() -> Union[UUID, None]: # We stop looking up the stack at main because after that we see global variables if var == "__main__": return None - if hasattr(var, "agent_ops_agent_id") and getattr( - var, "agent_ops_agent_id" - ): + if hasattr(var, "agent_ops_agent_id") and getattr(var, "agent_ops_agent_id"): logger.debug( "LLM call from agent named: %s", getattr(var, "agent_ops_agent_name"), @@ -141,7 +140,7 @@ def check_agentops_update(): if not latest_version == current_version: logger.warning( - f" WARNING: agentops is out of date. Please update with the command: 'pip install --upgrade agentops'" + " WARNING: agentops is out of date. Please update with the command: 'pip install --upgrade agentops'" ) except Exception as e: logger.debug(f"Failed to check for updates: {e}") diff --git a/agentops/host_env.py b/agentops/host_env.py index e466c86e..5307dec4 100644 --- a/agentops/host_env.py +++ b/agentops/host_env.py @@ -51,8 +51,7 @@ def get_installed_packages(): return { # TODO: add to opt out "Installed Packages": { - dist.metadata.get("Name"): dist.metadata.get("Version") - for dist in importlib.metadata.distributions() + dist.metadata.get("Name"): dist.metadata.get("Version") for dist in importlib.metadata.distributions() } } except: @@ -125,9 +124,7 @@ def get_disk_details(): } except OSError as inaccessible: # Skip inaccessible partitions, such as removable drives with no media - logger.debug( - "Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible - ) + logger.debug("Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible) return disk_info diff --git a/agentops/http_client.py b/agentops/http_client.py index 32d7d336..caa18b27 100644 --- a/agentops/http_client.py +++ b/agentops/http_client.py @@ -22,10 +22,7 @@ class HttpStatus(Enum): class Response: - - def __init__( - self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None - ): + def __init__(self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None): self.status: HttpStatus = status self.code: int = status.value self.body = body if body else {} @@ -57,7 +54,6 @@ def get_status(code: int) -> HttpStatus: class HttpClient: - @staticmethod def post( url: str, @@ -82,17 +78,13 @@ def post( if jwt is not None: JSON_HEADER["Authorization"] = f"Bearer {jwt}" - res = request_session.post( - url, data=payload, headers=JSON_HEADER, timeout=20 - ) + res = request_session.post(url, data=payload, headers=JSON_HEADER, timeout=20) result.parse(res) except requests.exceptions.Timeout: result.code = 408 result.status = HttpStatus.TIMEOUT - raise ApiServerException( - "Could not reach API server - connection timed out" - ) + raise ApiServerException("Could not reach API server - connection timed out") except requests.exceptions.HTTPError as e: try: result.parse(e.response) @@ -145,9 +137,7 @@ def get( except requests.exceptions.Timeout: result.code = 408 result.status = HttpStatus.TIMEOUT - raise ApiServerException( - "Could not reach API server - connection timed out" - ) + raise ApiServerException("Could not reach API server - connection timed out") except requests.exceptions.HTTPError as e: try: result.parse(e.response) diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index 1bd8c3b7..a5852d8c 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -66,17 +66,13 @@ def override_api(self): if api == "litellm": module_version = version(api) if module_version is None: - logger.warning( - f"Cannot determine LiteLLM version. Only LiteLLM>=1.3.1 supported." - ) + logger.warning("Cannot determine LiteLLM version. Only LiteLLM>=1.3.1 supported.") if Version(module_version) >= parse("1.3.1"): provider = LiteLLMProvider(self.client) provider.override() else: - logger.warning( - f"Only LiteLLM>=1.3.1 supported. v{module_version} found." - ) + logger.warning(f"Only LiteLLM>=1.3.1 supported. v{module_version} found.") return # If using an abstraction like litellm, do not patch the underlying LLM APIs if api == "openai": @@ -96,17 +92,13 @@ def override_api(self): # Patch cohere v5.4.0+ methods module_version = version(api) if module_version is None: - logger.warning( - f"Cannot determine Cohere version. Only Cohere>=5.4.0 supported." - ) + logger.warning("Cannot determine Cohere version. Only Cohere>=5.4.0 supported.") if Version(module_version) >= parse("5.4.0"): provider = CohereProvider(self.client) provider.override() else: - logger.warning( - f"Only Cohere>=5.4.0 supported. v{module_version} found." - ) + logger.warning(f"Only Cohere>=5.4.0 supported. v{module_version} found.") if api == "ollama": module_version = version(api) @@ -115,9 +107,7 @@ def override_api(self): provider = OllamaProvider(self.client) provider.override() else: - logger.warning( - f"Only Ollama>=0.0.1 supported. v{module_version} found." - ) + logger.warning(f"Only Ollama>=0.0.1 supported. v{module_version} found.") if api == "groq": module_version = version(api) @@ -126,25 +116,19 @@ def override_api(self): provider = GroqProvider(self.client) provider.override() else: - logger.warning( - f"Only Groq>=0.9.0 supported. v{module_version} found." - ) + logger.warning(f"Only Groq>=0.9.0 supported. v{module_version} found.") if api == "anthropic": module_version = version(api) if module_version is None: - logger.warning( - f"Cannot determine Anthropic version. Only Anthropic>=0.32.0 supported." - ) + logger.warning("Cannot determine Anthropic version. Only Anthropic>=0.32.0 supported.") if Version(module_version) >= parse("0.32.0"): provider = AnthropicProvider(self.client) provider.override() else: - logger.warning( - f"Only Anthropic>=0.32.0 supported. v{module_version} found." - ) + logger.warning(f"Only Anthropic>=0.32.0 supported. v{module_version} found.") if api == "mistralai": module_version = version(api) @@ -153,25 +137,19 @@ def override_api(self): provider = MistralProvider(self.client) provider.override() else: - logger.warning( - f"Only MistralAI>=1.0.1 supported. v{module_version} found." - ) + logger.warning(f"Only MistralAI>=1.0.1 supported. v{module_version} found.") if api == "ai21": module_version = version(api) if module_version is None: - logger.warning( - f"Cannot determine AI21 version. Only AI21>=2.0.0 supported." - ) + logger.warning("Cannot determine AI21 version. Only AI21>=2.0.0 supported.") if Version(module_version) >= parse("2.0.0"): provider = AI21Provider(self.client) provider.override() else: - logger.warning( - f"Only AI21>=2.0.0 supported. v{module_version} found." - ) + logger.warning(f"Only AI21>=2.0.0 supported. v{module_version} found.") def stop_instrumenting(self): OpenAiProvider(self.client).undo_override() diff --git a/agentops/llms/ai21.py b/agentops/llms/ai21.py index 533ad276..4e080a5b 100644 --- a/agentops/llms/ai21.py +++ b/agentops/llms/ai21.py @@ -14,7 +14,6 @@ @singleton class AI21Provider(InstrumentedProvider): - original_create = None original_create_async = None original_answer = None @@ -24,9 +23,7 @@ def __init__(self, client): super().__init__(client) self._provider_name = "AI21" - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ): + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None): """Handle responses for AI21""" from ai21.stream.stream import Stream from ai21.stream.async_stream import AsyncStream @@ -51,9 +48,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): accumulated_delta = llm_event.returns.choices[0].delta llm_event.agent_id = check_call_stack_for_agent_id() llm_event.model = kwargs["model"] - llm_event.prompt = [ - message.model_dump() for message in kwargs["messages"] - ] + llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] # NOTE: We assume for completion only choices[0] is relevant choice = chunk.choices[0] @@ -80,9 +75,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -119,9 +112,7 @@ async def async_generator(): llm_event.returns = response llm_event.agent_id = check_call_stack_for_agent_id() llm_event.model = kwargs["model"] - llm_event.prompt = [ - message.model_dump() for message in kwargs["messages"] - ] + llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] llm_event.prompt_tokens = response.usage.prompt_tokens llm_event.completion = response.choices[0].message.model_dump() llm_event.completion_tokens = response.usage.completion_tokens diff --git a/agentops/llms/anthropic.py b/agentops/llms/anthropic.py index 183694a4..bc63d85c 100644 --- a/agentops/llms/anthropic.py +++ b/agentops/llms/anthropic.py @@ -13,7 +13,6 @@ @singleton class AnthropicProvider(InstrumentedProvider): - original_create = None original_create_async = None @@ -23,9 +22,7 @@ def __init__(self, client): self.tool_event = {} self.tool_id = "" - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ): + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None): """Handle responses for Anthropic""" from anthropic import Stream, AsyncStream from anthropic.resources import AsyncMessages @@ -66,9 +63,7 @@ def handle_stream_chunk(chunk: Message): llm_event.completion["content"] += chunk.delta.text elif chunk.delta.type == "input_json_delta": - self.tool_event[self.tool_id].logs[ - "input" - ] += chunk.delta.partial_json + self.tool_event[self.tool_id].logs["input"] += chunk.delta.partial_json elif chunk.type == "content_block_stop": pass @@ -81,9 +76,7 @@ def handle_stream_chunk(chunk: Message): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -178,9 +171,7 @@ def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = None pydantic_models = ( @@ -195,9 +186,7 @@ def patched_function(*args, **kwargs): for pydantic_model in pydantic_models: try: - result_model = pydantic_model.model_validate_json( - completion_override - ) + result_model = pydantic_model.model_validate_json(completion_override) break except Exception as e: pass @@ -209,18 +198,12 @@ def patched_function(*args, **kwargs): f"{pprint.pformat(completion_override)}" ) return None - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # Call the original function with its original arguments - original_func = ( - self.original_create_beta if is_beta else self.original_create - ) + original_func = self.original_create_beta if is_beta else self.original_create result = original_func(*args, **kwargs) - return self.handle_response( - result, kwargs, init_timestamp, session=session - ) + return self.handle_response(result, kwargs, init_timestamp, session=session) return patched_function @@ -252,9 +235,7 @@ async def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = None pydantic_models = ( @@ -269,9 +250,7 @@ async def patched_function(*args, **kwargs): for pydantic_model in pydantic_models: try: - result_model = pydantic_model.model_validate_json( - completion_override - ) + result_model = pydantic_model.model_validate_json(completion_override) break except Exception as e: pass @@ -284,20 +263,12 @@ async def patched_function(*args, **kwargs): ) return None - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # Call the original function with its original arguments - original_func = ( - self.original_create_async_beta - if is_beta - else self.original_create_async - ) + original_func = self.original_create_async_beta if is_beta else self.original_create_async result = await original_func(*args, **kwargs) - return self.handle_response( - result, kwargs, init_timestamp, session=session - ) + return self.handle_response(result, kwargs, init_timestamp, session=session) return patched_function diff --git a/agentops/llms/cohere.py b/agentops/llms/cohere.py index ad8f93a5..0756bc69 100644 --- a/agentops/llms/cohere.py +++ b/agentops/llms/cohere.py @@ -36,9 +36,7 @@ def undo_override(self): def __init__(self, client): super().__init__(client) - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ): + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None): """Handle responses for Cohere versions >v5.4.0""" from cohere.types.streamed_chat_response import ( StreamedChatResponse_CitationGeneration, @@ -59,7 +57,6 @@ def handle_response( self.action_events = {} def handle_stream_chunk(chunk, session: Optional[Session] = None): - # We take the first chunk and accumulate the deltas from all subsequent chunks to build one full chat completion if isinstance(chunk, StreamedChatResponse_StreamStart): llm_event.returns = chunk @@ -102,9 +99,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): citation_dict = citation.dict() # Replace document_ids with the actual documents citation_dict["documents"] = [ - documents[doc_id] - for doc_id in citation_dict["document_ids"] - if doc_id in documents + documents[doc_id] for doc_id in citation_dict["document_ids"] if doc_id in documents ] del citation_dict["document_ids"] @@ -120,12 +115,10 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): pass elif isinstance(chunk, StreamedChatResponse_CitationGeneration): for citation in chunk.citations: - self.action_events[f"{citation.start}.{citation.end}"] = ( - ActionEvent( - action_type="citation", - init_timestamp=get_ISO_time(), - params=citation.text, - ) + self.action_events[f"{citation.start}.{citation.end}"] = ActionEvent( + action_type="citation", + init_timestamp=get_ISO_time(), + params=citation.text, ) elif isinstance(chunk, StreamedChatResponse_SearchQueriesGeneration): for query in chunk.search_queries: @@ -138,9 +131,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): pass except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) diff --git a/agentops/llms/groq.py b/agentops/llms/groq.py index ca869638..23d5b8db 100644 --- a/agentops/llms/groq.py +++ b/agentops/llms/groq.py @@ -29,9 +29,7 @@ def undo_override(self): completions.Completions.create = self.original_create completions.AsyncCompletions.create = self.original_create - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ): + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None): """Handle responses for OpenAI versions >v1.0.0""" from groq import AsyncStream, Stream from groq.resources.chat import AsyncCompletions @@ -81,9 +79,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) diff --git a/agentops/llms/instrumented_provider.py b/agentops/llms/instrumented_provider.py index f33f397a..f4dc78cf 100644 --- a/agentops/llms/instrumented_provider.py +++ b/agentops/llms/instrumented_provider.py @@ -14,9 +14,7 @@ def __init__(self, client): self.client = client @abstractmethod - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ) -> dict: + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: pass @abstractmethod diff --git a/agentops/llms/litellm.py b/agentops/llms/litellm.py index 30b4c25a..0f73c22b 100644 --- a/agentops/llms/litellm.py +++ b/agentops/llms/litellm.py @@ -40,9 +40,7 @@ def undo_override(self): completions.Completions.create = self.original_oai_create completions.AsyncCompletions.create = self.original_oai_create_async - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ) -> dict: + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: """Handle responses for OpenAI versions >v1.0.0""" from openai import AsyncStream, Stream from openai.resources import AsyncCompletions @@ -93,9 +91,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -186,14 +182,10 @@ def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = ChatCompletion.model_validate_json(completion_override) - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) # if prompt_override: @@ -222,14 +214,10 @@ async def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = ChatCompletion.model_validate_json(completion_override) - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) # if prompt_override: diff --git a/agentops/llms/mistral.py b/agentops/llms/mistral.py index 8be21946..770d4a81 100644 --- a/agentops/llms/mistral.py +++ b/agentops/llms/mistral.py @@ -11,7 +11,6 @@ class MistralProvider(InstrumentedProvider): - original_complete = None original_complete_async = None original_stream = None @@ -21,9 +20,7 @@ def __init__(self, client): super().__init__(client) self._provider_name = "Mistral" - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ) -> dict: + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: """Handle responses for Mistral""" from mistralai import Chat from mistralai.types import UNSET, UNSET_SENTINEL @@ -73,9 +70,7 @@ def handle_stream_chunk(chunk: dict): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) diff --git a/agentops/llms/ollama.py b/agentops/llms/ollama.py index e5779283..5c998897 100644 --- a/agentops/llms/ollama.py +++ b/agentops/llms/ollama.py @@ -16,9 +16,7 @@ class OllamaProvider(InstrumentedProvider): original_create = None original_create_async = None - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ) -> dict: + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) def handle_stream_chunk(chunk: dict): @@ -84,9 +82,7 @@ def patched_function(*args, **kwargs): # Call the original function with its original arguments init_timestamp = get_ISO_time() result = original_func["ollama.chat"](*args, **kwargs) - return self.handle_response( - result, kwargs, init_timestamp, session=kwargs.get("session", None) - ) + return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None)) # Override the original method with the patched one ollama.chat = patched_function diff --git a/agentops/llms/openai.py b/agentops/llms/openai.py index c99523d7..24756971 100644 --- a/agentops/llms/openai.py +++ b/agentops/llms/openai.py @@ -14,7 +14,6 @@ @singleton class OpenAiProvider(InstrumentedProvider): - original_create = None original_create_async = None @@ -22,9 +21,7 @@ def __init__(self, client): super().__init__(client) self._provider_name = "OpenAI" - def handle_response( - self, response, kwargs, init_timestamp, session: Optional[Session] = None - ) -> dict: + def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: """Handle responses for OpenAI versions >v1.0.0""" from openai import AsyncStream, Stream from openai.resources import AsyncCompletions @@ -74,9 +71,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self._safe_record(session, llm_event) except Exception as e: - self._safe_record( - session, ErrorEvent(trigger_event=llm_event, exception=e) - ) + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -157,17 +152,13 @@ def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = None pydantic_models = (ChatCompletion, ChatCompletionChunk) for pydantic_model in pydantic_models: try: - result_model = pydantic_model.model_validate_json( - completion_override - ) + result_model = pydantic_model.model_validate_json(completion_override) break except Exception as e: pass @@ -179,9 +170,7 @@ def patched_function(*args, **kwargs): f"{pprint.pformat(completion_override)}" ) return None - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) # if prompt_override: @@ -202,24 +191,19 @@ def _override_openai_v1_async_completion(self): self.original_create_async = completions.AsyncCompletions.create async def patched_function(*args, **kwargs): - init_timestamp = get_ISO_time() session = kwargs.get("session", None) if "session" in kwargs.keys(): del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache( - kwargs - ) + completion_override = fetch_completion_override_from_time_travel_cache(kwargs) if completion_override: result_model = None pydantic_models = (ChatCompletion, ChatCompletionChunk) for pydantic_model in pydantic_models: try: - result_model = pydantic_model.model_validate_json( - completion_override - ) + result_model = pydantic_model.model_validate_json(completion_override) break except Exception as e: pass @@ -231,9 +215,7 @@ async def patched_function(*args, **kwargs): f"{pprint.pformat(completion_override)}" ) return None - return self.handle_response( - result_model, kwargs, init_timestamp, session=session - ) + return self.handle_response(result_model, kwargs, init_timestamp, session=session) # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) # if prompt_override: diff --git a/agentops/meta_client.py b/agentops/meta_client.py index 6e222771..6cc7ed2e 100644 --- a/agentops/meta_client.py +++ b/agentops/meta_client.py @@ -14,9 +14,7 @@ class MetaClient(type): def __new__(cls, name, bases, dct): # Wrap each method with the handle_exceptions decorator for method_name, method in dct.items(): - if ( - callable(method) and not method_name.startswith("__") - ) or method_name == "__init__": + if (callable(method) and not method_name.startswith("__")) or method_name == "__init__": dct[method_name] = handle_exceptions(method) return super().__new__(cls, name, bases, dct) diff --git a/agentops/partners/autogen_logger.py b/agentops/partners/autogen_logger.py index 4080d792..77aca142 100644 --- a/agentops/partners/autogen_logger.py +++ b/agentops/partners/autogen_logger.py @@ -72,13 +72,9 @@ def log_chat_completion( def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None: """Calls agentops.create_agent""" ao_agent_id = agentops.create_agent(agent.name, str(uuid4())) - self.agent_store.append( - {"agentops_id": ao_agent_id, "autogen_id": str(id(agent))} - ) + self.agent_store.append({"agentops_id": ao_agent_id, "autogen_id": str(id(agent))}) - def log_event( - self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any] - ) -> None: + def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: """Records an ActionEvent to AgentOps session""" event = ActionEvent(action_type=name) agentops_id = self._get_agentops_id_from_agent(str(id(source))) @@ -86,9 +82,7 @@ def log_event( event.params = kwargs agentops.record(event) - def log_function_use( - self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: any - ): + def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: any): """Records a ToolEvent to AgentOps session""" event = ToolEvent() agentops_id = self._get_agentops_id_from_agent(str(id(source))) diff --git a/agentops/partners/langchain_callback_handler.py b/agentops/partners/langchain_callback_handler.py index 847abe66..7d3ded9b 100644 --- a/agentops/partners/langchain_callback_handler.py +++ b/agentops/partners/langchain_callback_handler.py @@ -131,12 +131,8 @@ def on_llm_end( llm_event.end_timestamp = get_ISO_time() llm_event.completion = response.generations[0][0].text if response.llm_output is not None: - llm_event.prompt_tokens = response.llm_output["token_usage"][ - "prompt_tokens" - ] - llm_event.completion_tokens = response.llm_output["token_usage"][ - "completion_tokens" - ] + llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] + llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] if len(response.generations) == 0: # TODO: more descriptive error @@ -297,9 +293,7 @@ def on_retriever_end( **kwargs: Any, ) -> None: action_event: ActionEvent = self.events.retriever[str(run_id)] - action_event.logs = ( - documents # TODO: Adding this. Might want to add elsewhere e.g. params - ) + action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params action_event.end_timestamp = get_ISO_time() self.ao_client.record(action_event) @@ -326,9 +320,7 @@ def on_agent_action( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - self.agent_actions[run_id].append( - ActionEvent(params={"action": action, **kwargs}, action_type="agent") - ) + self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent")) @debug_print_function_params def on_agent_finish( @@ -367,9 +359,7 @@ def on_retry( @property def session_id(self): - raise DeprecationWarning( - "session_id is deprecated in favor of current_session_ids" - ) + raise DeprecationWarning("session_id is deprecated in favor of current_session_ids") @property def current_session_ids(self): @@ -387,7 +377,6 @@ def __init__( max_queue_size: Optional[int] = None, tags: Optional[List[str]] = None, ): - client_params: Dict[str, Any] = { "api_key": api_key, "endpoint": endpoint, @@ -396,9 +385,7 @@ def __init__( "tags": tags, } - self.ao_client = AOClient( - **{k: v for k, v in client_params.items() if v is not None}, override=False - ) + self.ao_client = AOClient(**{k: v for k, v in client_params.items() if v is not None}, override=False) self.events = Events() self.agent_actions: Dict[UUID, List[ActionEvent]] = defaultdict(list) @@ -482,12 +469,8 @@ async def on_llm_end( llm_event.end_timestamp = get_ISO_time() llm_event.completion = response.generations[0][0].text if response.llm_output is not None: - llm_event.prompt_tokens = response.llm_output["token_usage"][ - "prompt_tokens" - ] - llm_event.completion_tokens = response.llm_output["token_usage"][ - "completion_tokens" - ] + llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] + llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"] if len(response.generations) == 0: # TODO: more descriptive error @@ -645,9 +628,7 @@ async def on_retriever_end( **kwargs: Any, ) -> None: action_event: ActionEvent = self.events.retriever[str(run_id)] - action_event.logs = ( - documents # TODO: Adding this. Might want to add elsewhere e.g. params - ) + action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params action_event.end_timestamp = get_ISO_time() self.ao_client.record(action_event) @@ -674,9 +655,7 @@ async def on_agent_action( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - self.agent_actions[run_id].append( - ActionEvent(params={"action": action, **kwargs}, action_type="agent") - ) + self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent")) @debug_print_function_params async def on_agent_finish( diff --git a/agentops/session.py b/agentops/session.py index 58225b6c..3cfc1303 100644 --- a/agentops/session.py +++ b/agentops/session.py @@ -86,14 +86,11 @@ def end_session( end_state_reason: Optional[str] = None, video: Optional[str] = None, ) -> Union[Decimal, None]: - if not self.is_running: return if not any(end_state == state.value for state in EndState): - return logger.warning( - "Invalid end_state. Please use one of the EndState enums" - ) + return logger.warning("Invalid end_state. Please use one of the EndState enums") self.end_timestamp = get_ISO_time() self.end_state = end_state @@ -146,9 +143,7 @@ def format_duration(start_time, end_time): formatted_cost = ( "{:.2f}".format(token_cost_d) if token_cost_d == 0 - else "{:.6f}".format( - token_cost_d.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP) - ) + else "{:.6f}".format(token_cost_d.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP)) ) analytics = ( @@ -222,8 +217,7 @@ def record(self, event: Union[Event, ErrorEvent]): if event.trigger_event: if ( not event.trigger_event.end_timestamp - or event.trigger_event.init_timestamp - == event.trigger_event.end_timestamp + or event.trigger_event.init_timestamp == event.trigger_event.end_timestamp ): event.trigger_event.end_timestamp = get_ISO_time() @@ -338,9 +332,7 @@ def _flush_queue(self) -> None: return logger.error(f"Could not post events - {e}") logger.debug("\n") - logger.debug( - f"Session request to {self.config.endpoint}/v2/create_events" - ) + logger.debug(f"Session request to {self.config.endpoint}/v2/create_events") logger.debug(serialized_payload) logger.debug("\n") diff --git a/agentops/singleton.py b/agentops/singleton.py index 8200510e..b22e4edc 100644 --- a/agentops/singleton.py +++ b/agentops/singleton.py @@ -2,7 +2,6 @@ def singleton(class_): - def getinstance(*args, **kwargs): if class_ not in ao_instances: ao_instances[class_] = class_(*args, **kwargs) @@ -12,7 +11,6 @@ def getinstance(*args, **kwargs): def conditional_singleton(class_): - def getinstance(*args, **kwargs): use_singleton = kwargs.pop("use_singleton", True) if use_singleton: diff --git a/agentops/time_travel.py b/agentops/time_travel.py index 6c1c7588..55ad6662 100644 --- a/agentops/time_travel.py +++ b/agentops/time_travel.py @@ -20,9 +20,7 @@ def __init__(self): try: with open(cache_path, "r") as file: time_travel_cache_json = json.load(file) - self._completion_overrides = time_travel_cache_json.get( - "completion_overrides" - ) + self._completion_overrides = time_travel_cache_json.get("completion_overrides") except FileNotFoundError: return @@ -99,20 +97,14 @@ def find_cache_hit(prompt_messages, completion_overrides): continue if all( - isinstance(a, dict) - and isinstance(b, dict) - and a.get("content") == b.get("content") + isinstance(a, dict) and isinstance(b, dict) and a.get("content") == b.get("content") for a, b in zip(prompt_messages, cached_messages) ): return value except (SyntaxError, ValueError, TypeError) as e: - print( - f"{ttd_prepend_string} Error - Error processing completion_overrides item: {e}" - ) + print(f"{ttd_prepend_string} Error - Error processing completion_overrides item: {e}") except Exception as e: - print( - f"{ttd_prepend_string} Error - Unexpected error in find_cache_hit: {e}" - ) + print(f"{ttd_prepend_string} Error - Unexpected error in find_cache_hit: {e}") return None @@ -143,9 +135,7 @@ def set_time_travel_active_state(is_active: bool): try: yaml.dump(config, config_file) except: - print( - f"{ttd_prepend_string} Error - Unable to write to {config_path}. Time Travel not activated" - ) + print(f"{ttd_prepend_string} Error - Unable to write to {config_path}. Time Travel not activated") return if is_active: diff --git a/pyproject.toml b/pyproject.toml index 59954095..9093e7ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ "pytest-mock", "pyfakefs", "requests_mock==1.11.0", + "ruff", "tach~=0.9", ] langchain = [ @@ -54,3 +55,50 @@ agentops = "agentops.cli:main" [tool.pytest.ini_options] asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "function" + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = [ + "F401", # Unused imports + "E712", # Comparison to True/False + "E711", # Comparison to None + "E722", # Bare except + "F821", # Undefined names + "F841", # Unused variables +] + +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".github", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".vscode", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "docs", + "examples", + "node_modules", + "site-packages", + "venv", + "tests/core_manual_tests", +] diff --git a/tests/core_manual_tests/api_server/server.py b/tests/core_manual_tests/api_server/server.py index 61ee8c67..5ae5e1ae 100644 --- a/tests/core_manual_tests/api_server/server.py +++ b/tests/core_manual_tests/api_server/server.py @@ -14,7 +14,6 @@ @app.get("/completion") def completion(): - session = agentops.start_session(tags=["api-server-test"]) messages = [{"role": "user", "content": "Hello"}] diff --git a/tests/core_manual_tests/multi_session_llm.py b/tests/core_manual_tests/multi_session_llm.py index 25a652ae..bd55477c 100644 --- a/tests/core_manual_tests/multi_session_llm.py +++ b/tests/core_manual_tests/multi_session_llm.py @@ -25,9 +25,7 @@ session_1.record(ActionEvent(action_type="test event")) # option 2: add session as a keyword argument -response2 = openai.chat.completions.create( - model="gpt-3.5-turbo", messages=messages, temperature=0.5, session=session_2 -) +response2 = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages, temperature=0.5, session=session_2) session_1.end_session(end_state="Success") session_2.end_session(end_state="Success") diff --git a/tests/core_manual_tests/providers/litellm_canary.py b/tests/core_manual_tests/providers/litellm_canary.py index ecd2186a..0b9cf98c 100644 --- a/tests/core_manual_tests/providers/litellm_canary.py +++ b/tests/core_manual_tests/providers/litellm_canary.py @@ -7,9 +7,7 @@ load_dotenv() agentops.init(default_tags=["litellm-provider-test"]) -response = litellm.completion( - model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}] -) +response = litellm.completion(model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}]) stream_response = litellm.completion( model="gpt-3.5-turbo", diff --git a/tests/langchain_handlers/_test_langchain_handler.py b/tests/langchain_handlers/_test_langchain_handler.py index 97cc2205..8f468c54 100644 --- a/tests/langchain_handlers/_test_langchain_handler.py +++ b/tests/langchain_handlers/_test_langchain_handler.py @@ -14,13 +14,9 @@ AGENTOPS_API_KEY = os.environ.get("AGENTOPS_API_KEY") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") -agentops_handler = AgentOpsLangchainCallbackHandler( - api_key=AGENTOPS_API_KEY, tags=["Langchain Example"] -) +agentops_handler = AgentOpsLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Langchain Example"]) -llm = ChatOpenAI( - openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo" -) +llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo") @tool @@ -43,9 +39,7 @@ def find_movie(genre) -> str: llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, - callbacks=[ - agentops_handler - ], # You must pass in a callback handler to record your agent + callbacks=[agentops_handler], # You must pass in a callback handler to record your agent handle_parsing_errors=True, ) @@ -56,22 +50,16 @@ def find_movie(genre) -> str: ######## # Async -agentops_handler = AgentOpsAsyncLangchainCallbackHandler( - api_key=AGENTOPS_API_KEY, tags=["Async Example"] -) +agentops_handler = AgentOpsAsyncLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Async Example"]) -llm = ChatOpenAI( - openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo" -) +llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo") agent = initialize_agent( tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, - callbacks=[ - agentops_handler - ], # You must pass in a callback handler to record your agent + callbacks=[agentops_handler], # You must pass in a callback handler to record your agent handle_parsing_errors=True, ) diff --git a/tests/openai_handlers/_test_gpt_vision.py b/tests/openai_handlers/_test_gpt_vision.py index 29a04e7e..3aae1231 100644 --- a/tests/openai_handlers/_test_gpt_vision.py +++ b/tests/openai_handlers/_test_gpt_vision.py @@ -45,9 +45,7 @@ def encode_image(image_path): # Path to your image -image_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "logo_for_vision_test.png" -) +image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logo_for_vision_test.png") # Getting the base64 string base64_image = encode_image(image_path) @@ -71,9 +69,7 @@ def encode_image(image_path): "max_tokens": 300, } -response = requests.post( - "https://api.openai.com/v1/chat/completions", headers=headers, json=payload -) +response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) print(response.json()) diff --git a/tests/openai_handlers/_test_handler_openai_v0.py b/tests/openai_handlers/_test_handler_openai_v0.py index 601be8da..3dc4d5e0 100644 --- a/tests/openai_handlers/_test_handler_openai_v0.py +++ b/tests/openai_handlers/_test_handler_openai_v0.py @@ -7,9 +7,7 @@ print("Running OpenAI