Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama Support #237

Merged
merged 12 commits into from
Jun 19, 2024
126 changes: 115 additions & 11 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import functools
import inspect
import pprint
import sys
from importlib import import_module
from importlib.metadata import version
from typing import Optional

from packaging.version import Version, parse

from .event import ActionEvent, ErrorEvent, LLMEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .log_config import logger
from .event import LLMEvent, ActionEvent, ToolEvent, ErrorEvent
from .helpers import get_ISO_time, check_call_stack_for_agent_id
import inspect
from typing import Optional
import pprint

original_func = {}
original_create = None
original_create_async = None

Expand All @@ -27,6 +30,7 @@ class LlmTracker:
"cohere": {
"5.4.0": ("chat", "chat_stream"),
},
"ollama": {},
siyangqiu marked this conversation as resolved.
Show resolved Hide resolved
}

def __init__(self, client):
Expand Down Expand Up @@ -134,9 +138,9 @@ def generator():

def _handle_response_v1_openai(self, response, kwargs, init_timestamp):
"""Handle responses for OpenAI versions >v1.0.0"""
from openai import Stream, AsyncStream
from openai.types.chat import ChatCompletionChunk
from openai import AsyncStream, Stream
from openai.resources import AsyncCompletions
from openai.types.chat import ChatCompletionChunk

self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

Expand All @@ -151,9 +155,9 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.model = chunk.model
self.llm_event.prompt = kwargs["messages"]
choice = chunk.choices[
0
] # NOTE: We assume for completion only choices[0] is relevant

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]
siyangqiu marked this conversation as resolved.
Show resolved Hide resolved

if choice.delta.content:
accumulated_delta.content += choice.delta.content
Expand Down Expand Up @@ -261,7 +265,6 @@ def _handle_response_cohere(self, response, kwargs, init_timestamp):
)

# from cohere.types.chat import ChatGenerationChunk

# NOTE: Cohere only returns one message and its role will be CHATBOT which we are coercing to "assistant"
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

Expand Down Expand Up @@ -417,6 +420,47 @@ def generator():

return response

def _handle_response_ollama(self, response, kwargs, init_timestamp):
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

if chunk.get("done"):
self.llm_event.completion["content"] += message.get("content")
self.llm_event.end_timestamp = get_ISO_time()
self.llm_event.model = f'ollama/{chunk.get("model")}'
self.llm_event.returns = chunk
self.llm_event.returns["message"] = self.llm_event.completion
self.llm_event.prompt = kwargs["messages"]
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.client.record(self.llm_event)

if self.llm_event.completion is None:
self.llm_event.completion = message
else:
self.llm_event.completion["content"] += message.get("content")

if inspect.isgenerator(response):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

self.llm_event.end_timestamp = get_ISO_time()

self.llm_event.model = f'ollama/{response["model"]}'
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.llm_event.completion = response["message"]

self.client.record(self.llm_event)
return response
sprajosh marked this conversation as resolved.
Show resolved Hide resolved

def override_openai_v1_completion(self):
from openai.resources.chat import completions

Expand Down Expand Up @@ -506,6 +550,47 @@ def patched_function(*args, **kwargs):
# Override the original method with the patched one
cohere.Client.chat_stream = patched_function

def override_ollama_chat(self):
import ollama

original_func['ollama.chat'] = ollama.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func['ollama.chat'](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
ollama.chat = patched_function

def override_ollama_chat_client(self):
from ollama import Client

original_func['ollama.Client.chat'] = Client.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func['ollama.Client.chat'](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
Client.chat = patched_function

def override_ollama_chat_async_client(self):
from ollama import AsyncClient

original_func['ollama.AsyncClient.chat'] = AsyncClient.chat
async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = await original_func['ollama.AsyncClient.chat'](*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
AsyncClient.chat = patched_function

def _override_method(self, api, method_path, module):
def handle_response(result, kwargs, init_timestamp):
if api == "openai":
Expand Down Expand Up @@ -595,6 +680,18 @@ def override_api(self):
f"Only Cohere>=5.4.0 supported. v{module_version} found."
)

if api == "ollama":
module_version = version(api)

if Version(module_version) >= parse("0.0.1"):
self.override_ollama_chat()
self.override_ollama_chat_client()
self.override_ollama_chat_async_client()
else:
logger.warning(
f"Only Ollama>=0.0.1 supported. v{module_version} found."
)

def stop_instrumenting(self):
self.undo_override_openai_v1_async_completion()
self.undo_override_openai_v1_completion()
Expand All @@ -610,3 +707,10 @@ def undo_override_openai_v1_async_completion(self):
from openai.resources.chat import completions

completions.AsyncCompletions.create = original_create_async

def undo_override_ollama(self):
siyangqiu marked this conversation as resolved.
Show resolved Hide resolved
import ollama

ollama.chat = original_func['ollama.chat']
ollama.Client.chat = original_func['ollama.Client.chat']
ollama.AsyncClient.chat = original_func['ollama.AsyncClient.chat']
Loading