From 497f8b953ee1605463db8b1347328c105af92be1 Mon Sep 17 00:00:00 2001
From: garyzhang99 <46197280+garyzhang99@users.noreply.github.com>
Date: Sat, 11 May 2024 14:27:05 +0800
Subject: [PATCH] Support LiteLLM chat API in AgentScope (#204)
---
README.md | 1 +
README_ZH.md | 1 +
.../en/source/tutorial/203-model.md | 22 ++
.../zh_CN/source/tutorial/203-model.md | 22 ++
.../litellm_chat_template.json | 11 +
setup.py | 1 +
src/agentscope/_init.py | 2 +-
src/agentscope/models/__init__.py | 4 +
src/agentscope/models/litellm_model.py | 256 ++++++++++++++++++
tests/format_test.py | 27 ++
tests/litellm_test.py | 61 +++++
11 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 examples/model_configs_template/litellm_chat_template.json
create mode 100644 src/agentscope/models/litellm_model.py
create mode 100644 tests/litellm_test.py
diff --git a/README.md b/README.md
index 62a4c8758..dcc2cb432 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,7 @@ services and third-party model APIs.
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_chat_template.json) | llama3, llama2, Mistral, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_embedding_template.json) | llama2, Mistral, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_generate_template.json) | llama2, Mistral, ... |
+| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#litellm-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/litellm_chat_template.json) | [models supported by litellm](https://docs.litellm.ai/docs/)... |
| Post Request based API | - | [`PostAPIModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#post-request-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/postapi_model_config_template.json) | - |
**Supported Local Model Deployment**
diff --git a/README_ZH.md b/README_ZH.md
index 56fa23d36..f3ef8c72d 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -80,6 +80,7 @@ AgentScope提供了一系列`ModelWrapper`来支持本地模型服务和第三
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_chat_template.json) | llama3, llama2, Mistral, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_embedding_template.json) | llama2, Mistral, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_generate_template.json) | llama2, Mistral, ... |
+| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#litellm-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/litellm_chat_template.json) | [models supported by litellm](https://docs.litellm.ai/docs/)... |
| Post Request based API | - | [`PostAPIModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#post-request-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/postapi_model_config_template.json) | - |
**支持的本地模型部署**
diff --git a/docs/sphinx_doc/en/source/tutorial/203-model.md b/docs/sphinx_doc/en/source/tutorial/203-model.md
index 08ef18dc5..d6e153d0f 100644
--- a/docs/sphinx_doc/en/source/tutorial/203-model.md
+++ b/docs/sphinx_doc/en/source/tutorial/203-model.md
@@ -16,6 +16,7 @@ Currently, AgentScope supports the following model service APIs:
- Gemini API, including chat and embedding.
- ZhipuAI API, including chat and embedding.
- Ollama API, including chat, embedding and generation.
+- LiteLLM API, including chat, with various model APIs.
- Post Request API, model inference services based on Post
requests, including Huggingface/ModelScope Inference API and various
post request based model APIs.
@@ -87,6 +88,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
+| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
@@ -440,6 +442,26 @@ Here we provide example configurations for different model wrappers.
+
+#### LiteLLM Chat API
+
+
+LiteLLM Chat API (agentscope.models.LiteLLMChatModelWrapper
)
+
+```python
+{
+ "config_name": "lite_llm_openai_chat_gpt-3.5-turbo",
+ "model_type": "litellm_chat",
+ "model_name": "gpt-3.5-turbo" # You should note that for different models, you should set the corresponding environment variables, such as OPENAI_API_KEY, etc. You may refer to https://docs.litellm.ai/docs/ for this.
+},
+```
+
+
+
+
+
+
#### Post Request Chat API
diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
index 0528abae8..7b912cbf2 100644
--- a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
+++ b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
@@ -13,6 +13,7 @@ AgentScope中,模型的部署和调用是通过`ModelWrapper`来解耦开的
- Gemini API,包括对话(Chat)和嵌入(Embedding)。
- ZhipuAi API,包括对话(Chat)和嵌入(Embedding)。
- Ollama API,包括对话(Chat),嵌入(Embedding)和生成(Generation)。
+- LiteLLM API, 包括对话(Chat), 支持各种模型的API.
- Post请求API,基于Post请求实现的模型推理服务,包括Huggingface/ModelScope
Inference API和各种符合Post请求格式的API。
@@ -107,6 +108,7 @@ API如下:
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
+| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
@@ -435,6 +437,26 @@ API如下:
+
+#### LiteLLM Chat API
+
+
+LiteLLM Chat API (agentscope.models.LiteLLMChatModelWrapper
)
+
+```python
+{
+ "config_name": "lite_llm_openai_chat_gpt-3.5-turbo",
+ "model_type": "litellm_chat",
+ "model_name": "gpt-3.5-turbo" # You should note that for different models, you should set the corresponding environment variables, such as OPENAI_API_KEY, etc. You may refer to https://docs.litellm.ai/docs/ for this.
+},
+```
+
+
+
+
+
+
#### Post Request API
diff --git a/examples/model_configs_template/litellm_chat_template.json b/examples/model_configs_template/litellm_chat_template.json
new file mode 100644
index 000000000..f1711dca9
--- /dev/null
+++ b/examples/model_configs_template/litellm_chat_template.json
@@ -0,0 +1,11 @@
+[{
+ "config_name": "lite_llm_openai_chat_gpt-3.5-turbo",
+ "model_type": "litellm_chat",
+ "model_name": "gpt-3.5-turbo"
+},
+{
+ "config_name": "lite_llm_claude3",
+ "model_type": "litellm_chat",
+ "model_name": "claude-3-opus-20240229"
+}
+]
diff --git a/setup.py b/setup.py
index 7dca2181b..2259f592f 100644
--- a/setup.py
+++ b/setup.py
@@ -70,6 +70,7 @@
"ollama>=0.1.7",
"google-generativeai>=0.4.0",
"zhipuai",
+ "litellm",
]
distribute_requires = minimal_requires + rpc_requires
diff --git a/src/agentscope/_init.py b/src/agentscope/_init.py
index 7d1f44d7b..dff68e585 100644
--- a/src/agentscope/_init.py
+++ b/src/agentscope/_init.py
@@ -25,7 +25,7 @@ def init(
save_dir: str = _DEFAULT_DIR,
save_log: bool = True,
save_code: bool = True,
- save_api_invoke: bool = True,
+ save_api_invoke: bool = False,
use_monitor: bool = True,
logger_level: LOG_LEVEL = _DEFAULT_LOG_LEVEL,
runtime_id: Optional[str] = None,
diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py
index 1e607c0e4..832829993 100644
--- a/src/agentscope/models/__init__.py
+++ b/src/agentscope/models/__init__.py
@@ -37,6 +37,9 @@
ZhipuAIChatWrapper,
ZhipuAIEmbeddingWrapper,
)
+from .litellm_model import (
+ LiteLLMChatWrapper,
+)
__all__ = [
@@ -59,6 +62,7 @@
"GeminiEmbeddingWrapper",
"ZhipuAIChatWrapper",
"ZhipuAIEmbeddingWrapper",
+ "LiteLLMChatWrapper",
"load_model_by_config_name",
"read_model_configs",
"clear_model_configs",
diff --git a/src/agentscope/models/litellm_model.py b/src/agentscope/models/litellm_model.py
new file mode 100644
index 000000000..242830a38
--- /dev/null
+++ b/src/agentscope/models/litellm_model.py
@@ -0,0 +1,256 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper based on litellm https://docs.litellm.ai/docs/"""
+from abc import ABC
+from typing import Union, Any, List, Sequence
+
+from loguru import logger
+
+from .model import ModelWrapperBase, ModelResponse
+from ..message import MessageBase
+from ..utils.tools import _convert_to_str
+
+try:
+ import litellm
+except ImportError:
+ litellm = None
+
+
+class LiteLLMWrapperBase(ModelWrapperBase, ABC):
+ """The model wrapper based on LiteLLM API."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """
+ To use the LiteLLM wrapper, environent variables must be set.
+ Different model_name could be using different environment variables.
+ For example:
+ - for model_name: "gpt-3.5-turbo", you need to set "OPENAI_API_KEY"
+ ```
+ os.environ["OPENAI_API_KEY"] = "your-api-key"
+ ```
+ - for model_name: "claude-2", you need to set "ANTHROPIC_API_KEY"
+ - for Azure OpenAI, you need to set "AZURE_API_KEY",
+ "AZURE_API_BASE", "AZURE_API_VERSION"
+ You should refer to the docs in https://docs.litellm.ai/docs/ .
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in OpenAI API.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in litellm api generation,
+ e.g. `temperature`, `seed`.
+ For generate_args, please refer to
+ https://docs.litellm.ai/docs/completion/input
+ for more detailes.
+
+ """
+
+ if model_name is None:
+ model_name = config_name
+ logger.warning("model_name is not set, use config_name instead.")
+
+ super().__init__(config_name=config_name)
+
+ if litellm is None:
+ raise ImportError(
+ "Cannot import litellm package in current python environment."
+ "You should try:"
+ "1. Install litellm by `pip install litellm`"
+ "2. If you still have import error, you should try to "
+ "update the openai to higher version, e.g. "
+ "by runing `pip install openai==1.25.1",
+ )
+
+ self.model_name = model_name
+ self.generate_args = generate_args or {}
+ self._register_default_metrics()
+
+ def format(
+ self,
+ *args: Union[MessageBase, Sequence[MessageBase]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
+
+
+class LiteLLMChatWrapper(LiteLLMWrapperBase):
+ """The model wrapper based on litellm chat API.
+ To use the LiteLLM wrapper, environent variables must be set.
+ Different model_name could be using different environment variables.
+ For example:
+ - for model_name: "gpt-3.5-turbo", you need to set "OPENAI_API_KEY"
+ ```
+ os.environ["OPENAI_API_KEY"] = "your-api-key"
+ ```
+ - for model_name: "claude-2", you need to set "ANTHROPIC_API_KEY"
+ - for Azure OpenAI, you need to set "AZURE_API_KEY",
+ "AZURE_API_BASE", "AZURE_API_VERSION"
+ You should refer to the docs in https://docs.litellm.ai/docs/ .
+ """
+
+ model_type: str = "litellm_chat"
+
+ def _register_default_metrics(self) -> None:
+ # Set monitor accordingly
+ # TODO: set quota to the following metrics
+ self.monitor.register(
+ self._metric("call_counter"),
+ metric_unit="times",
+ )
+ self.monitor.register(
+ self._metric("prompt_tokens"),
+ metric_unit="token",
+ )
+ self.monitor.register(
+ self._metric("completion_tokens"),
+ metric_unit="token",
+ )
+ self.monitor.register(
+ self._metric("total_tokens"),
+ metric_unit="token",
+ )
+
+ def __call__(
+ self,
+ messages: list,
+ **kwargs: Any,
+ ) -> ModelResponse:
+ """
+ Args:
+ messages (`list`):
+ A list of messages to process.
+ **kwargs (`Any`):
+ The keyword arguments to litellm chat completions API,
+ e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to
+ https://docs.litellm.ai/docs/completion/input
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ The response text in text field, and the raw response in
+ raw field.
+ """
+
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ "LiteLLM `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+ if not all("role" in msg and "content" in msg for msg in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for LiteLLM API.",
+ )
+
+ # step3: forward to generate response
+ response = litellm.completion(
+ model=self.model_name,
+ messages=messages,
+ **kwargs,
+ )
+
+ # step4: record the api invocation if needed
+ self._save_model_invocation(
+ arguments={
+ "model": self.model_name,
+ "messages": messages,
+ **kwargs,
+ },
+ response=response.model_dump(),
+ )
+
+ # step5: update monitor accordingly
+ self.update_monitor(call_counter=1, **response.usage.model_dump())
+
+ # step6: return response
+ return ModelResponse(
+ text=response.choices[0].message.content,
+ raw=response.model_dump(),
+ )
+
+ def format(
+ self,
+ *args: Union[MessageBase, Sequence[MessageBase]],
+ ) -> List[dict]:
+ """Format the input string and dictionary into the unified format.
+ Note that the format function might not be the optimal way to contruct
+ prompt for every model, but a common way to do so.
+ Developers are encouraged to implement their own prompt
+ engineering strategies if have strong performance concerns.
+
+ Args:
+ args (`Union[MessageBase, Sequence[MessageBase]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+ Returns:
+ `List[dict]`:
+ The formatted messages in the format that anthropic Chat API
+ required.
+ """
+
+ # Parse all information into a list of messages
+ input_msgs = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, MessageBase):
+ input_msgs.append(_)
+ elif isinstance(_, list) and all(
+ isinstance(__, MessageBase) for __ in _
+ ):
+ input_msgs.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Msg object or a list "
+ f"of Msg objects, got {type(_)}.",
+ )
+
+ # record dialog history as a list of strings
+ system_content_template = []
+ dialogue = []
+ for i, unit in enumerate(input_msgs):
+ if i == 0 and unit.role == "system":
+ # system prompt
+ system_prompt = _convert_to_str(unit.content)
+ if not system_prompt.endswith("\n"):
+ system_prompt += "\n"
+ system_content_template.append(system_prompt)
+ else:
+ # Merge all messages into a dialogue history prompt
+ dialogue.append(
+ f"{unit.name}: {_convert_to_str(unit.content)}",
+ )
+
+ if len(dialogue) != 0:
+ system_content_template.extend(
+ ["## Dialogue History", "{dialogue_history}"],
+ )
+
+ dialogue_history = "\n".join(dialogue)
+
+ system_content_template = "\n".join(system_content_template)
+
+ messages = [
+ {
+ "role": "user",
+ "content": system_content_template.format(
+ dialogue_history=dialogue_history,
+ ),
+ },
+ ]
+
+ return messages
diff --git a/tests/format_test.py b/tests/format_test.py
index 42df7b960..661950743 100644
--- a/tests/format_test.py
+++ b/tests/format_test.py
@@ -12,6 +12,7 @@
ZhipuAIChatWrapper,
DashScopeChatWrapper,
DashScopeMultiModalWrapper,
+ LiteLLMChatWrapper,
)
@@ -211,6 +212,32 @@ def test_zhipuai_chat(self) -> None:
with self.assertRaises(TypeError):
model.format(*self.wrong_inputs) # type: ignore[arg-type]
+ def test_litellm_chat(self) -> None:
+ """Unit test for the format function in litellm chat api wrapper."""
+ model = LiteLLMChatWrapper(
+ config_name="",
+ model_name="gpt-3.5-turbo",
+ api_key="xxx",
+ )
+
+ ground_truth = [
+ {
+ "role": "user",
+ "content": (
+ "You are a helpful assistant\n\n"
+ "## Dialogue History\nuser: What is the weather today?\n"
+ "assistant: It is sunny today"
+ ),
+ },
+ ]
+
+ prompt = model.format(*self.inputs)
+ self.assertListEqual(prompt, ground_truth)
+
+ # wrong format
+ with self.assertRaises(TypeError):
+ model.format(*self.wrong_inputs) # type: ignore[arg-type]
+
def test_dashscope_multimodal_image(self) -> None:
"""Unit test for the format function in dashscope multimodal
conversation api wrapper for image."""
diff --git a/tests/litellm_test.py b/tests/litellm_test.py
new file mode 100644
index 000000000..3ee4a8503
--- /dev/null
+++ b/tests/litellm_test.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+"""litellm test"""
+import unittest
+from unittest.mock import patch, MagicMock
+
+import agentscope
+from agentscope.models import load_model_by_config_name
+
+
+class TestLiteLLMChatWrapper(unittest.TestCase):
+ """Test LiteLLM Chat Wrapper"""
+
+ def setUp(self) -> None:
+ self.api_key = "test_api_key.secret_key"
+ self.messages = [
+ {"role": "user", "content": "Hello, litellm!"},
+ {"role": "assistant", "content": "How can I assist you?"},
+ ]
+
+ @patch("agentscope.models.litellm_model.litellm")
+ def test_chat(self, mock_litellm: MagicMock) -> None:
+ """
+ Test chat"""
+ mock_response = MagicMock()
+ mock_response.model_dump.return_value = {
+ "choices": [
+ {"message": {"content": "Hello, this is a mocked response!"}},
+ ],
+ "usage": {
+ "prompt_tokens": 100,
+ "completion_tokens": 5,
+ "total_tokens": 105,
+ },
+ }
+ mock_response.choices[
+ 0
+ ].message.content = "Hello, this is a mocked response!"
+
+ mock_litellm.completion.return_value = mock_response
+
+ agentscope.init(
+ model_configs={
+ "config_name": "test_config",
+ "model_type": "litellm_chat",
+ "model_name": "ollama/llama3:8b",
+ "api_key": self.api_key,
+ },
+ )
+
+ model = load_model_by_config_name("test_config")
+
+ response = model(
+ messages=self.messages,
+ api_base="http://localhost:11434",
+ )
+
+ self.assertEqual(response.text, "Hello, this is a mocked response!")
+
+
+if __name__ == "__main__":
+ unittest.main()