From 4b86d9abee3362303719f5c24fed3b22797fabb9 Mon Sep 17 00:00:00 2001 From: Siddharth Prajosh Date: Thu, 20 Jun 2024 02:55:32 +0530 Subject: [PATCH] Ollama Support (#237) --- agentops/llm_tracker.py | 128 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 117 insertions(+), 11 deletions(-) diff --git a/agentops/llm_tracker.py b/agentops/llm_tracker.py index dcead775..c07b2c9e 100644 --- a/agentops/llm_tracker.py +++ b/agentops/llm_tracker.py @@ -1,15 +1,18 @@ import functools +import inspect +import pprint import sys from importlib import import_module from importlib.metadata import version +from typing import Optional + from packaging.version import Version, parse + +from .event import ActionEvent, ErrorEvent, LLMEvent +from .helpers import check_call_stack_for_agent_id, get_ISO_time from .log_config import logger -from .event import LLMEvent, ActionEvent, ToolEvent, ErrorEvent -from .helpers import get_ISO_time, check_call_stack_for_agent_id -import inspect -from typing import Optional -import pprint +original_func = {} original_create = None original_create_async = None @@ -27,6 +30,7 @@ class LlmTracker: "cohere": { "5.4.0": ("chat", "chat_stream"), }, + "ollama": {"0.0.1": ("chat", "Client.chat", "AsyncClient.chat")}, } def __init__(self, client): @@ -134,9 +138,9 @@ def generator(): def _handle_response_v1_openai(self, response, kwargs, init_timestamp): """Handle responses for OpenAI versions >v1.0.0""" - from openai import Stream, AsyncStream - from openai.types.chat import ChatCompletionChunk + from openai import AsyncStream, Stream from openai.resources import AsyncCompletions + from openai.types.chat import ChatCompletionChunk self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) @@ -151,9 +155,9 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self.llm_event.agent_id = check_call_stack_for_agent_id() self.llm_event.model = chunk.model self.llm_event.prompt = kwargs["messages"] - choice = chunk.choices[ - 0 - ] # NOTE: We assume for completion only choices[0] is relevant + + # NOTE: We assume for completion only choices[0] is relevant + choice = chunk.choices[0] if choice.delta.content: accumulated_delta.content += choice.delta.content @@ -261,7 +265,6 @@ def _handle_response_cohere(self, response, kwargs, init_timestamp): ) # from cohere.types.chat import ChatGenerationChunk - # NOTE: Cohere only returns one message and its role will be CHATBOT which we are coercing to "assistant" self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) @@ -417,6 +420,47 @@ def generator(): return response + def _handle_response_ollama(self, response, kwargs, init_timestamp): + self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + + def handle_stream_chunk(chunk: dict): + message = chunk.get("message", {"role": None, "content": ""}) + + if chunk.get("done"): + self.llm_event.completion["content"] += message.get("content") + self.llm_event.end_timestamp = get_ISO_time() + self.llm_event.model = f'ollama/{chunk.get("model")}' + self.llm_event.returns = chunk + self.llm_event.returns["message"] = self.llm_event.completion + self.llm_event.prompt = kwargs["messages"] + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.client.record(self.llm_event) + + if self.llm_event.completion is None: + self.llm_event.completion = message + else: + self.llm_event.completion["content"] += message.get("content") + + if inspect.isgenerator(response): + + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return generator() + + self.llm_event.end_timestamp = get_ISO_time() + + self.llm_event.model = f'ollama/{response["model"]}' + self.llm_event.returns = response + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.llm_event.prompt = kwargs["messages"] + self.llm_event.completion = response["message"] + + self.client.record(self.llm_event) + return response + def override_openai_v1_completion(self): from openai.resources.chat import completions @@ -506,6 +550,48 @@ def patched_function(*args, **kwargs): # Override the original method with the patched one cohere.Client.chat_stream = patched_function + def override_ollama_chat(self): + import ollama + + original_func["ollama.chat"] = ollama.chat + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = original_func["ollama.chat"](*args, **kwargs) + return self._handle_response_ollama(result, kwargs, init_timestamp) + + # Override the original method with the patched one + ollama.chat = patched_function + + def override_ollama_chat_client(self): + from ollama import Client + + original_func["ollama.Client.chat"] = Client.chat + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = original_func["ollama.Client.chat"](*args, **kwargs) + return self._handle_response_ollama(result, kwargs, init_timestamp) + + # Override the original method with the patched one + Client.chat = patched_function + + def override_ollama_chat_async_client(self): + from ollama import AsyncClient + + original_func["ollama.AsyncClient.chat"] = AsyncClient.chat + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs) + return self._handle_response_ollama(result, kwargs, init_timestamp) + + # Override the original method with the patched one + AsyncClient.chat = patched_function + def _override_method(self, api, method_path, module): def handle_response(result, kwargs, init_timestamp): if api == "openai": @@ -595,9 +681,22 @@ def override_api(self): f"Only Cohere>=5.4.0 supported. v{module_version} found." ) + if api == "ollama": + module_version = version(api) + + if Version(module_version) >= parse("0.0.1"): + self.override_ollama_chat() + self.override_ollama_chat_client() + self.override_ollama_chat_async_client() + else: + logger.warning( + f"Only Ollama>=0.0.1 supported. v{module_version} found." + ) + def stop_instrumenting(self): self.undo_override_openai_v1_async_completion() self.undo_override_openai_v1_completion() + self.undo_override_ollama() def undo_override_openai_v1_completion(self): global original_create @@ -610,3 +709,10 @@ def undo_override_openai_v1_async_completion(self): from openai.resources.chat import completions completions.AsyncCompletions.create = original_create_async + + def undo_override_ollama(self): + import ollama + + ollama.chat = original_func["ollama.chat"] + ollama.Client.chat = original_func["ollama.Client.chat"] + ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]