From 388101a3ad9947e0eabd1048d7864da04cce6a5a Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Fri, 3 May 2024 16:24:46 -0700 Subject: [PATCH] add stop_instrumenting (#185) * add stop_instrumenting * version bump * test deps --- agentops/__init__.py | 4 ++++ agentops/client.py | 5 ++++- agentops/llm_tracker.py | 22 ++++++++++++++++++++-- pyproject.toml | 2 +- tox.ini | 2 ++ 5 files changed, 31 insertions(+), 4 deletions(-) diff --git a/agentops/__init__.py b/agentops/__init__.py index 3d087f1e..1c2d4dfa 100755 --- a/agentops/__init__.py +++ b/agentops/__init__.py @@ -9,6 +9,7 @@ from .decorators import record_function from .agent import track_agent from .log_config import set_logging_level_info, set_logging_level_critial +from .langchain_callback_handler import LangchainCallbackHandler, AsyncLangchainCallbackHandler def init(api_key: Optional[str] = None, @@ -128,3 +129,6 @@ def set_parent_key(parent_key): parent_key (str): The API key of the parent organization to set. """ Client().set_parent_key(parent_key) + +def stop_instrumenting(): + Client().stop_instrumenting() diff --git a/agentops/client.py b/agentops/client.py index 11b4de24..7dc21559 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -136,7 +136,7 @@ def record(self, event: Union[Event, ErrorEvent]): Args: event (Event): The event to record. """ - if not event.end_timestamp or event.init_timestamp == event.end_timestamp: + if isinstance(event, Event) and not event.end_timestamp or event.init_timestamp == event.end_timestamp: event.end_timestamp = get_ISO_time() if self._session is not None and not self._session.has_ended and self._worker is not None: if isinstance(event, ErrorEvent): @@ -359,3 +359,6 @@ def set_parent_key(self, parent_key: str): @property def parent_key(self): return self.config.parent_key + + def stop_instrumenting(self): + self.llm_tracker.stop_instrumenting() diff --git a/agentops/llm_tracker.py b/agentops/llm_tracker.py index 0a8c78fe..2b823fe9 100644 --- a/agentops/llm_tracker.py +++ b/agentops/llm_tracker.py @@ -10,6 +10,8 @@ from typing import Optional import pprint +original_create = None +original_create_async = None class LlmTracker: SUPPORTED_APIS = { @@ -230,6 +232,7 @@ def override_openai_v1_completion(self): from openai.resources.chat import completions # Store the original method + global original_create original_create = completions.Completions.create def patched_function(*args, **kwargs): @@ -245,12 +248,13 @@ def override_openai_v1_async_completion(self): from openai.resources.chat import completions # Store the original method - original_create = completions.AsyncCompletions.create + global original_create_async + original_create_async = completions.AsyncCompletions.create async def patched_function(*args, **kwargs): # Call the original function with its original arguments init_timestamp = get_ISO_time() - result = await original_create(*args, **kwargs) + result = await original_create_async(*args, **kwargs) return self._handle_response_v1_openai(result, kwargs, init_timestamp) # Override the original method with the patched one @@ -345,3 +349,17 @@ def override_api(self): # Patch openai