diff --git a/src/agentscope/models/ollama_model.py b/src/agentscope/models/ollama_model.py index 38d679f2b..8d90cf803 100644 --- a/src/agentscope/models/ollama_model.py +++ b/src/agentscope/models/ollama_model.py @@ -44,6 +44,7 @@ def __init__( model_name: str, options: dict = None, keep_alive: str = "5m", + host: Optional[Union[str, None]] = None, **kwargs: Any, ) -> None: """Initialize the model wrapper for Ollama API. @@ -57,6 +58,9 @@ def __init__( keep_alive (`str`, default `5m`): Controls how long the model will stay loaded into memory following the request. + host (`str`, default `None`): + The host port of the ollama server. + Defaults to `None`, which is 127.0.0.1:11434. """ super().__init__(config_name=config_name) @@ -64,6 +68,7 @@ def __init__( self.model_name = model_name self.options = options self.keep_alive = keep_alive + self.client = ollama.Client(host=host, **kwargs) self._register_default_metrics() @@ -80,6 +85,7 @@ def __init__( stream: bool = False, options: dict = None, keep_alive: str = "5m", + host: Optional[Union[str, None]] = None, **kwargs: Any, ) -> None: """Initialize the model wrapper for Ollama API. @@ -95,6 +101,9 @@ def __init__( keep_alive (`str`, default `5m`): Controls how long the model will stay loaded into memory following the request. + host (`str`, default `None`): + The host port of the ollama server. + Defaults to `None`, which is 127.0.0.1:11434. """ super().__init__( @@ -102,6 +111,7 @@ def __init__( model_name=model_name, options=options, keep_alive=keep_alive, + host=host, **kwargs, ) @@ -161,7 +171,7 @@ def __call__( }, ) - response = ollama.chat(**kwargs) + response = self.client.chat(**kwargs) if stream: @@ -396,7 +406,7 @@ def __call__( keep_alive = keep_alive or self.keep_alive # step2: forward to generate response - response = ollama.embeddings( + response = self.client.embeddings( model=self.model_name, prompt=prompt, options=options, @@ -485,7 +495,7 @@ def __call__( keep_alive = keep_alive or self.keep_alive # step2: forward to generate response - response = ollama.generate( + response = self.client.generate( model=self.model_name, prompt=prompt, options=options, diff --git a/tests/ollama_test.py b/tests/ollama_test.py index 094b44121..8a03e5c98 100644 --- a/tests/ollama_test.py +++ b/tests/ollama_test.py @@ -94,11 +94,13 @@ def setUp(self) -> None: } flush() - @patch("ollama.chat") - def test_ollama_chat(self, mock_chat: MagicMock) -> None: + @patch("agentscope.models.ollama_model.ollama") + def test_ollama_chat(self, mock_ollama: MagicMock) -> None: """Unit test for ollama chat API.""" # prepare the mock - mock_chat.return_value = self.dummy_response + mock_ollama_client = MagicMock() + mock_ollama.Client.return_value = mock_ollama_client + mock_ollama_client.chat.return_value = self.dummy_response # run test agentscope.init( @@ -118,11 +120,13 @@ def test_ollama_chat(self, mock_chat: MagicMock) -> None: self.assertEqual(response.raw, self.dummy_response) - @patch("ollama.embeddings") - def test_ollama_embedding(self, mock_embeddings: MagicMock) -> None: + @patch("agentscope.models.ollama_model.ollama") + def test_ollama_embedding(self, mock_ollama: MagicMock) -> None: """Unit test for ollama embeddings API.""" # prepare the mock - mock_embeddings.return_value = self.dummy_embedding + mock_ollama_client = MagicMock() + mock_ollama.Client.return_value = mock_ollama_client + mock_ollama_client.embeddings.return_value = self.dummy_embedding # run test agentscope.init( @@ -142,11 +146,13 @@ def test_ollama_embedding(self, mock_embeddings: MagicMock) -> None: self.assertEqual(response.raw, self.dummy_embedding) - @patch("ollama.generate") - def test_ollama_generate(self, mock_generate: MagicMock) -> None: + @patch("agentscope.models.ollama_model.ollama") + def test_ollama_generate(self, mock_ollama: MagicMock) -> None: """Unit test for ollama generate API.""" # prepare the mock - mock_generate.return_value = self.dummy_generate + mock_ollama_client = MagicMock() + mock_ollama.Client.return_value = mock_ollama_client + mock_ollama_client.generate.return_value = self.dummy_generate # run test agentscope.init(