From bd701cbbd4bfd21d0a96ec061720fff4bead20bf Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 01:05:47 +0000 Subject: [PATCH] refactor: improve Ollama provider implementation - Add ChatResponse and Choice classes for consistent interface - Improve error handling and streaming response processing - Update response handling for both streaming and non-streaming cases - Add proper type hints and documentation - Update tests to verify streaming response format Co-Authored-By: Alex Reibman --- agentops/llms/providers/ollama.py | 291 +++++++++++++++++++++--------- tests/integration/test_ollama.py | 157 ++++++++++++++++ 2 files changed, 365 insertions(+), 83 deletions(-) create mode 100644 tests/integration/test_ollama.py diff --git a/agentops/llms/providers/ollama.py b/agentops/llms/providers/ollama.py index e944469c9..57eae089c 100644 --- a/agentops/llms/providers/ollama.py +++ b/agentops/llms/providers/ollama.py @@ -1,126 +1,251 @@ -import inspect -import sys -from typing import Optional +import json +from typing import AsyncGenerator, Dict, List, Optional, Union +from dataclasses import dataclass +import asyncio +import httpx -from agentops.event import LLMEvent +from agentops.event import LLMEvent, ErrorEvent from agentops.session import Session from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id from .instrumented_provider import InstrumentedProvider from agentops.singleton import singleton -original_func = {} +@dataclass +class Choice: + message: dict = None + delta: dict = None + finish_reason: str = None + index: int = 0 + +@dataclass +class ChatResponse: + model: str + choices: list[Choice] +original_func = {} @singleton class OllamaProvider(InstrumentedProvider): original_create = None original_create_async = None - def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: - llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) - if session is not None: - llm_event.session_id = session.session_id - - def handle_stream_chunk(chunk: dict): - message = chunk.get("message", {"role": None, "content": ""}) - - if chunk.get("done"): - llm_event.end_timestamp = get_ISO_time() - llm_event.model = f'ollama/{chunk.get("model")}' - llm_event.returns = chunk - llm_event.returns["message"] = llm_event.completion - llm_event.prompt = kwargs["messages"] - llm_event.agent_id = check_call_stack_for_agent_id() - self._safe_record(session, llm_event) - - if llm_event.completion is None: - llm_event.completion = { - "role": message.get("role"), - "content": message.get("content", ""), - "tool_calls": None, - "function_call": None, - } - else: - llm_event.completion["content"] += message.get("content", "") - - if inspect.isgenerator(response): - - def generator(): - for chunk in response: - handle_stream_chunk(chunk) - yield chunk - - return generator() - - llm_event.end_timestamp = get_ISO_time() - llm_event.model = f'ollama/{response["model"]}' - llm_event.returns = response - llm_event.agent_id = check_call_stack_for_agent_id() - llm_event.prompt = kwargs["messages"] - llm_event.completion = { - "role": response["message"].get("role"), - "content": response["message"].get("content", ""), - "tool_calls": None, - "function_call": None, + def handle_response(self, response_data, request_data, init_timestamp, session=None): + """Handle the response from the Ollama API.""" + end_timestamp = get_ISO_time() + model = request_data.get("model", "unknown") + + # Extract error if present + error = None + if isinstance(response_data, dict) and "error" in response_data: + error = response_data["error"] + + # Create event data + event_data = { + "model": f"ollama/{model}", + "params": request_data, + "returns": { + "model": model, + }, + "init_timestamp": init_timestamp, + "end_timestamp": end_timestamp, + "prompt": request_data.get("messages", []), + "prompt_tokens": None, # Ollama doesn't provide token counts + "completion_tokens": None, + "cost": None, # Ollama is free/local } - self._safe_record(session, llm_event) - return response + + if error: + event_data["returns"]["error"] = error + event_data["completion"] = error + else: + # Extract completion from response + if isinstance(response_data, dict): + message = response_data.get("message", {}) + if isinstance(message, dict): + content = message.get("content", "") + event_data["returns"]["content"] = content + event_data["completion"] = content + + # Create and emit LLM event + if session: + event = LLMEvent(**event_data) + session.record(event) # Changed from add_event to record + + return event_data def override(self): + """Override Ollama methods with instrumented versions.""" self._override_chat_client() self._override_chat() self._override_chat_async_client() def undo_override(self): - if original_func is not None and original_func != {}: - import ollama - - ollama.chat = original_func["ollama.chat"] - ollama.Client.chat = original_func["ollama.Client.chat"] - ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"] - - def __init__(self, client): - super().__init__(client) + import ollama + if hasattr(self, '_original_chat'): + ollama.chat = self._original_chat + if hasattr(self, '_original_client_chat'): + ollama.Client.chat = self._original_client_chat + if hasattr(self, '_original_async_chat'): + ollama.AsyncClient.chat = self._original_async_chat + + def __init__(self, http_client=None, client=None): + """Initialize the Ollama provider.""" + super().__init__(client=client) + self.base_url = "http://localhost:11434" # Ollama runs locally by default + self.timeout = 60.0 # Default timeout in seconds + + # Initialize HTTP client if not provided + if http_client is None: + self.http_client = httpx.AsyncClient(timeout=self.timeout) + else: + self.http_client = http_client + + # Store original methods for restoration + self._original_chat = None + self._original_chat_client = None + self._original_chat_async_client = None def _override_chat(self): import ollama - - original_func["ollama.chat"] = ollama.chat + self._original_chat = ollama.chat def patched_function(*args, **kwargs): - # Call the original function with its original arguments init_timestamp = get_ISO_time() - result = original_func["ollama.chat"](*args, **kwargs) - return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None)) + session = kwargs.pop("session", None) + result = self._original_chat(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) - # Override the original method with the patched one ollama.chat = patched_function def _override_chat_client(self): from ollama import Client + self._original_client_chat = Client.chat - original_func["ollama.Client.chat"] = Client.chat - - def patched_function(*args, **kwargs): - # Call the original function with its original arguments + def patched_function(self_client, *args, **kwargs): init_timestamp = get_ISO_time() - result = original_func["ollama.Client.chat"](*args, **kwargs) - return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None)) + session = kwargs.pop("session", None) + result = self._original_client_chat(self_client, *args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) - # Override the original method with the patched one Client.chat = patched_function def _override_chat_async_client(self): from ollama import AsyncClient + self._original_async_chat = AsyncClient.chat - original_func = {} - original_func["ollama.AsyncClient.chat"] = AsyncClient.chat - - async def patched_function(*args, **kwargs): - # Call the original function with its original arguments + async def patched_function(self_client, *args, **kwargs): init_timestamp = get_ISO_time() - result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs) - return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None)) + session = kwargs.pop("session", None) + result = await self._original_async_chat(self_client, *args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) - # Override the original method with the patched one AsyncClient.chat = patched_function + + async def chat_completion( + self, + model: str, + messages: List[Dict[str, str]], + stream: bool = False, + session=None, + **kwargs, + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Send a chat completion request to the Ollama API.""" + init_timestamp = get_ISO_time() + + # Prepare request data + data = { + "model": model, + "messages": messages, + "stream": stream, + **kwargs, + } + + try: + response = await self.http_client.post( + f"{self.base_url}/api/chat", + json=data, + timeout=self.timeout, + ) + + if response.status_code != 200: + error_data = await response.json() + self.handle_response(error_data, data, init_timestamp, session) + raise Exception(error_data.get("error", "Unknown error")) + + if stream: + return self.stream_generator(response, data, init_timestamp, session) + else: + response_data = await response.json() + self.handle_response(response_data, data, init_timestamp, session) + return ChatResponse( + model=model, + choices=[ + Choice( + message=response_data["message"], + finish_reason="stop" + ) + ] + ) + + except Exception as e: + error_data = {"error": str(e)} + self.handle_response(error_data, data, init_timestamp, session) + raise + + async def stream_generator(self, response, data, init_timestamp, session): + """Generate streaming responses from Ollama API.""" + accumulated_content = "" + try: + async for line in response.aiter_lines(): + if not line.strip(): + continue + + try: + chunk_data = json.loads(line) + if not isinstance(chunk_data, dict): + continue + + message = chunk_data.get("message", {}) + if not isinstance(message, dict): + continue + + content = message.get("content", "") + if not content: + continue + + accumulated_content += content + + # Create chunk response with model parameter + chunk_response = ChatResponse( + model=data["model"], # Include model from request data + choices=[ + Choice( + delta={"content": content}, + finish_reason=None if not chunk_data.get("done") else "stop" + ) + ] + ) + yield chunk_response + + except json.JSONDecodeError: + continue + + # Emit event after streaming is complete + if accumulated_content: + self.handle_response( + { + "message": { + "role": "assistant", + "content": accumulated_content + } + }, + data, + init_timestamp, + session + ) + + except Exception as e: + # Handle streaming errors + error_data = {"error": str(e)} + self.handle_response(error_data, data, init_timestamp, session) + raise diff --git a/tests/integration/test_ollama.py b/tests/integration/test_ollama.py new file mode 100644 index 000000000..bf95718ee --- /dev/null +++ b/tests/integration/test_ollama.py @@ -0,0 +1,157 @@ +import json +import pytest +import httpx +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from agentops.llms.providers.ollama import OllamaProvider, ChatResponse, Choice +from .test_base import BaseProviderTest +import agentops + +class TestOllamaProvider(BaseProviderTest): + """Test class for Ollama provider.""" + + @pytest.fixture(autouse=True) + async def setup_test(self): + """Set up test method.""" + await super().async_setup_method(None) + + # Create mock httpx client and initialize provider with AgentOps session + self.mock_client = AsyncMock(spec=httpx.AsyncClient) + self.provider = OllamaProvider(http_client=self.mock_client, client=self.session) + + # Set up mock responses + async def mock_post(*args, **kwargs): + request_data = kwargs.get('json', {}) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + + if request_data.get('stream', False): + chunks = [ + { + "model": "llama2", + "message": { + "role": "assistant", + "content": "Test" + }, + "done": False + }, + { + "model": "llama2", + "message": { + "role": "assistant", + "content": " response" + }, + "done": True + } + ] + + async def async_line_generator(): + for chunk in chunks: + yield json.dumps(chunk) + "\n" + + mock_response.aiter_lines = async_line_generator + return mock_response + + elif "invalid-model" in request_data.get('model', ''): + mock_response.status_code = 404 + error_response = { + "error": "model \"invalid-model\" not found, try pulling it first" + } + mock_response.json = AsyncMock(return_value=error_response) + return mock_response + + else: + response_data = { + "model": "llama2", + "message": { + "role": "assistant", + "content": "Test response" + } + } + mock_response.json = AsyncMock(return_value=response_data) + return mock_response + + self.mock_client.post = AsyncMock(side_effect=mock_post) + + @pytest.mark.asyncio + async def teardown_method(self, method): + """Cleanup after each test.""" + if self.session: + await self.session.end() + + @pytest.mark.asyncio + async def test_completion(self): + """Test chat completion.""" + mock_response = { + "model": "llama2", + "content": "Test response" + } + self.mock_req.post( + "http://localhost:11434/api/chat", + json=mock_response + ) + + provider = OllamaProvider(model="llama2") + response = await provider.chat_completion( + messages=[{"role": "user", "content": "Test message"}], + session=self.session + ) + assert response["content"] == "Test response" + events = await self.async_verify_llm_event(self.mock_req, model="ollama/llama2") + + @pytest.mark.asyncio + async def test_streaming(self): + """Test streaming functionality.""" + mock_responses = [ + {"message": {"content": "Test"}, "done": False}, + {"message": {"content": " response"}, "done": True} + ] + + async def async_line_generator(): + for resp in mock_responses: + yield json.dumps(resp).encode() + b"\n" + + self.mock_req.post( + "http://localhost:11434/api/chat", + body=async_line_generator() + ) + + provider = OllamaProvider(model="llama2") + responses = [] + async for chunk in await provider.chat_completion( + messages=[{"role": "user", "content": "Test message"}], + stream=True, + session=self.session + ): + assert isinstance(chunk, ChatResponse) + assert len(chunk.choices) == 1 + assert isinstance(chunk.choices[0], Choice) + assert chunk.choices[0].delta["content"] in ["Test", " response"] + responses.append(chunk) + + assert len(responses) == 2 + events = await self.async_verify_llm_event(self.mock_req, model="ollama/llama2") + + @pytest.mark.asyncio + async def test_error_handling(self): + """Test error handling.""" + error_msg = "model \"invalid-model\" not found, try pulling it first" + mock_response = { + "model": "invalid-model", + "error": error_msg + } + self.mock_req.post( + "http://localhost:11434/api/chat", + json=mock_response, + status_code=404 + ) + + provider = OllamaProvider(model="invalid-model") + with pytest.raises(Exception) as exc_info: + await provider.chat_completion( + messages=[{"role": "user", "content": "Test message"}], + session=self.session + ) + assert error_msg in str(exc_info.value) + events = await self.async_verify_llm_event(self.mock_req, model="ollama/invalid-model")