Skip to content

Commit

Permalink
Patch ollama.chat function
Browse files Browse the repository at this point in the history
  • Loading branch information
sprajosh committed Jun 5, 2024
1 parent 52e069e commit 566eb3d
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LlmTracker:
"cohere": {
"5.4.0": ("chat", "chat_stream"),
},
"ollama": {},
}

def __init__(self, client):
Expand Down Expand Up @@ -417,6 +418,22 @@ def generator():

return response

def _handle_response_ollama(self, response, kwargs, init_timestamp):
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

self.llm_event.end_timestamp = get_ISO_time()

self.llm_event.model = response["model"]
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.completion = response["message"]

# ollama doesn't return the tokens
# still looking at how to find this
self.llm_event.prompt_tokens = 0
self.llm_event.completion_tokens = 0

def override_openai_v1_completion(self):
from openai.resources.chat import completions

Expand Down Expand Up @@ -506,6 +523,20 @@ def patched_function(*args, **kwargs):
# Override the original method with the patched one
cohere.Client.chat_stream = patched_function

def override_ollama_chat(self):
import ollama

original_chat = ollama.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_chat(*args, **kwargs)
return self._handle_response_ollama(result, kwargs, init_timestamp)

# Override the original method with the patched one
ollama.chat = patched_function

def _override_method(self, api, method_path, module):
def handle_response(result, kwargs, init_timestamp):
if api == "openai":
Expand Down Expand Up @@ -595,6 +626,12 @@ def override_api(self):
f"Only Cohere>=5.4.0 supported. v{module_version} found."
)

if api == "ollama":
logger.warning("ollama support is still in dev stage.")
module_version = version(api)

self.override_ollama_chat()

def stop_instrumenting(self):
self.undo_override_openai_v1_async_completion()
self.undo_override_openai_v1_completion()
Expand Down

0 comments on commit 566eb3d

Please sign in to comment.