Skip to content

Commit

Permalink
Add host url option in Ollama model wrappers (#365)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: wenhao <[email protected]>
  • Loading branch information
garyzhang99 and wenhao authored Jul 26, 2024
1 parent 8771ba9 commit 17aec27
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
16 changes: 13 additions & 3 deletions src/agentscope/models/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
model_name: str,
options: dict = None,
keep_alive: str = "5m",
host: Optional[Union[str, None]] = None,
**kwargs: Any,
) -> None:
"""Initialize the model wrapper for Ollama API.
Expand All @@ -57,13 +58,17 @@ def __init__(
keep_alive (`str`, default `5m`):
Controls how long the model will stay loaded into memory
following the request.
host (`str`, default `None`):
The host port of the ollama server.
Defaults to `None`, which is 127.0.0.1:11434.
"""

super().__init__(config_name=config_name)

self.model_name = model_name
self.options = options
self.keep_alive = keep_alive
self.client = ollama.Client(host=host, **kwargs)

self._register_default_metrics()

Expand All @@ -80,6 +85,7 @@ def __init__(
stream: bool = False,
options: dict = None,
keep_alive: str = "5m",
host: Optional[Union[str, None]] = None,
**kwargs: Any,
) -> None:
"""Initialize the model wrapper for Ollama API.
Expand All @@ -95,13 +101,17 @@ def __init__(
keep_alive (`str`, default `5m`):
Controls how long the model will stay loaded into memory
following the request.
host (`str`, default `None`):
The host port of the ollama server.
Defaults to `None`, which is 127.0.0.1:11434.
"""

super().__init__(
config_name=config_name,
model_name=model_name,
options=options,
keep_alive=keep_alive,
host=host,
**kwargs,
)

Expand Down Expand Up @@ -161,7 +171,7 @@ def __call__(
},
)

response = ollama.chat(**kwargs)
response = self.client.chat(**kwargs)

if stream:

Expand Down Expand Up @@ -396,7 +406,7 @@ def __call__(
keep_alive = keep_alive or self.keep_alive

# step2: forward to generate response
response = ollama.embeddings(
response = self.client.embeddings(
model=self.model_name,
prompt=prompt,
options=options,
Expand Down Expand Up @@ -485,7 +495,7 @@ def __call__(
keep_alive = keep_alive or self.keep_alive

# step2: forward to generate response
response = ollama.generate(
response = self.client.generate(
model=self.model_name,
prompt=prompt,
options=options,
Expand Down
24 changes: 15 additions & 9 deletions tests/ollama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def setUp(self) -> None:
}
flush()

@patch("ollama.chat")
def test_ollama_chat(self, mock_chat: MagicMock) -> None:
@patch("agentscope.models.ollama_model.ollama")
def test_ollama_chat(self, mock_ollama: MagicMock) -> None:
"""Unit test for ollama chat API."""
# prepare the mock
mock_chat.return_value = self.dummy_response
mock_ollama_client = MagicMock()
mock_ollama.Client.return_value = mock_ollama_client
mock_ollama_client.chat.return_value = self.dummy_response

# run test
agentscope.init(
Expand All @@ -118,11 +120,13 @@ def test_ollama_chat(self, mock_chat: MagicMock) -> None:

self.assertEqual(response.raw, self.dummy_response)

@patch("ollama.embeddings")
def test_ollama_embedding(self, mock_embeddings: MagicMock) -> None:
@patch("agentscope.models.ollama_model.ollama")
def test_ollama_embedding(self, mock_ollama: MagicMock) -> None:
"""Unit test for ollama embeddings API."""
# prepare the mock
mock_embeddings.return_value = self.dummy_embedding
mock_ollama_client = MagicMock()
mock_ollama.Client.return_value = mock_ollama_client
mock_ollama_client.embeddings.return_value = self.dummy_embedding

# run test
agentscope.init(
Expand All @@ -142,11 +146,13 @@ def test_ollama_embedding(self, mock_embeddings: MagicMock) -> None:

self.assertEqual(response.raw, self.dummy_embedding)

@patch("ollama.generate")
def test_ollama_generate(self, mock_generate: MagicMock) -> None:
@patch("agentscope.models.ollama_model.ollama")
def test_ollama_generate(self, mock_ollama: MagicMock) -> None:
"""Unit test for ollama generate API."""
# prepare the mock
mock_generate.return_value = self.dummy_generate
mock_ollama_client = MagicMock()
mock_ollama.Client.return_value = mock_ollama_client
mock_ollama_client.generate.return_value = self.dummy_generate

# run test
agentscope.init(
Expand Down

0 comments on commit 17aec27

Please sign in to comment.