Skip to content

Commit

Permalink
refactor: improve Ollama provider implementation
Browse files Browse the repository at this point in the history
- Add ChatResponse and Choice classes for consistent interface
- Improve error handling and streaming response processing
- Update response handling for both streaming and non-streaming cases
- Add proper type hints and documentation
- Update tests to verify streaming response format

Co-Authored-By: Alex Reibman <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and areibman committed Dec 19, 2024
1 parent c3fceb1 commit bd701cb
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 83 deletions.
291 changes: 208 additions & 83 deletions agentops/llms/providers/ollama.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,251 @@
import inspect
import sys
from typing import Optional
import json
from typing import AsyncGenerator, Dict, List, Optional, Union
from dataclasses import dataclass
import asyncio
import httpx

from agentops.event import LLMEvent
from agentops.event import LLMEvent, ErrorEvent
from agentops.session import Session
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from .instrumented_provider import InstrumentedProvider
from agentops.singleton import singleton

original_func = {}
@dataclass
class Choice:
message: dict = None
delta: dict = None
finish_reason: str = None
index: int = 0

@dataclass
class ChatResponse:
model: str
choices: list[Choice]

original_func = {}

@singleton
class OllamaProvider(InstrumentedProvider):
original_create = None
original_create_async = None

def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
if session is not None:
llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

if chunk.get("done"):
llm_event.end_timestamp = get_ISO_time()
llm_event.model = f'ollama/{chunk.get("model")}'
llm_event.returns = chunk
llm_event.returns["message"] = llm_event.completion
llm_event.prompt = kwargs["messages"]
llm_event.agent_id = check_call_stack_for_agent_id()
self._safe_record(session, llm_event)

if llm_event.completion is None:
llm_event.completion = {
"role": message.get("role"),
"content": message.get("content", ""),
"tool_calls": None,
"function_call": None,
}
else:
llm_event.completion["content"] += message.get("content", "")

if inspect.isgenerator(response):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

llm_event.end_timestamp = get_ISO_time()
llm_event.model = f'ollama/{response["model"]}'
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.prompt = kwargs["messages"]
llm_event.completion = {
"role": response["message"].get("role"),
"content": response["message"].get("content", ""),
"tool_calls": None,
"function_call": None,
def handle_response(self, response_data, request_data, init_timestamp, session=None):
"""Handle the response from the Ollama API."""
end_timestamp = get_ISO_time()
model = request_data.get("model", "unknown")

# Extract error if present
error = None
if isinstance(response_data, dict) and "error" in response_data:
error = response_data["error"]

# Create event data
event_data = {
"model": f"ollama/{model}",
"params": request_data,
"returns": {
"model": model,
},
"init_timestamp": init_timestamp,
"end_timestamp": end_timestamp,
"prompt": request_data.get("messages", []),
"prompt_tokens": None, # Ollama doesn't provide token counts
"completion_tokens": None,
"cost": None, # Ollama is free/local
}
self._safe_record(session, llm_event)
return response

if error:
event_data["returns"]["error"] = error
event_data["completion"] = error
else:
# Extract completion from response
if isinstance(response_data, dict):
message = response_data.get("message", {})
if isinstance(message, dict):
content = message.get("content", "")
event_data["returns"]["content"] = content
event_data["completion"] = content

# Create and emit LLM event
if session:
event = LLMEvent(**event_data)
session.record(event) # Changed from add_event to record

return event_data

def override(self):
"""Override Ollama methods with instrumented versions."""
self._override_chat_client()
self._override_chat()
self._override_chat_async_client()

def undo_override(self):
if original_func is not None and original_func != {}:
import ollama

ollama.chat = original_func["ollama.chat"]
ollama.Client.chat = original_func["ollama.Client.chat"]
ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]

def __init__(self, client):
super().__init__(client)
import ollama
if hasattr(self, '_original_chat'):
ollama.chat = self._original_chat
if hasattr(self, '_original_client_chat'):
ollama.Client.chat = self._original_client_chat
if hasattr(self, '_original_async_chat'):
ollama.AsyncClient.chat = self._original_async_chat

def __init__(self, http_client=None, client=None):
"""Initialize the Ollama provider."""
super().__init__(client=client)
self.base_url = "http://localhost:11434" # Ollama runs locally by default
self.timeout = 60.0 # Default timeout in seconds

# Initialize HTTP client if not provided
if http_client is None:
self.http_client = httpx.AsyncClient(timeout=self.timeout)
else:
self.http_client = http_client

# Store original methods for restoration
self._original_chat = None
self._original_chat_client = None
self._original_chat_async_client = None

def _override_chat(self):
import ollama

original_func["ollama.chat"] = ollama.chat
self._original_chat = ollama.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
ollama.chat = patched_function

def _override_chat_client(self):
from ollama import Client
self._original_client_chat = Client.chat

original_func["ollama.Client.chat"] = Client.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
result = original_func["ollama.Client.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_client_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Client.chat = patched_function

def _override_chat_async_client(self):
from ollama import AsyncClient
self._original_async_chat = AsyncClient.chat

original_func = {}
original_func["ollama.AsyncClient.chat"] = AsyncClient.chat

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
async def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = await self._original_async_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
AsyncClient.chat = patched_function

async def chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
session=None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send a chat completion request to the Ollama API."""
init_timestamp = get_ISO_time()

# Prepare request data
data = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}

try:
response = await self.http_client.post(
f"{self.base_url}/api/chat",
json=data,
timeout=self.timeout,
)

if response.status_code != 200:
error_data = await response.json()
self.handle_response(error_data, data, init_timestamp, session)
raise Exception(error_data.get("error", "Unknown error"))

if stream:
return self.stream_generator(response, data, init_timestamp, session)
else:
response_data = await response.json()
self.handle_response(response_data, data, init_timestamp, session)
return ChatResponse(
model=model,
choices=[
Choice(
message=response_data["message"],
finish_reason="stop"
)
]
)

except Exception as e:
error_data = {"error": str(e)}
self.handle_response(error_data, data, init_timestamp, session)
raise

async def stream_generator(self, response, data, init_timestamp, session):
"""Generate streaming responses from Ollama API."""
accumulated_content = ""
try:
async for line in response.aiter_lines():
if not line.strip():
continue

try:
chunk_data = json.loads(line)
if not isinstance(chunk_data, dict):
continue

message = chunk_data.get("message", {})
if not isinstance(message, dict):
continue

content = message.get("content", "")
if not content:
continue

accumulated_content += content

# Create chunk response with model parameter
chunk_response = ChatResponse(
model=data["model"], # Include model from request data
choices=[
Choice(
delta={"content": content},
finish_reason=None if not chunk_data.get("done") else "stop"
)
]
)
yield chunk_response

except json.JSONDecodeError:
continue

# Emit event after streaming is complete
if accumulated_content:
self.handle_response(
{
"message": {
"role": "assistant",
"content": accumulated_content
}
},
data,
init_timestamp,
session
)

except Exception as e:
# Handle streaming errors
error_data = {"error": str(e)}
self.handle_response(error_data, data, init_timestamp, session)
raise
Loading

0 comments on commit bd701cb

Please sign in to comment.