From 8313a6962c1ec66167b256ce2f6e6a034dfa4180 Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Mon, 12 Aug 2024 12:23:10 -0700 Subject: [PATCH] llm refactor progress --- agentops/client.py | 2 +- agentops/helpers.py | 39 +-- agentops/llms/__init__.py | 6 - agentops/llms/cohere.py | 14 +- agentops/llms/groq.py | 10 +- agentops/llms/instrumented_provider.py | 26 ++ agentops/llms/ollama.py | 2 +- agentops/llms/openai.py | 468 +++++++++++++------------ agentops/llms/openai_v0.py | 126 +++++++ agentops/singleton.py | 30 ++ 10 files changed, 440 insertions(+), 283 deletions(-) create mode 100644 agentops/llms/instrumented_provider.py create mode 100644 agentops/llms/openai_v0.py create mode 100644 agentops/singleton.py diff --git a/agentops/client.py b/agentops/client.py index 4007e410..acb4fc47 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -19,7 +19,7 @@ from termcolor import colored from .event import Event, ErrorEvent -from .helpers import ( +from .singleton import ( conditional_singleton, ) from .session import Session, active_sessions diff --git a/agentops/helpers.py b/agentops/helpers.py index e04f4b57..4dbae3f5 100644 --- a/agentops/helpers.py +++ b/agentops/helpers.py @@ -8,42 +8,12 @@ import json from importlib.metadata import version, PackageNotFoundError +from . import Client from .log_config import logger from uuid import UUID from importlib.metadata import version import subprocess -ao_instances = {} - - -def singleton(class_): - - def getinstance(*args, **kwargs): - if class_ not in ao_instances: - ao_instances[class_] = class_(*args, **kwargs) - return ao_instances[class_] - - return getinstance - - -def conditional_singleton(class_): - - def getinstance(*args, **kwargs): - use_singleton = kwargs.pop("use_singleton", True) - if use_singleton: - if class_ not in ao_instances: - ao_instances[class_] = class_(*args, **kwargs) - return ao_instances[class_] - else: - return class_(*args, **kwargs) - - return getinstance - - -def clear_singletons(): - global ao_instances - ao_instances = {} - def get_ISO_time(): """ @@ -212,3 +182,10 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) return wrapper + + +def safe_record(session, event): + if session is not None: + session.record(event) + else: + Client().record(event) diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index f831bd17..76951435 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -173,9 +173,3 @@ def stop_instrumenting(self): undo_override_openai_v1_async_completion() undo_override_openai_v1_completion() undo_override_ollama(self) - - def _safe_record(self, session, event): - if session is not None: - session.record(event) - else: - self.client.record(event) diff --git a/agentops/llms/cohere.py b/agentops/llms/cohere.py index 87f4e2ed..64ad6164 100644 --- a/agentops/llms/cohere.py +++ b/agentops/llms/cohere.py @@ -47,7 +47,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): "content": chunk.response.text, } tracker.llm_event.end_timestamp = get_ISO_time() - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) # StreamedChatResponse_SearchResults = ActionEvent search_results = chunk.response.search_results @@ -80,7 +80,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): action_event.end_timestamp = get_ISO_time() for key, action_event in tracker.action_events.items(): - tracker._safe_record(session, action_event) + safe_record(session, action_event) elif isinstance(chunk, StreamedChatResponse_TextGeneration): tracker.llm_event.completion += chunk.text @@ -106,7 +106,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None): pass except Exception as e: - tracker._safe_record( + safe_record( session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) ) @@ -166,11 +166,9 @@ def generator(): tracker.llm_event.completion_tokens = response.meta.tokens.output_tokens tracker.llm_event.model = kwargs.get("model", "command-r-plus") - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) except Exception as e: - tracker._safe_record( - session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) - ) + safe_record(session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) logger.warning( @@ -194,7 +192,7 @@ def patched_function(*args, **kwargs): if "session" in kwargs.keys(): del kwargs["session"] result = original_chat(*args, **kwargs) - return tracker._handle_response_cohere( + return tracker.handle_response_cohere( result, kwargs, init_timestamp, session=session ) diff --git a/agentops/llms/groq.py b/agentops/llms/groq.py index 7391d3af..d2432af6 100644 --- a/agentops/llms/groq.py +++ b/agentops/llms/groq.py @@ -57,9 +57,9 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): } tracker.llm_event.end_timestamp = get_ISO_time() - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) except Exception as e: - tracker._safe_record( + safe_record( session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) ) @@ -111,11 +111,9 @@ async def async_generator(): tracker.llm_event.completion_tokens = response.usage.completion_tokens tracker.llm_event.model = response.model - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) except Exception as e: - tracker._safe_record( - session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) - ) + safe_record(session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) diff --git a/agentops/llms/instrumented_provider.py b/agentops/llms/instrumented_provider.py new file mode 100644 index 00000000..55652220 --- /dev/null +++ b/agentops/llms/instrumented_provider.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from agentops import Session + + +class InstrumentedProvider(ABC): + _provider_name: str = "InstrumentedModel" + + @abstractmethod + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + pass + + @abstractmethod + def override(self): + pass + + @abstractmethod + def undo_override(self): + pass + + @property + def provider_name(self): + return self._provider_name diff --git a/agentops/llms/ollama.py b/agentops/llms/ollama.py index 5118186a..fa196a12 100644 --- a/agentops/llms/ollama.py +++ b/agentops/llms/ollama.py @@ -97,7 +97,7 @@ def generator(): tracker.llm_event.prompt = kwargs["messages"] tracker.llm_event.completion = response["message"] - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) return response diff --git a/agentops/llms/openai.py b/agentops/llms/openai.py index c1f7f29a..1482c98d 100644 --- a/agentops/llms/openai.py +++ b/agentops/llms/openai.py @@ -2,116 +2,263 @@ import pprint from typing import Optional +from agentops.llms.instrumented_provider import InstrumentedProvider from agentops.time_travel import fetch_completion_override_from_time_travel_cache from agentops import LLMEvent, Session, ErrorEvent, logger -from agentops.helpers import check_call_stack_for_agent_id, get_ISO_time - - -def override_openai_v1_completion(tracker): - from openai.resources.chat import completions - from openai.types.chat import ChatCompletion, ChatCompletionChunk - - # Store the original method - global original_create - original_create = completions.Completions.create - - def patched_function(*args, **kwargs): - init_timestamp = get_ISO_time() - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] - - completion_override = fetch_completion_override_from_time_travel_cache(kwargs) - if completion_override: - result_model = None - pydantic_models = (ChatCompletion, ChatCompletionChunk) - for pydantic_model in pydantic_models: - try: - result_model = pydantic_model.model_validate_json( - completion_override +from agentops.helpers import check_call_stack_for_agent_id, get_ISO_time, safe_record + + +class OpenAiInstrumentedProvider(InstrumentedProvider): + original_create = None + original_create_async = None + + def __init__(self): + self._provider_name = "OpenAI" + + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + """Handle responses for OpenAI versions >v1.0.0""" + from openai import AsyncStream, Stream + from openai.resources import AsyncCompletions + from openai.types.chat import ChatCompletionChunk + + tracker.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + if session is not None: + tracker.llm_event.session_id = session.session_id + + def handle_stream_chunk(chunk: ChatCompletionChunk): + # NOTE: prompt/completion usage not returned in response when streaming + # We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion + if tracker.llm_event.returns == None: + tracker.llm_event.returns = chunk + + try: + accumulated_delta = tracker.llm_event.returns.choices[0].delta + tracker.llm_event.agent_id = check_call_stack_for_agent_id() + tracker.llm_event.model = chunk.model + tracker.llm_event.prompt = kwargs["messages"] + + # NOTE: We assume for completion only choices[0] is relevant + choice = chunk.choices[0] + + if choice.delta.content: + accumulated_delta.content += choice.delta.content + + if choice.delta.role: + accumulated_delta.role = choice.delta.role + + if choice.delta.tool_calls: + accumulated_delta.tool_calls = choice.delta.tool_calls + + if choice.delta.function_call: + accumulated_delta.function_call = choice.delta.function_call + + if choice.finish_reason: + # Streaming is done. Record LLMEvent + tracker.llm_event.returns.choices[0].finish_reason = ( + choice.finish_reason ) - break - except Exception as e: - pass - - if result_model is None: - logger.error( - f"Time Travel: Pydantic validation failed for {pydantic_models} \n" - f"Time Travel: Completion override was:\n" - f"{pprint.pformat(completion_override)}" + tracker.llm_event.completion = { + "role": accumulated_delta.role, + "content": accumulated_delta.content, + "function_call": accumulated_delta.function_call, + "tool_calls": accumulated_delta.tool_calls, + } + tracker.llm_event.end_timestamp = get_ISO_time() + + safe_record(session, tracker.llm_event) + except Exception as e: + safe_record( + session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) ) - return None - return tracker._handle_response_v1_openai( - result_model, kwargs, init_timestamp, session=session + + kwargs_str = pprint.pformat(kwargs) + chunk = pprint.pformat(chunk) + logger.warning( + f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" + f"chunk:\n {chunk}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + # if the response is a generator, decorate the generator + if isinstance(response, Stream): + + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return generator() + + # For asynchronous AsyncStream + elif isinstance(response, AsyncStream): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + # For async AsyncCompletion + elif isinstance(response, AsyncCompletions): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + # v1.0.0+ responses are objects + try: + tracker.llm_event.returns = response + tracker.llm_event.agent_id = check_call_stack_for_agent_id() + tracker.llm_event.prompt = kwargs["messages"] + tracker.llm_event.prompt_tokens = response.usage.prompt_tokens + tracker.llm_event.completion = response.choices[0].message.model_dump() + tracker.llm_event.completion_tokens = response.usage.completion_tokens + tracker.llm_event.model = response.model + + safe_record(session, tracker.llm_event) + except Exception as e: + safe_record( + session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) + ) + + kwargs_str = pprint.pformat(kwargs) + response = pprint.pformat(response) + logger.warning( + f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" + f"response:\n {response}\n" + f"kwargs:\n {kwargs_str}\n" ) - # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) - # if prompt_override: - # kwargs["messages"] = prompt_override["messages"] + return response - # Call the original function with its original arguments - result = original_create(*args, **kwargs) - return tracker._handle_response_v1_openai( - result, kwargs, init_timestamp, session=session - ) + def override(self): + self._override_openai_v1_completion() + self._override_openai_v1_async_completion() + + def _override_openai_v1_completion(self): + from openai.resources.chat import completions + from openai.types.chat import ChatCompletion, ChatCompletionChunk + + # Store the original method + global original_create + original_create = completions.Completions.create + + def patched_function(*args, **kwargs): + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + + completion_override = fetch_completion_override_from_time_travel_cache( + kwargs + ) + if completion_override: + result_model = None + pydantic_models = (ChatCompletion, ChatCompletionChunk) + for pydantic_model in pydantic_models: + try: + result_model = pydantic_model.model_validate_json( + completion_override + ) + break + except Exception as e: + pass + + if result_model is None: + logger.error( + f"Time Travel: Pydantic validation failed for {pydantic_models} \n" + f"Time Travel: Completion override was:\n" + f"{pprint.pformat(completion_override)}" + ) + return None + return tracker._handle_response_v1_openai( + result_model, kwargs, init_timestamp, session=session + ) - # Override the original method with the patched one - completions.Completions.create = patched_function + # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) + # if prompt_override: + # kwargs["messages"] = prompt_override["messages"] + # Call the original function with its original arguments + result = original_create(*args, **kwargs) + return tracker._handle_response_v1_openai( + result, kwargs, init_timestamp, session=session + ) + + # Override the original method with the patched one + completions.Completions.create = patched_function -def override_openai_v1_async_completion(tracker): - from openai.resources.chat import completions - from openai.types.chat import ChatCompletion, ChatCompletionChunk + def _override_openai_v1_async_completion(self): + from openai.resources.chat import completions + from openai.types.chat import ChatCompletion, ChatCompletionChunk - # Store the original method - global original_create_async - original_create_async = completions.AsyncCompletions.create + # Store the original method + global original_create_async + original_create_async = completions.AsyncCompletions.create - async def patched_function(*args, **kwargs): + async def patched_function(*args, **kwargs): - init_timestamp = get_ISO_time() + init_timestamp = get_ISO_time() - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] - completion_override = fetch_completion_override_from_time_travel_cache(kwargs) - if completion_override: - result_model = None - pydantic_models = (ChatCompletion, ChatCompletionChunk) - for pydantic_model in pydantic_models: - try: - result_model = pydantic_model.model_validate_json( - completion_override + completion_override = fetch_completion_override_from_time_travel_cache( + kwargs + ) + if completion_override: + result_model = None + pydantic_models = (ChatCompletion, ChatCompletionChunk) + for pydantic_model in pydantic_models: + try: + result_model = pydantic_model.model_validate_json( + completion_override + ) + break + except Exception as e: + pass + + if result_model is None: + logger.error( + f"Time Travel: Pydantic validation failed for {pydantic_models} \n" + f"Time Travel: Completion override was:\n" + f"{pprint.pformat(completion_override)}" ) - break - except Exception as e: - pass - - if result_model is None: - logger.error( - f"Time Travel: Pydantic validation failed for {pydantic_models} \n" - f"Time Travel: Completion override was:\n" - f"{pprint.pformat(completion_override)}" + return None + return tracker._handle_response_v1_openai( + result_model, kwargs, init_timestamp, session=session ) - return None + + # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) + # if prompt_override: + # kwargs["messages"] = prompt_override["messages"] + + # Call the original function with its original arguments + result = await original_create_async(*args, **kwargs) return tracker._handle_response_v1_openai( - result_model, kwargs, init_timestamp, session=session + result, kwargs, init_timestamp, session=session ) - # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) - # if prompt_override: - # kwargs["messages"] = prompt_override["messages"] + # Override the original method with the patched one + completions.AsyncCompletions.create = patched_function - # Call the original function with its original arguments - result = await original_create_async(*args, **kwargs) - return tracker._handle_response_v1_openai( - result, kwargs, init_timestamp, session=session - ) + def _undo_override_openai_v1_completion(self): + from openai.resources.chat import completions + + completions.Completions.create = self.original_create + + def _undo_override_openai_v1_async_completion(self): + from openai.resources.chat import completions - # Override the original method with the patched one - completions.AsyncCompletions.create = patched_function + completions.AsyncCompletions.create = self.original_create_async def handle_response_v0_openai( @@ -155,9 +302,9 @@ def handle_stream_chunk(chunk): } tracker.llm_event.end_timestamp = get_ISO_time() - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) except Exception as e: - tracker._safe_record( + safe_record( session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) ) @@ -204,134 +351,9 @@ def generator(): tracker.llm_event.model = response["model"] tracker.llm_event.end_timestamp = get_ISO_time() - tracker._safe_record(session, tracker.llm_event) - except Exception as e: - tracker._safe_record( - session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) - ) - - kwargs_str = pprint.pformat(kwargs) - response = pprint.pformat(response) - logger.warning( - f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" - f"response:\n {response}\n" - f"kwargs:\n {kwargs_str}\n" - ) - - return response - - -def handle_response_v1_openai( - tracker, response, kwargs, init_timestamp, session: Optional[Session] = None -): - """Handle responses for OpenAI versions >v1.0.0""" - from openai import AsyncStream, Stream - from openai.resources import AsyncCompletions - from openai.types.chat import ChatCompletionChunk - - tracker.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) - if session is not None: - tracker.llm_event.session_id = session.session_id - - def handle_stream_chunk(chunk: ChatCompletionChunk): - # NOTE: prompt/completion usage not returned in response when streaming - # We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion - if tracker.llm_event.returns == None: - tracker.llm_event.returns = chunk - - try: - accumulated_delta = tracker.llm_event.returns.choices[0].delta - tracker.llm_event.agent_id = check_call_stack_for_agent_id() - tracker.llm_event.model = chunk.model - tracker.llm_event.prompt = kwargs["messages"] - - # NOTE: We assume for completion only choices[0] is relevant - choice = chunk.choices[0] - - if choice.delta.content: - accumulated_delta.content += choice.delta.content - - if choice.delta.role: - accumulated_delta.role = choice.delta.role - - if choice.delta.tool_calls: - accumulated_delta.tool_calls = choice.delta.tool_calls - - if choice.delta.function_call: - accumulated_delta.function_call = choice.delta.function_call - - if choice.finish_reason: - # Streaming is done. Record LLMEvent - tracker.llm_event.returns.choices[0].finish_reason = ( - choice.finish_reason - ) - tracker.llm_event.completion = { - "role": accumulated_delta.role, - "content": accumulated_delta.content, - "function_call": accumulated_delta.function_call, - "tool_calls": accumulated_delta.tool_calls, - } - tracker.llm_event.end_timestamp = get_ISO_time() - - tracker._safe_record(session, tracker.llm_event) - except Exception as e: - tracker._safe_record( - session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) - ) - - kwargs_str = pprint.pformat(kwargs) - chunk = pprint.pformat(chunk) - logger.warning( - f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" - f"chunk:\n {chunk}\n" - f"kwargs:\n {kwargs_str}\n" - ) - - # if the response is a generator, decorate the generator - if isinstance(response, Stream): - - def generator(): - for chunk in response: - handle_stream_chunk(chunk) - yield chunk - - return generator() - - # For asynchronous AsyncStream - elif isinstance(response, AsyncStream): - - async def async_generator(): - async for chunk in response: - handle_stream_chunk(chunk) - yield chunk - - return async_generator() - - # For async AsyncCompletion - elif isinstance(response, AsyncCompletions): - - async def async_generator(): - async for chunk in response: - handle_stream_chunk(chunk) - yield chunk - - return async_generator() - - # v1.0.0+ responses are objects - try: - tracker.llm_event.returns = response - tracker.llm_event.agent_id = check_call_stack_for_agent_id() - tracker.llm_event.prompt = kwargs["messages"] - tracker.llm_event.prompt_tokens = response.usage.prompt_tokens - tracker.llm_event.completion = response.choices[0].message.model_dump() - tracker.llm_event.completion_tokens = response.usage.completion_tokens - tracker.llm_event.model = response.model - - tracker._safe_record(session, tracker.llm_event) + safe_record(session, tracker.llm_event) except Exception as e: - tracker._safe_record( - session, ErrorEvent(trigger_event=tracker.llm_event, exception=e) - ) + safe_record(session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) @@ -342,17 +364,3 @@ async def async_generator(): ) return response - - -def undo_override_openai_v1_completion(): - global original_create - from openai.resources.chat import completions - - completions.Completions.create = original_create - - -def undo_override_openai_v1_async_completion(): - global original_create_async - from openai.resources.chat import completions - - completions.AsyncCompletions.create = original_create_async diff --git a/agentops/llms/openai_v0.py b/agentops/llms/openai_v0.py new file mode 100644 index 00000000..570da993 --- /dev/null +++ b/agentops/llms/openai_v0.py @@ -0,0 +1,126 @@ +import inspect +import pprint +from typing import Optional + +from agentops.llms.instrumented_provider import InstrumentedProvider +from agentops.time_travel import fetch_completion_override_from_time_travel_cache + +from agentops import LLMEvent, Session, ErrorEvent, logger +from agentops.helpers import check_call_stack_for_agent_id, get_ISO_time, safe_record + + +class OpenAiInstrumentedProvider(InstrumentedProvider): + original_create = None + original_create_async = None + + def __init__(self): + self._provider_name = "OpenAI" + + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + """Handle responses for OpenAI versions