diff --git a/.github/workflows/black-formatter.yml b/.github/workflows/black-formatter.yml
deleted file mode 100644
index c8b3e342..00000000
--- a/.github/workflows/black-formatter.yml
+++ /dev/null
@@ -1,29 +0,0 @@
-
-name: Black Code Formatter Check
-
-on:
- pull_request:
- paths:
- - '**/*.py'
- - '**/*.ipynb'
-
-jobs:
- black-check:
- runs-on: ubuntu-latest
-
- steps:
- - name: Checkout code
- uses: actions/checkout@v2
-
- - name: Set up Python
- uses: actions/setup-python@v2
- with:
- python-version: '3.11' # Specify the version of Python you need
-
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install "black[jupyter]"
-
- - name: Run Black
- run: black --diff --check .
\ No newline at end of file
diff --git a/.github/workflows/static-analysis.yaml b/.github/workflows/static-analysis.yaml
new file mode 100644
index 00000000..6b03faf9
--- /dev/null
+++ b/.github/workflows/static-analysis.yaml
@@ -0,0 +1,62 @@
+name: Static analysis
+
+# on PR and push to main
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ paths:
+ - '**/*.py'
+
+permissions:
+ contents: read
+
+# Limit concurrency by workflow/branch combination.
+#
+# For pull request builds, pushing additional changes to the
+# branch will cancel prior in-progress and pending builds.
+#
+# For builds triggered on a branch push, additional changes
+# will wait for prior builds to complete before starting.
+#
+# https://docs.github.com/en/actions/using-jobs/using-concurrency
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+
+jobs:
+ pre-commit-checks:
+ name: Pre-commit checks
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ persist-credentials: false
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ id: setup_python
+ with:
+ python-version: "3.11.10"
+
+ - name: UV Cache
+ # Manually cache the uv cache directory
+ # until setup-python supports it:
+ # https://github.com/actions/setup-python/issues/822
+ uses: actions/cache@v4
+ id: cache-uv
+ with:
+ path: ~/.cache/uv
+ key: uvcache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }}
+
+ - name: Install packages
+ run: |
+ python -m pip install -U uv pre-commit
+ uv pip install --upgrade --system -e .[dev]
+
+ - name: Run pre-commit
+ run: |
+ pre-commit run --show-diff-on-failure --color=always --all-files
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..138eedbc
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,7 @@
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: "v0.2.1"
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+ - id: ruff-format
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 4ba0523d..1db574e5 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -77,6 +77,18 @@ Even if you're not ready to contribute code, we'd love to hear your thoughts. Dr
.\venv\Scripts\activate # Windows
```
+3. **Pre-commit Setup**:
+ We use pre-commit hooks to automatically format and lint code. Set them up with:
+ ```bash
+ pip install pre-commit
+ pre-commit install
+ ```
+
+ That's it! The hooks will run automatically when you commit. To manually check all files:
+ ```bash
+ pre-commit run --all-files
+ ```
+
## Testing
We use a comprehensive testing stack to ensure code quality and reliability. Our testing framework includes pytest and several specialized testing tools.
diff --git a/agentops/cli.py b/agentops/cli.py
index 562c8053..29a81123 100644
--- a/agentops/cli.py
+++ b/agentops/cli.py
@@ -6,9 +6,7 @@ def main():
parser = argparse.ArgumentParser(description="AgentOps CLI")
subparsers = parser.add_subparsers(dest="command")
- timetravel_parser = subparsers.add_parser(
- "timetravel", help="Time Travel Debugging commands", aliases=["tt"]
- )
+ timetravel_parser = subparsers.add_parser("timetravel", help="Time Travel Debugging commands", aliases=["tt"])
timetravel_parser.add_argument(
"branch_name",
type=str,
diff --git a/agentops/client.py b/agentops/client.py
index 80cb1cc5..86fe49b8 100644
--- a/agentops/client.py
+++ b/agentops/client.py
@@ -44,10 +44,7 @@ def __init__(self):
api_key=os.environ.get("AGENTOPS_API_KEY"),
parent_key=os.environ.get("AGENTOPS_PARENT_KEY"),
endpoint=os.environ.get("AGENTOPS_API_ENDPOINT"),
- env_data_opt_out=os.environ.get(
- "AGENTOPS_ENV_DATA_OPT_OUT", "False"
- ).lower()
- == "true",
+ env_data_opt_out=os.environ.get("AGENTOPS_ENV_DATA_OPT_OUT", "False").lower() == "true",
)
def configure(
@@ -106,9 +103,7 @@ def initialize(self) -> Union[Session, None]:
if session:
for agent_args in self._pre_init_queue["agents"]:
- session.create_agent(
- name=agent_args["name"], agent_id=agent_args["agent_id"]
- )
+ session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
self._pre_init_queue["agents"] = []
return session
@@ -141,9 +136,7 @@ def add_tags(self, tags: List[str]) -> None:
session = self._safe_get_session()
if session is None:
- return logger.warning(
- "Could not add tags. Start a session by calling agentops.start_session()."
- )
+ return logger.warning("Could not add tags. Start a session by calling agentops.start_session().")
session.add_tags(tags=tags)
@@ -162,9 +155,7 @@ def set_tags(self, tags: List[str]) -> None:
session = self._safe_get_session()
if session is None:
- return logger.warning(
- "Could not set tags. Start a session by calling agentops.start_session()."
- )
+ return logger.warning("Could not set tags. Start a session by calling agentops.start_session().")
else:
session.set_tags(tags=tags)
@@ -198,9 +189,7 @@ def record(self, event: Union[Event, ErrorEvent]) -> None:
session = self._safe_get_session()
if session is None:
- return logger.error(
- "Could not record event. Start a session by calling agentops.start_session()."
- )
+ return logger.error("Could not record event. Start a session by calling agentops.start_session().")
session.record(event)
def start_session(
@@ -244,9 +233,7 @@ def start_session(
if self._pre_init_queue["agents"] and len(self._pre_init_queue["agents"]) > 0:
for agent_args in self._pre_init_queue["agents"]:
- session.create_agent(
- name=agent_args["name"], agent_id=agent_args["agent_id"]
- )
+ session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"])
self._pre_init_queue["agents"] = []
self._sessions.append(session)
@@ -277,9 +264,7 @@ def end_session(
if is_auto_end and self._config.skip_auto_end_session:
return
- token_cost = session.end_session(
- end_state=end_state, end_state_reason=end_state_reason, video=video
- )
+ token_cost = session.end_session(end_state=end_state, end_state_reason=end_state_reason, video=video)
return token_cost
@@ -299,9 +284,7 @@ def create_agent(
# if no session passed, assume single session
session = self._safe_get_session()
if session is None:
- self._pre_init_queue["agents"].append(
- {"name": name, "agent_id": agent_id}
- )
+ self._pre_init_queue["agents"].append({"name": name, "agent_id": agent_id})
else:
session.create_agent(name=name, agent_id=agent_id)
@@ -326,9 +309,7 @@ def signal_handler(signum, frame):
"""
signal_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
logger.info("%s detected. Ending session...", signal_name)
- self.end_session(
- end_state="Fail", end_state_reason=f"Signal {signal_name} detected"
- )
+ self.end_session(end_state="Fail", end_state_reason=f"Signal {signal_name} detected")
sys.exit(0)
def handle_exception(exc_type, exc_value, exc_traceback):
@@ -341,9 +322,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
exc_traceback (TracebackType): A traceback object encapsulating the call stack at the
point where the exception originally occurred.
"""
- formatted_traceback = "".join(
- traceback.format_exception(exc_type, exc_value, exc_traceback)
- )
+ formatted_traceback = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
for session in self._sessions:
session.end_session(
@@ -376,13 +355,7 @@ def add_pre_init_warning(self, message: str):
# replaces the session currently stored with a specific session_id, with a new session
def _update_session(self, session: Session):
self._sessions[
- self._sessions.index(
- [
- sess
- for sess in self._sessions
- if sess.session_id == session.session_id
- ][0]
- )
+ self._sessions.index([sess for sess in self._sessions if sess.session_id == session.session_id][0])
] = session
def _safe_get_session(self) -> Optional[Session]:
@@ -392,9 +365,7 @@ def _safe_get_session(self) -> Optional[Session]:
return self._sessions[0]
if len(self._sessions) > 1:
- calling_function = inspect.stack()[
- 2
- ].function # Using index 2 because we have a wrapper at index 1
+ calling_function = inspect.stack()[2].function # Using index 2 because we have a wrapper at index 1
return logger.warning(
f"Multiple sessions detected. You must use session.{calling_function}(). More info: https://docs.agentops.ai/v1/concepts/core-concepts#session-management"
)
diff --git a/agentops/decorators.py b/agentops/decorators.py
index d291c416..860b7c73 100644
--- a/agentops/decorators.py
+++ b/agentops/decorators.py
@@ -45,9 +45,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
- name: func_args[name].default
- for name in arg_names
- if func_args[name].default is not inspect._empty
+ name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
@@ -111,9 +109,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
- name: func_args[name].default
- for name in arg_names
- if func_args[name].default is not inspect._empty
+ name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
@@ -191,9 +187,7 @@ async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
- name: func_args[name].default
- for name in arg_names
- if func_args[name].default is not inspect._empty
+ name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
@@ -257,9 +251,7 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
arg_names = list(func_args.keys())
# Get default values
arg_values = {
- name: func_args[name].default
- for name in arg_names
- if func_args[name].default is not inspect._empty
+ name: func_args[name].default for name in arg_names if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
@@ -338,12 +330,8 @@ def new_init(self, *args, **kwargs):
session=session,
)
except AttributeError as e:
- Client().add_pre_init_warning(
- f"Failed to track an agent {name} with the @track_agent decorator."
- )
- logger.warning(
- "Failed to track an agent with the @track_agent decorator."
- )
+ Client().add_pre_init_warning(f"Failed to track an agent {name} with the @track_agent decorator.")
+ logger.warning("Failed to track an agent with the @track_agent decorator.")
original_init(self, *args, **kwargs)
obj.__init__ = new_init
@@ -351,7 +339,8 @@ def new_init(self, *args, **kwargs):
elif inspect.isfunction(obj):
obj.agent_ops_agent_id = str(uuid4()) # type: ignore
Client().create_agent(
- name=obj.agent_ops_agent_name, agent_id=obj.agent_ops_agent_id # type: ignore
+ name=obj.agent_ops_agent_name,
+ agent_id=obj.agent_ops_agent_id, # type: ignore
)
else:
diff --git a/agentops/helpers.py b/agentops/helpers.py
index fbcd21d5..02530213 100644
--- a/agentops/helpers.py
+++ b/agentops/helpers.py
@@ -9,7 +9,6 @@
from .log_config import logger
from uuid import UUID
-from importlib.metadata import version
def get_ISO_time():
@@ -38,7 +37,9 @@ def filter_dict(obj):
k: (
filter_dict(v)
if isinstance(v, (dict, list)) or is_jsonable(v)
- else str(v) if isinstance(v, UUID) else ""
+ else str(v)
+ if isinstance(v, UUID)
+ else ""
)
for k, v in obj.items()
}
@@ -47,7 +48,9 @@ def filter_dict(obj):
(
filter_dict(x)
if isinstance(x, (dict, list)) or is_jsonable(x)
- else str(x) if isinstance(x, UUID) else ""
+ else str(x)
+ if isinstance(x, UUID)
+ else ""
)
for x in obj
]
@@ -85,9 +88,7 @@ def remove_unwanted_items(value):
"""Recursively remove self key and None/... values from dictionaries so they aren't serialized"""
if isinstance(value, dict):
return {
- k: remove_unwanted_items(v)
- for k, v in value.items()
- if v is not None and v is not ... and k != "self"
+ k: remove_unwanted_items(v) for k, v in value.items() if v is not None and v is not ... and k != "self"
}
elif isinstance(value, list):
return [remove_unwanted_items(item) for item in value]
@@ -106,9 +107,7 @@ def check_call_stack_for_agent_id() -> Union[UUID, None]:
# We stop looking up the stack at main because after that we see global variables
if var == "__main__":
return None
- if hasattr(var, "agent_ops_agent_id") and getattr(
- var, "agent_ops_agent_id"
- ):
+ if hasattr(var, "agent_ops_agent_id") and getattr(var, "agent_ops_agent_id"):
logger.debug(
"LLM call from agent named: %s",
getattr(var, "agent_ops_agent_name"),
@@ -141,7 +140,7 @@ def check_agentops_update():
if not latest_version == current_version:
logger.warning(
- f" WARNING: agentops is out of date. Please update with the command: 'pip install --upgrade agentops'"
+ " WARNING: agentops is out of date. Please update with the command: 'pip install --upgrade agentops'"
)
except Exception as e:
logger.debug(f"Failed to check for updates: {e}")
diff --git a/agentops/host_env.py b/agentops/host_env.py
index e466c86e..5307dec4 100644
--- a/agentops/host_env.py
+++ b/agentops/host_env.py
@@ -51,8 +51,7 @@ def get_installed_packages():
return {
# TODO: add to opt out
"Installed Packages": {
- dist.metadata.get("Name"): dist.metadata.get("Version")
- for dist in importlib.metadata.distributions()
+ dist.metadata.get("Name"): dist.metadata.get("Version") for dist in importlib.metadata.distributions()
}
}
except:
@@ -125,9 +124,7 @@ def get_disk_details():
}
except OSError as inaccessible:
# Skip inaccessible partitions, such as removable drives with no media
- logger.debug(
- "Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible
- )
+ logger.debug("Mountpoint %s inaccessible: %s", partition.mountpoint, inaccessible)
return disk_info
diff --git a/agentops/http_client.py b/agentops/http_client.py
index 32d7d336..caa18b27 100644
--- a/agentops/http_client.py
+++ b/agentops/http_client.py
@@ -22,10 +22,7 @@ class HttpStatus(Enum):
class Response:
-
- def __init__(
- self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None
- ):
+ def __init__(self, status: HttpStatus = HttpStatus.UNKNOWN, body: Optional[dict] = None):
self.status: HttpStatus = status
self.code: int = status.value
self.body = body if body else {}
@@ -57,7 +54,6 @@ def get_status(code: int) -> HttpStatus:
class HttpClient:
-
@staticmethod
def post(
url: str,
@@ -82,17 +78,13 @@ def post(
if jwt is not None:
JSON_HEADER["Authorization"] = f"Bearer {jwt}"
- res = request_session.post(
- url, data=payload, headers=JSON_HEADER, timeout=20
- )
+ res = request_session.post(url, data=payload, headers=JSON_HEADER, timeout=20)
result.parse(res)
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
- raise ApiServerException(
- "Could not reach API server - connection timed out"
- )
+ raise ApiServerException("Could not reach API server - connection timed out")
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
@@ -145,9 +137,7 @@ def get(
except requests.exceptions.Timeout:
result.code = 408
result.status = HttpStatus.TIMEOUT
- raise ApiServerException(
- "Could not reach API server - connection timed out"
- )
+ raise ApiServerException("Could not reach API server - connection timed out")
except requests.exceptions.HTTPError as e:
try:
result.parse(e.response)
diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py
index 1bd8c3b7..a5852d8c 100644
--- a/agentops/llms/__init__.py
+++ b/agentops/llms/__init__.py
@@ -66,17 +66,13 @@ def override_api(self):
if api == "litellm":
module_version = version(api)
if module_version is None:
- logger.warning(
- f"Cannot determine LiteLLM version. Only LiteLLM>=1.3.1 supported."
- )
+ logger.warning("Cannot determine LiteLLM version. Only LiteLLM>=1.3.1 supported.")
if Version(module_version) >= parse("1.3.1"):
provider = LiteLLMProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only LiteLLM>=1.3.1 supported. v{module_version} found."
- )
+ logger.warning(f"Only LiteLLM>=1.3.1 supported. v{module_version} found.")
return # If using an abstraction like litellm, do not patch the underlying LLM APIs
if api == "openai":
@@ -96,17 +92,13 @@ def override_api(self):
# Patch cohere v5.4.0+ methods
module_version = version(api)
if module_version is None:
- logger.warning(
- f"Cannot determine Cohere version. Only Cohere>=5.4.0 supported."
- )
+ logger.warning("Cannot determine Cohere version. Only Cohere>=5.4.0 supported.")
if Version(module_version) >= parse("5.4.0"):
provider = CohereProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only Cohere>=5.4.0 supported. v{module_version} found."
- )
+ logger.warning(f"Only Cohere>=5.4.0 supported. v{module_version} found.")
if api == "ollama":
module_version = version(api)
@@ -115,9 +107,7 @@ def override_api(self):
provider = OllamaProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only Ollama>=0.0.1 supported. v{module_version} found."
- )
+ logger.warning(f"Only Ollama>=0.0.1 supported. v{module_version} found.")
if api == "groq":
module_version = version(api)
@@ -126,25 +116,19 @@ def override_api(self):
provider = GroqProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only Groq>=0.9.0 supported. v{module_version} found."
- )
+ logger.warning(f"Only Groq>=0.9.0 supported. v{module_version} found.")
if api == "anthropic":
module_version = version(api)
if module_version is None:
- logger.warning(
- f"Cannot determine Anthropic version. Only Anthropic>=0.32.0 supported."
- )
+ logger.warning("Cannot determine Anthropic version. Only Anthropic>=0.32.0 supported.")
if Version(module_version) >= parse("0.32.0"):
provider = AnthropicProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only Anthropic>=0.32.0 supported. v{module_version} found."
- )
+ logger.warning(f"Only Anthropic>=0.32.0 supported. v{module_version} found.")
if api == "mistralai":
module_version = version(api)
@@ -153,25 +137,19 @@ def override_api(self):
provider = MistralProvider(self.client)
provider.override()
else:
- logger.warning(
- f"Only MistralAI>=1.0.1 supported. v{module_version} found."
- )
+ logger.warning(f"Only MistralAI>=1.0.1 supported. v{module_version} found.")
if api == "ai21":
module_version = version(api)
if module_version is None:
- logger.warning(
- f"Cannot determine AI21 version. Only AI21>=2.0.0 supported."
- )
+ logger.warning("Cannot determine AI21 version. Only AI21>=2.0.0 supported.")
if Version(module_version) >= parse("2.0.0"):
provider = AI21Provider(self.client)
provider.override()
else:
- logger.warning(
- f"Only AI21>=2.0.0 supported. v{module_version} found."
- )
+ logger.warning(f"Only AI21>=2.0.0 supported. v{module_version} found.")
def stop_instrumenting(self):
OpenAiProvider(self.client).undo_override()
diff --git a/agentops/llms/ai21.py b/agentops/llms/ai21.py
index 533ad276..4e080a5b 100644
--- a/agentops/llms/ai21.py
+++ b/agentops/llms/ai21.py
@@ -14,7 +14,6 @@
@singleton
class AI21Provider(InstrumentedProvider):
-
original_create = None
original_create_async = None
original_answer = None
@@ -24,9 +23,7 @@ def __init__(self, client):
super().__init__(client)
self._provider_name = "AI21"
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ):
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None):
"""Handle responses for AI21"""
from ai21.stream.stream import Stream
from ai21.stream.async_stream import AsyncStream
@@ -51,9 +48,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
accumulated_delta = llm_event.returns.choices[0].delta
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
- llm_event.prompt = [
- message.model_dump() for message in kwargs["messages"]
- ]
+ llm_event.prompt = [message.model_dump() for message in kwargs["messages"]]
# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]
@@ -80,9 +75,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
@@ -119,9 +112,7 @@ async def async_generator():
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
- llm_event.prompt = [
- message.model_dump() for message in kwargs["messages"]
- ]
+ llm_event.prompt = [message.model_dump() for message in kwargs["messages"]]
llm_event.prompt_tokens = response.usage.prompt_tokens
llm_event.completion = response.choices[0].message.model_dump()
llm_event.completion_tokens = response.usage.completion_tokens
diff --git a/agentops/llms/anthropic.py b/agentops/llms/anthropic.py
index 183694a4..bc63d85c 100644
--- a/agentops/llms/anthropic.py
+++ b/agentops/llms/anthropic.py
@@ -13,7 +13,6 @@
@singleton
class AnthropicProvider(InstrumentedProvider):
-
original_create = None
original_create_async = None
@@ -23,9 +22,7 @@ def __init__(self, client):
self.tool_event = {}
self.tool_id = ""
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ):
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None):
"""Handle responses for Anthropic"""
from anthropic import Stream, AsyncStream
from anthropic.resources import AsyncMessages
@@ -66,9 +63,7 @@ def handle_stream_chunk(chunk: Message):
llm_event.completion["content"] += chunk.delta.text
elif chunk.delta.type == "input_json_delta":
- self.tool_event[self.tool_id].logs[
- "input"
- ] += chunk.delta.partial_json
+ self.tool_event[self.tool_id].logs["input"] += chunk.delta.partial_json
elif chunk.type == "content_block_stop":
pass
@@ -81,9 +76,7 @@ def handle_stream_chunk(chunk: Message):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
@@ -178,9 +171,7 @@ def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = None
pydantic_models = (
@@ -195,9 +186,7 @@ def patched_function(*args, **kwargs):
for pydantic_model in pydantic_models:
try:
- result_model = pydantic_model.model_validate_json(
- completion_override
- )
+ result_model = pydantic_model.model_validate_json(completion_override)
break
except Exception as e:
pass
@@ -209,18 +198,12 @@ def patched_function(*args, **kwargs):
f"{pprint.pformat(completion_override)}"
)
return None
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# Call the original function with its original arguments
- original_func = (
- self.original_create_beta if is_beta else self.original_create
- )
+ original_func = self.original_create_beta if is_beta else self.original_create
result = original_func(*args, **kwargs)
- return self.handle_response(
- result, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result, kwargs, init_timestamp, session=session)
return patched_function
@@ -252,9 +235,7 @@ async def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = None
pydantic_models = (
@@ -269,9 +250,7 @@ async def patched_function(*args, **kwargs):
for pydantic_model in pydantic_models:
try:
- result_model = pydantic_model.model_validate_json(
- completion_override
- )
+ result_model = pydantic_model.model_validate_json(completion_override)
break
except Exception as e:
pass
@@ -284,20 +263,12 @@ async def patched_function(*args, **kwargs):
)
return None
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# Call the original function with its original arguments
- original_func = (
- self.original_create_async_beta
- if is_beta
- else self.original_create_async
- )
+ original_func = self.original_create_async_beta if is_beta else self.original_create_async
result = await original_func(*args, **kwargs)
- return self.handle_response(
- result, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result, kwargs, init_timestamp, session=session)
return patched_function
diff --git a/agentops/llms/cohere.py b/agentops/llms/cohere.py
index ad8f93a5..0756bc69 100644
--- a/agentops/llms/cohere.py
+++ b/agentops/llms/cohere.py
@@ -36,9 +36,7 @@ def undo_override(self):
def __init__(self, client):
super().__init__(client)
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ):
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None):
"""Handle responses for Cohere versions >v5.4.0"""
from cohere.types.streamed_chat_response import (
StreamedChatResponse_CitationGeneration,
@@ -59,7 +57,6 @@ def handle_response(
self.action_events = {}
def handle_stream_chunk(chunk, session: Optional[Session] = None):
-
# We take the first chunk and accumulate the deltas from all subsequent chunks to build one full chat completion
if isinstance(chunk, StreamedChatResponse_StreamStart):
llm_event.returns = chunk
@@ -102,9 +99,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
citation_dict = citation.dict()
# Replace document_ids with the actual documents
citation_dict["documents"] = [
- documents[doc_id]
- for doc_id in citation_dict["document_ids"]
- if doc_id in documents
+ documents[doc_id] for doc_id in citation_dict["document_ids"] if doc_id in documents
]
del citation_dict["document_ids"]
@@ -120,12 +115,10 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
pass
elif isinstance(chunk, StreamedChatResponse_CitationGeneration):
for citation in chunk.citations:
- self.action_events[f"{citation.start}.{citation.end}"] = (
- ActionEvent(
- action_type="citation",
- init_timestamp=get_ISO_time(),
- params=citation.text,
- )
+ self.action_events[f"{citation.start}.{citation.end}"] = ActionEvent(
+ action_type="citation",
+ init_timestamp=get_ISO_time(),
+ params=citation.text,
)
elif isinstance(chunk, StreamedChatResponse_SearchQueriesGeneration):
for query in chunk.search_queries:
@@ -138,9 +131,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
pass
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
diff --git a/agentops/llms/groq.py b/agentops/llms/groq.py
index ca869638..23d5b8db 100644
--- a/agentops/llms/groq.py
+++ b/agentops/llms/groq.py
@@ -29,9 +29,7 @@ def undo_override(self):
completions.Completions.create = self.original_create
completions.AsyncCompletions.create = self.original_create
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ):
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None):
"""Handle responses for OpenAI versions >v1.0.0"""
from groq import AsyncStream, Stream
from groq.resources.chat import AsyncCompletions
@@ -81,9 +79,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
diff --git a/agentops/llms/instrumented_provider.py b/agentops/llms/instrumented_provider.py
index f33f397a..f4dc78cf 100644
--- a/agentops/llms/instrumented_provider.py
+++ b/agentops/llms/instrumented_provider.py
@@ -14,9 +14,7 @@ def __init__(self, client):
self.client = client
@abstractmethod
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ) -> dict:
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
pass
@abstractmethod
diff --git a/agentops/llms/litellm.py b/agentops/llms/litellm.py
index 30b4c25a..0f73c22b 100644
--- a/agentops/llms/litellm.py
+++ b/agentops/llms/litellm.py
@@ -40,9 +40,7 @@ def undo_override(self):
completions.Completions.create = self.original_oai_create
completions.AsyncCompletions.create = self.original_oai_create_async
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ) -> dict:
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for OpenAI versions >v1.0.0"""
from openai import AsyncStream, Stream
from openai.resources import AsyncCompletions
@@ -93,9 +91,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
@@ -186,14 +182,10 @@ def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
@@ -222,14 +214,10 @@ async def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
diff --git a/agentops/llms/mistral.py b/agentops/llms/mistral.py
index 8be21946..770d4a81 100644
--- a/agentops/llms/mistral.py
+++ b/agentops/llms/mistral.py
@@ -11,7 +11,6 @@
class MistralProvider(InstrumentedProvider):
-
original_complete = None
original_complete_async = None
original_stream = None
@@ -21,9 +20,7 @@ def __init__(self, client):
super().__init__(client)
self._provider_name = "Mistral"
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ) -> dict:
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for Mistral"""
from mistralai import Chat
from mistralai.types import UNSET, UNSET_SENTINEL
@@ -73,9 +70,7 @@ def handle_stream_chunk(chunk: dict):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
diff --git a/agentops/llms/ollama.py b/agentops/llms/ollama.py
index e5779283..5c998897 100644
--- a/agentops/llms/ollama.py
+++ b/agentops/llms/ollama.py
@@ -16,9 +16,7 @@ class OllamaProvider(InstrumentedProvider):
original_create = None
original_create_async = None
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ) -> dict:
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
def handle_stream_chunk(chunk: dict):
@@ -84,9 +82,7 @@ def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
- return self.handle_response(
- result, kwargs, init_timestamp, session=kwargs.get("session", None)
- )
+ return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
# Override the original method with the patched one
ollama.chat = patched_function
diff --git a/agentops/llms/openai.py b/agentops/llms/openai.py
index c99523d7..24756971 100644
--- a/agentops/llms/openai.py
+++ b/agentops/llms/openai.py
@@ -14,7 +14,6 @@
@singleton
class OpenAiProvider(InstrumentedProvider):
-
original_create = None
original_create_async = None
@@ -22,9 +21,7 @@ def __init__(self, client):
super().__init__(client)
self._provider_name = "OpenAI"
- def handle_response(
- self, response, kwargs, init_timestamp, session: Optional[Session] = None
- ) -> dict:
+ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for OpenAI versions >v1.0.0"""
from openai import AsyncStream, Stream
from openai.resources import AsyncCompletions
@@ -74,9 +71,7 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self._safe_record(session, llm_event)
except Exception as e:
- self._safe_record(
- session, ErrorEvent(trigger_event=llm_event, exception=e)
- )
+ self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
@@ -157,17 +152,13 @@ def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = None
pydantic_models = (ChatCompletion, ChatCompletionChunk)
for pydantic_model in pydantic_models:
try:
- result_model = pydantic_model.model_validate_json(
- completion_override
- )
+ result_model = pydantic_model.model_validate_json(completion_override)
break
except Exception as e:
pass
@@ -179,9 +170,7 @@ def patched_function(*args, **kwargs):
f"{pprint.pformat(completion_override)}"
)
return None
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
@@ -202,24 +191,19 @@ def _override_openai_v1_async_completion(self):
self.original_create_async = completions.AsyncCompletions.create
async def patched_function(*args, **kwargs):
-
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
- completion_override = fetch_completion_override_from_time_travel_cache(
- kwargs
- )
+ completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = None
pydantic_models = (ChatCompletion, ChatCompletionChunk)
for pydantic_model in pydantic_models:
try:
- result_model = pydantic_model.model_validate_json(
- completion_override
- )
+ result_model = pydantic_model.model_validate_json(completion_override)
break
except Exception as e:
pass
@@ -231,9 +215,7 @@ async def patched_function(*args, **kwargs):
f"{pprint.pformat(completion_override)}"
)
return None
- return self.handle_response(
- result_model, kwargs, init_timestamp, session=session
- )
+ return self.handle_response(result_model, kwargs, init_timestamp, session=session)
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
diff --git a/agentops/meta_client.py b/agentops/meta_client.py
index 6e222771..6cc7ed2e 100644
--- a/agentops/meta_client.py
+++ b/agentops/meta_client.py
@@ -14,9 +14,7 @@ class MetaClient(type):
def __new__(cls, name, bases, dct):
# Wrap each method with the handle_exceptions decorator
for method_name, method in dct.items():
- if (
- callable(method) and not method_name.startswith("__")
- ) or method_name == "__init__":
+ if (callable(method) and not method_name.startswith("__")) or method_name == "__init__":
dct[method_name] = handle_exceptions(method)
return super().__new__(cls, name, bases, dct)
diff --git a/agentops/partners/autogen_logger.py b/agentops/partners/autogen_logger.py
index 4080d792..77aca142 100644
--- a/agentops/partners/autogen_logger.py
+++ b/agentops/partners/autogen_logger.py
@@ -72,13 +72,9 @@ def log_chat_completion(
def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
"""Calls agentops.create_agent"""
ao_agent_id = agentops.create_agent(agent.name, str(uuid4()))
- self.agent_store.append(
- {"agentops_id": ao_agent_id, "autogen_id": str(id(agent))}
- )
+ self.agent_store.append({"agentops_id": ao_agent_id, "autogen_id": str(id(agent))})
- def log_event(
- self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]
- ) -> None:
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
"""Records an ActionEvent to AgentOps session"""
event = ActionEvent(action_type=name)
agentops_id = self._get_agentops_id_from_agent(str(id(source)))
@@ -86,9 +82,7 @@ def log_event(
event.params = kwargs
agentops.record(event)
- def log_function_use(
- self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: any
- ):
+ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: any):
"""Records a ToolEvent to AgentOps session"""
event = ToolEvent()
agentops_id = self._get_agentops_id_from_agent(str(id(source)))
diff --git a/agentops/partners/langchain_callback_handler.py b/agentops/partners/langchain_callback_handler.py
index 847abe66..7d3ded9b 100644
--- a/agentops/partners/langchain_callback_handler.py
+++ b/agentops/partners/langchain_callback_handler.py
@@ -131,12 +131,8 @@ def on_llm_end(
llm_event.end_timestamp = get_ISO_time()
llm_event.completion = response.generations[0][0].text
if response.llm_output is not None:
- llm_event.prompt_tokens = response.llm_output["token_usage"][
- "prompt_tokens"
- ]
- llm_event.completion_tokens = response.llm_output["token_usage"][
- "completion_tokens"
- ]
+ llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"]
+ llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"]
if len(response.generations) == 0:
# TODO: more descriptive error
@@ -297,9 +293,7 @@ def on_retriever_end(
**kwargs: Any,
) -> None:
action_event: ActionEvent = self.events.retriever[str(run_id)]
- action_event.logs = (
- documents # TODO: Adding this. Might want to add elsewhere e.g. params
- )
+ action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params
action_event.end_timestamp = get_ISO_time()
self.ao_client.record(action_event)
@@ -326,9 +320,7 @@ def on_agent_action(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
- self.agent_actions[run_id].append(
- ActionEvent(params={"action": action, **kwargs}, action_type="agent")
- )
+ self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent"))
@debug_print_function_params
def on_agent_finish(
@@ -367,9 +359,7 @@ def on_retry(
@property
def session_id(self):
- raise DeprecationWarning(
- "session_id is deprecated in favor of current_session_ids"
- )
+ raise DeprecationWarning("session_id is deprecated in favor of current_session_ids")
@property
def current_session_ids(self):
@@ -387,7 +377,6 @@ def __init__(
max_queue_size: Optional[int] = None,
tags: Optional[List[str]] = None,
):
-
client_params: Dict[str, Any] = {
"api_key": api_key,
"endpoint": endpoint,
@@ -396,9 +385,7 @@ def __init__(
"tags": tags,
}
- self.ao_client = AOClient(
- **{k: v for k, v in client_params.items() if v is not None}, override=False
- )
+ self.ao_client = AOClient(**{k: v for k, v in client_params.items() if v is not None}, override=False)
self.events = Events()
self.agent_actions: Dict[UUID, List[ActionEvent]] = defaultdict(list)
@@ -482,12 +469,8 @@ async def on_llm_end(
llm_event.end_timestamp = get_ISO_time()
llm_event.completion = response.generations[0][0].text
if response.llm_output is not None:
- llm_event.prompt_tokens = response.llm_output["token_usage"][
- "prompt_tokens"
- ]
- llm_event.completion_tokens = response.llm_output["token_usage"][
- "completion_tokens"
- ]
+ llm_event.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"]
+ llm_event.completion_tokens = response.llm_output["token_usage"]["completion_tokens"]
if len(response.generations) == 0:
# TODO: more descriptive error
@@ -645,9 +628,7 @@ async def on_retriever_end(
**kwargs: Any,
) -> None:
action_event: ActionEvent = self.events.retriever[str(run_id)]
- action_event.logs = (
- documents # TODO: Adding this. Might want to add elsewhere e.g. params
- )
+ action_event.logs = documents # TODO: Adding this. Might want to add elsewhere e.g. params
action_event.end_timestamp = get_ISO_time()
self.ao_client.record(action_event)
@@ -674,9 +655,7 @@ async def on_agent_action(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
- self.agent_actions[run_id].append(
- ActionEvent(params={"action": action, **kwargs}, action_type="agent")
- )
+ self.agent_actions[run_id].append(ActionEvent(params={"action": action, **kwargs}, action_type="agent"))
@debug_print_function_params
async def on_agent_finish(
diff --git a/agentops/session.py b/agentops/session.py
index 58225b6c..3cfc1303 100644
--- a/agentops/session.py
+++ b/agentops/session.py
@@ -86,14 +86,11 @@ def end_session(
end_state_reason: Optional[str] = None,
video: Optional[str] = None,
) -> Union[Decimal, None]:
-
if not self.is_running:
return
if not any(end_state == state.value for state in EndState):
- return logger.warning(
- "Invalid end_state. Please use one of the EndState enums"
- )
+ return logger.warning("Invalid end_state. Please use one of the EndState enums")
self.end_timestamp = get_ISO_time()
self.end_state = end_state
@@ -146,9 +143,7 @@ def format_duration(start_time, end_time):
formatted_cost = (
"{:.2f}".format(token_cost_d)
if token_cost_d == 0
- else "{:.6f}".format(
- token_cost_d.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP)
- )
+ else "{:.6f}".format(token_cost_d.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP))
)
analytics = (
@@ -222,8 +217,7 @@ def record(self, event: Union[Event, ErrorEvent]):
if event.trigger_event:
if (
not event.trigger_event.end_timestamp
- or event.trigger_event.init_timestamp
- == event.trigger_event.end_timestamp
+ or event.trigger_event.init_timestamp == event.trigger_event.end_timestamp
):
event.trigger_event.end_timestamp = get_ISO_time()
@@ -338,9 +332,7 @@ def _flush_queue(self) -> None:
return logger.error(f"Could not post events - {e}")
logger.debug("\n")
- logger.debug(
- f"Session request to {self.config.endpoint}/v2/create_events"
- )
+ logger.debug(f"Session request to {self.config.endpoint}/v2/create_events")
logger.debug(serialized_payload)
logger.debug("\n")
diff --git a/agentops/singleton.py b/agentops/singleton.py
index 8200510e..b22e4edc 100644
--- a/agentops/singleton.py
+++ b/agentops/singleton.py
@@ -2,7 +2,6 @@
def singleton(class_):
-
def getinstance(*args, **kwargs):
if class_ not in ao_instances:
ao_instances[class_] = class_(*args, **kwargs)
@@ -12,7 +11,6 @@ def getinstance(*args, **kwargs):
def conditional_singleton(class_):
-
def getinstance(*args, **kwargs):
use_singleton = kwargs.pop("use_singleton", True)
if use_singleton:
diff --git a/agentops/time_travel.py b/agentops/time_travel.py
index 6c1c7588..55ad6662 100644
--- a/agentops/time_travel.py
+++ b/agentops/time_travel.py
@@ -20,9 +20,7 @@ def __init__(self):
try:
with open(cache_path, "r") as file:
time_travel_cache_json = json.load(file)
- self._completion_overrides = time_travel_cache_json.get(
- "completion_overrides"
- )
+ self._completion_overrides = time_travel_cache_json.get("completion_overrides")
except FileNotFoundError:
return
@@ -99,20 +97,14 @@ def find_cache_hit(prompt_messages, completion_overrides):
continue
if all(
- isinstance(a, dict)
- and isinstance(b, dict)
- and a.get("content") == b.get("content")
+ isinstance(a, dict) and isinstance(b, dict) and a.get("content") == b.get("content")
for a, b in zip(prompt_messages, cached_messages)
):
return value
except (SyntaxError, ValueError, TypeError) as e:
- print(
- f"{ttd_prepend_string} Error - Error processing completion_overrides item: {e}"
- )
+ print(f"{ttd_prepend_string} Error - Error processing completion_overrides item: {e}")
except Exception as e:
- print(
- f"{ttd_prepend_string} Error - Unexpected error in find_cache_hit: {e}"
- )
+ print(f"{ttd_prepend_string} Error - Unexpected error in find_cache_hit: {e}")
return None
@@ -143,9 +135,7 @@ def set_time_travel_active_state(is_active: bool):
try:
yaml.dump(config, config_file)
except:
- print(
- f"{ttd_prepend_string} Error - Unable to write to {config_path}. Time Travel not activated"
- )
+ print(f"{ttd_prepend_string} Error - Unable to write to {config_path}. Time Travel not activated")
return
if is_active:
diff --git a/pyproject.toml b/pyproject.toml
index 59954095..9093e7ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,6 +35,7 @@ dev = [
"pytest-mock",
"pyfakefs",
"requests_mock==1.11.0",
+ "ruff",
"tach~=0.9",
]
langchain = [
@@ -54,3 +55,50 @@ agentops = "agentops.cli:main"
[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
+
+[tool.ruff]
+line-length = 120
+
+[tool.ruff.lint]
+ignore = [
+ "F401", # Unused imports
+ "E712", # Comparison to True/False
+ "E711", # Comparison to None
+ "E722", # Bare except
+ "F821", # Undefined names
+ "F841", # Unused variables
+]
+
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".github",
+ ".hg",
+ ".ipynb_checkpoints",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pyenv",
+ ".pytest_cache",
+ ".vscode",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ ".vscode",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "docs",
+ "examples",
+ "node_modules",
+ "site-packages",
+ "venv",
+ "tests/core_manual_tests",
+]
diff --git a/tests/core_manual_tests/api_server/server.py b/tests/core_manual_tests/api_server/server.py
index 61ee8c67..5ae5e1ae 100644
--- a/tests/core_manual_tests/api_server/server.py
+++ b/tests/core_manual_tests/api_server/server.py
@@ -14,7 +14,6 @@
@app.get("/completion")
def completion():
-
session = agentops.start_session(tags=["api-server-test"])
messages = [{"role": "user", "content": "Hello"}]
diff --git a/tests/core_manual_tests/multi_session_llm.py b/tests/core_manual_tests/multi_session_llm.py
index 25a652ae..bd55477c 100644
--- a/tests/core_manual_tests/multi_session_llm.py
+++ b/tests/core_manual_tests/multi_session_llm.py
@@ -25,9 +25,7 @@
session_1.record(ActionEvent(action_type="test event"))
# option 2: add session as a keyword argument
-response2 = openai.chat.completions.create(
- model="gpt-3.5-turbo", messages=messages, temperature=0.5, session=session_2
-)
+response2 = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages, temperature=0.5, session=session_2)
session_1.end_session(end_state="Success")
session_2.end_session(end_state="Success")
diff --git a/tests/core_manual_tests/providers/litellm_canary.py b/tests/core_manual_tests/providers/litellm_canary.py
index ecd2186a..0b9cf98c 100644
--- a/tests/core_manual_tests/providers/litellm_canary.py
+++ b/tests/core_manual_tests/providers/litellm_canary.py
@@ -7,9 +7,7 @@
load_dotenv()
agentops.init(default_tags=["litellm-provider-test"])
-response = litellm.completion(
- model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}]
-)
+response = litellm.completion(model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}])
stream_response = litellm.completion(
model="gpt-3.5-turbo",
diff --git a/tests/langchain_handlers/_test_langchain_handler.py b/tests/langchain_handlers/_test_langchain_handler.py
index 97cc2205..8f468c54 100644
--- a/tests/langchain_handlers/_test_langchain_handler.py
+++ b/tests/langchain_handlers/_test_langchain_handler.py
@@ -14,13 +14,9 @@
AGENTOPS_API_KEY = os.environ.get("AGENTOPS_API_KEY")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
-agentops_handler = AgentOpsLangchainCallbackHandler(
- api_key=AGENTOPS_API_KEY, tags=["Langchain Example"]
-)
+agentops_handler = AgentOpsLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Langchain Example"])
-llm = ChatOpenAI(
- openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo"
-)
+llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo")
@tool
@@ -43,9 +39,7 @@ def find_movie(genre) -> str:
llm,
agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
- callbacks=[
- agentops_handler
- ], # You must pass in a callback handler to record your agent
+ callbacks=[agentops_handler], # You must pass in a callback handler to record your agent
handle_parsing_errors=True,
)
@@ -56,22 +50,16 @@ def find_movie(genre) -> str:
########
# Async
-agentops_handler = AgentOpsAsyncLangchainCallbackHandler(
- api_key=AGENTOPS_API_KEY, tags=["Async Example"]
-)
+agentops_handler = AgentOpsAsyncLangchainCallbackHandler(api_key=AGENTOPS_API_KEY, tags=["Async Example"])
-llm = ChatOpenAI(
- openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo"
-)
+llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[agentops_handler], model="gpt-3.5-turbo")
agent = initialize_agent(
tools,
llm,
agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
- callbacks=[
- agentops_handler
- ], # You must pass in a callback handler to record your agent
+ callbacks=[agentops_handler], # You must pass in a callback handler to record your agent
handle_parsing_errors=True,
)
diff --git a/tests/openai_handlers/_test_gpt_vision.py b/tests/openai_handlers/_test_gpt_vision.py
index 29a04e7e..3aae1231 100644
--- a/tests/openai_handlers/_test_gpt_vision.py
+++ b/tests/openai_handlers/_test_gpt_vision.py
@@ -45,9 +45,7 @@ def encode_image(image_path):
# Path to your image
-image_path = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "logo_for_vision_test.png"
-)
+image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logo_for_vision_test.png")
# Getting the base64 string
base64_image = encode_image(image_path)
@@ -71,9 +69,7 @@ def encode_image(image_path):
"max_tokens": 300,
}
-response = requests.post(
- "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
-)
+response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
print(response.json())
diff --git a/tests/openai_handlers/_test_handler_openai_v0.py b/tests/openai_handlers/_test_handler_openai_v0.py
index 601be8da..3dc4d5e0 100644
--- a/tests/openai_handlers/_test_handler_openai_v0.py
+++ b/tests/openai_handlers/_test_handler_openai_v0.py
@@ -7,9 +7,7 @@
print("Running OpenAI