Skip to content

Commit

Permalink
Ollama Support (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
sprajosh authored Jun 19, 2024
1 parent 4a5695c commit 4b86d9a
Showing 1 changed file with 117 additions and 11 deletions.
128 changes: 117 additions & 11 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import functools
import inspect
import pprint
import sys
from importlib import import_module
from importlib.metadata import version
from typing import Optional

from packaging.version import Version, parse

from .event import ActionEvent, ErrorEvent, LLMEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .log_config import logger
from .event import LLMEvent, ActionEvent, ToolEvent, ErrorEvent
from .helpers import get_ISO_time, check_call_stack_for_agent_id
import inspect
from typing import Optional
import pprint

original_func = {}
original_create = None
original_create_async = None

Expand All @@ -27,6 +30,7 @@ class LlmTracker:
"cohere": {
"5.4.0": ("chat", "chat_stream"),
},
"ollama": {"0.0.1": ("chat", "Client.chat", "AsyncClient.chat")},
}

def __init__(self, client):
Expand Down Expand Up @@ -134,9 +138,9 @@ def generator():

def _handle_response_v1_openai(self, response, kwargs, init_timestamp):
"""Handle responses for OpenAI versions >v1.0.0"""
from openai import Stream, AsyncStream
from openai.types.chat import ChatCompletionChunk
from openai import AsyncStream, Stream
from openai.resources import AsyncCompletions
from openai.types.chat import ChatCompletionChunk

self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

Expand All @@ -151,9 +155,9 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.model = chunk.model
self.llm_event.prompt = kwargs["messages"]
choice = chunk.choices[
0
] # NOTE: We assume for completion only choices[0] is relevant

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]

if choice.delta.content:
accumulated_delta.content += choice.delta.content
Expand Down Expand Up @@ -261,7 +265,6 @@ def _handle_response_cohere(self, response, kwargs, init_timestamp):
)

# from cohere.types.chat import ChatGenerationChunk

# NOTE: Cohere only returns one message and its role will be CHATBOT which we are coercing to "assistant"
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

Expand Down Expand Up @@ -417,6 +420,47 @@ def generator():

return response

def _handle_response_ollama(self, response, kwargs, init_timestamp):
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

if chunk.get("done"):
self.llm_event.completion["content"] += message.get("content")
self.llm_event.end_timestamp = get_ISO_time()
self.llm_event.model = f'ollama/{chunk.get("model")}'
self.llm_event.returns = chunk
self.llm_event.returns["message"] = self.llm_event.completion
self.llm_event.prompt = kwargs["messages"]
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.client.record(self.llm_event)

if self.llm_event.completion is None:
self.llm_event.completion = message
else:
self.llm_event.completion["content"] += message.get("content")

if inspect.isgenerator(response):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

self.llm_event.end_timestamp = get_ISO_time()

self.llm_event.model = f'ollama/{response["model"]}'
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.llm_event.completion = response["message"]

self.client.record(self.llm_event)
return response

def override_openai_v1_completion(self):
from openai.resources.chat import completions

Expand Down Expand Up @@ -506,6 +550,48 @@ def patched_function(*args, **kwargs):
# Override the original method with the patched one
cohere.Client.chat_stream = patched_function

def override_ollama_chat(self):
import ollama

original_func["ollama.chat"] = ollama.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
ollama.chat = patched_function

def override_ollama_chat_client(self):
from ollama import Client

original_func["ollama.Client.chat"] = Client.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.Client.chat"](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
Client.chat = patched_function

def override_ollama_chat_async_client(self):
from ollama import AsyncClient

original_func["ollama.AsyncClient.chat"] = AsyncClient.chat

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
AsyncClient.chat = patched_function

def _override_method(self, api, method_path, module):
def handle_response(result, kwargs, init_timestamp):
if api == "openai":
Expand Down Expand Up @@ -595,9 +681,22 @@ def override_api(self):
f"Only Cohere>=5.4.0 supported. v{module_version} found."
)

if api == "ollama":
module_version = version(api)

if Version(module_version) >= parse("0.0.1"):
self.override_ollama_chat()
self.override_ollama_chat_client()
self.override_ollama_chat_async_client()
else:
logger.warning(
f"Only Ollama>=0.0.1 supported. v{module_version} found."
)

def stop_instrumenting(self):
self.undo_override_openai_v1_async_completion()
self.undo_override_openai_v1_completion()
self.undo_override_ollama()

def undo_override_openai_v1_completion(self):
global original_create
Expand All @@ -610,3 +709,10 @@ def undo_override_openai_v1_async_completion(self):
from openai.resources.chat import completions

completions.AsyncCompletions.create = original_create_async

def undo_override_ollama(self):
import ollama

ollama.chat = original_func["ollama.chat"]
ollama.Client.chat = original_func["ollama.Client.chat"]
ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]

0 comments on commit 4b86d9a

Please sign in to comment.