From 8c31ae634c0b8a715bfa122033674bd049634549 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Thu, 31 Oct 2024 14:21:48 +0800 Subject: [PATCH] Support vision models in OpenAI token counting function; Add unit test for OpenAI vision models --- src/agentscope/logging.py | 2 +- src/agentscope/service/__init__.py | 3 +- src/agentscope/tokens.py | 95 ++++++++++++++++++++++++++++-- tests/tokens_test.py | 21 +++++++ 4 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/agentscope/logging.py b/src/agentscope/logging.py index 951de472a..498c007b2 100644 --- a/src/agentscope/logging.py +++ b/src/agentscope/logging.py @@ -32,7 +32,7 @@ LEVEL_SAVE_MSG = "SAVE_MSG" _DEFAULT_LOG_FORMAT = ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | {" + "{time:YYYY-MM-DD HH:mm:ss} | {" "level: <8} | {name}:{" "function}:{line} - {" "message}" diff --git a/src/agentscope/service/__init__.py b/src/agentscope/service/__init__.py index 7d33e6501..472379d6c 100644 --- a/src/agentscope/service/__init__.py +++ b/src/agentscope/service/__init__.py @@ -66,7 +66,8 @@ def get_help() -> None: """Get help message.""" - help_msg = f"The following service are available:\n{__all__}" + tools = "\n - ".join(__all__) + help_msg = f"The following service are available:\n{tools}" logger.info(help_msg) diff --git a/src/agentscope/tokens.py b/src/agentscope/tokens.py index 31c9e42de..cd172b42a 100644 --- a/src/agentscope/tokens.py +++ b/src/agentscope/tokens.py @@ -2,7 +2,7 @@ """The tokens interface for agentscope.""" import os from http import HTTPStatus -from typing import Callable, Union, Optional +from typing import Callable, Union, Optional, Any from loguru import logger @@ -51,7 +51,7 @@ def count(model_name: str, messages: list[dict[str, str]]) -> int: return count_gemini_tokens(model_name, messages) # Dashscope - elif model_name in ["qwen-"]: + elif model_name.startswith("qwen-"): return count_dashscope_tokens(model_name, messages) else: @@ -63,7 +63,77 @@ def count(model_name: str, messages: list[dict[str, str]]) -> int: ) -def count_openai_tokens( +def _count_content_tokens_for_openai_vision_model( + content: list[dict], + encoding: Any, +) -> int: + """Yield the number of tokens for the content of an OpenAI vision model. + Implemented according to https://platform.openai.com/docs/guides/vision. + + Args: + content (`list[dict]`): + A list of dictionaries. + encoding (`Any`): + The encoding object. + + Example: + .. code-block:: python + + _yield_tokens_for_openai_vision_model( + [ + { + "type": "text", + "text": "xxx", + }, + { + "type": "image_url", + "image_url": { + "url": "xxx", + "detail": "auto", + } + }, + # ... + ] + ) + + Returns: + `Generator[int, None, None]`: Generate the number of tokens in a + generator. + """ + num_tokens = 0 + for item in content: + if not isinstance(item, dict): + raise TypeError( + "If you're using a vision model for OpenAI models," + "The content field should be a list of " + f"dictionaries, but got {type(item)}.", + ) + + typ = item.get("type", None) + if typ == "text": + num_tokens += len(encoding.encode(item["text"])) + + elif typ == "image_url": + # By default, we use high here to avoid undercounting tokens + detail = item.get("image_url").get("detail", "high") + if detail == "low": + num_tokens += 85 + elif detail in ["auto", "high"]: + num_tokens += 170 + else: + raise ValueError( + f"Unsupported image detail {detail}, expected " + f"one of ['low', 'auto', 'high'].", + ) + else: + raise ValueError( + "The type field currently only supports 'text' " + f"and 'image_url', but got {typ}.", + ) + return num_tokens + + +def count_openai_tokens( # pylint: disable=too-many-branches model_name: str, messages: list[dict[str, str]], ) -> int: @@ -81,8 +151,8 @@ def count_openai_tokens( try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - print("Warning: model not found. Using o200k_base encoding.") encoding = tiktoken.get_encoding("o200k_base") + if model_name in { "gpt-3.5-turbo-0125", "gpt-4-0314", @@ -121,9 +191,24 @@ def count_openai_tokens( for message in messages: num_tokens += tokens_per_message for key, value in message.items(): - num_tokens += len(encoding.encode(value)) + # Considering vision models + if key == "content" and isinstance(value, list): + num_tokens += _count_content_tokens_for_openai_vision_model( + value, + encoding, + ) + + elif isinstance(value, str): + num_tokens += len(encoding.encode(value)) + + else: + raise TypeError( + f"Invalid type {type(value)} in the {key} field.", + ) + if key == "name": num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens diff --git a/tests/tokens_test.py b/tests/tokens_test.py index b545625ca..193034b50 100644 --- a/tests/tokens_test.py +++ b/tests/tokens_test.py @@ -50,6 +50,24 @@ def setUp(self) -> None: "name": "Friday", }, ] + self.messages_openai_vision = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I want to book a flight to Paris.", + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto", + }, + }, + ], + }, + ] def test_openai_token_counting(self) -> None: """Test OpenAI token counting functions.""" @@ -59,6 +77,9 @@ def test_openai_token_counting(self) -> None: n_tokens = count_openai_tokens("gpt-4o", self.messages) self.assertEqual(n_tokens, 32) + n_tokens = count_openai_tokens("gpt-4o", self.messages_openai_vision) + self.assertEqual(n_tokens, 186) + @patch("dashscope.Tokenization.call") def test_dashscope_token_counting(self, mock_call: MagicMock) -> None: """Test Dashscope token counting functions."""