Skip to content

Commit

Permalink
add litellm chat test
Browse files Browse the repository at this point in the history
  • Loading branch information
garyzhang99 committed May 9, 2024
1 parent a6b8d0b commit dafc60e
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/litellm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""litellm test"""
import unittest
from unittest.mock import patch, MagicMock

import agentscope
from agentscope.models import load_model_by_config_name


class TestLiteLLMChatWrapper(unittest.TestCase):
"""Test LiteLLM Chat Wrapper"""

def setUp(self) -> None:
self.api_key = "test_api_key.secret_key"
self.messages = [
{"role": "user", "content": "Hello, litellm!"},
{"role": "assistant", "content": "How can I assist you?"},
]

@patch("agentscope.models.litellm_model.litellm")
def test_chat(self, mock_litellm: MagicMock) -> None:
"""
Test chat"""
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"choices": [
{"message": {"content": "Hello, this is a mocked response!"}},
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 5,
"total_tokens": 105,
},
}
mock_response.choices[
0
].message.content = "Hello, this is a mocked response!"

mock_litellm.completion.return_value = mock_response

agentscope.init(
model_configs={
"config_name": "test_config",
"model_type": "litellm_chat",
"model_name": "ollama/llama3:8b",
"api_key": self.api_key,
},
)

model = load_model_by_config_name("test_config")

response = model(
messages=self.messages,
api_base="http://localhost:11434",
)

self.assertEqual(response.text, "Hello, this is a mocked response!")


if __name__ == "__main__":
unittest.main()

0 comments on commit dafc60e

Please sign in to comment.