From 566eb3d7142672680a36e11aa620c7dccb07aa66 Mon Sep 17 00:00:00 2001 From: Siddharth Prajosh Date: Wed, 5 Jun 2024 11:14:54 +0530 Subject: [PATCH] Patch ollama.chat function --- agentops/llm_tracker.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/agentops/llm_tracker.py b/agentops/llm_tracker.py index dcead775..5f8a0dca 100644 --- a/agentops/llm_tracker.py +++ b/agentops/llm_tracker.py @@ -27,6 +27,7 @@ class LlmTracker: "cohere": { "5.4.0": ("chat", "chat_stream"), }, + "ollama": {}, } def __init__(self, client): @@ -417,6 +418,22 @@ def generator(): return response + def _handle_response_ollama(self, response, kwargs, init_timestamp): + self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + + self.llm_event.end_timestamp = get_ISO_time() + + self.llm_event.model = response["model"] + self.llm_event.returns = response + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.llm_event.prompt = kwargs["messages"] + self.completion = response["message"] + + # ollama doesn't return the tokens + # still looking at how to find this + self.llm_event.prompt_tokens = 0 + self.llm_event.completion_tokens = 0 + def override_openai_v1_completion(self): from openai.resources.chat import completions @@ -506,6 +523,20 @@ def patched_function(*args, **kwargs): # Override the original method with the patched one cohere.Client.chat_stream = patched_function + def override_ollama_chat(self): + import ollama + + original_chat = ollama.chat + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + result = original_chat(*args, **kwargs) + return self._handle_response_ollama(result, kwargs, init_timestamp) + + # Override the original method with the patched one + ollama.chat = patched_function + def _override_method(self, api, method_path, module): def handle_response(result, kwargs, init_timestamp): if api == "openai": @@ -595,6 +626,12 @@ def override_api(self): f"Only Cohere>=5.4.0 supported. v{module_version} found." ) + if api == "ollama": + logger.warning("ollama support is still in dev stage.") + module_version = version(api) + + self.override_ollama_chat() + def stop_instrumenting(self): self.undo_override_openai_v1_async_completion() self.undo_override_openai_v1_completion()