Skip to content

Commit

Permalink
Support vision models in OpenAI token counting function; Add unit tes…
Browse files Browse the repository at this point in the history
…t for OpenAI vision models
  • Loading branch information
DavdGao committed Oct 31, 2024
1 parent aeb5f7e commit 8c31ae6
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/agentscope/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
LEVEL_SAVE_MSG = "SAVE_MSG"

_DEFAULT_LOG_FORMAT = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{"
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{"
"level: <8}</level> | <cyan>{name}</cyan>:<cyan>{"
"function}</cyan>:<cyan>{line}</cyan> - <level>{"
"message}</level>"
Expand Down
3 changes: 2 additions & 1 deletion src/agentscope/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@

def get_help() -> None:
"""Get help message."""
help_msg = f"The following service are available:\n{__all__}"
tools = "\n - ".join(__all__)
help_msg = f"The following service are available:\n{tools}"
logger.info(help_msg)


Expand Down
95 changes: 90 additions & 5 deletions src/agentscope/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""The tokens interface for agentscope."""
import os
from http import HTTPStatus
from typing import Callable, Union, Optional
from typing import Callable, Union, Optional, Any

from loguru import logger

Expand Down Expand Up @@ -51,7 +51,7 @@ def count(model_name: str, messages: list[dict[str, str]]) -> int:
return count_gemini_tokens(model_name, messages)

# Dashscope
elif model_name in ["qwen-"]:
elif model_name.startswith("qwen-"):
return count_dashscope_tokens(model_name, messages)

else:
Expand All @@ -63,7 +63,77 @@ def count(model_name: str, messages: list[dict[str, str]]) -> int:
)


def count_openai_tokens(
def _count_content_tokens_for_openai_vision_model(
content: list[dict],
encoding: Any,
) -> int:
"""Yield the number of tokens for the content of an OpenAI vision model.
Implemented according to https://platform.openai.com/docs/guides/vision.
Args:
content (`list[dict]`):
A list of dictionaries.
encoding (`Any`):
The encoding object.
Example:
.. code-block:: python
_yield_tokens_for_openai_vision_model(
[
{
"type": "text",
"text": "xxx",
},
{
"type": "image_url",
"image_url": {
"url": "xxx",
"detail": "auto",
}
},
# ...
]
)
Returns:
`Generator[int, None, None]`: Generate the number of tokens in a
generator.
"""
num_tokens = 0
for item in content:
if not isinstance(item, dict):
raise TypeError(
"If you're using a vision model for OpenAI models,"
"The content field should be a list of "
f"dictionaries, but got {type(item)}.",
)

typ = item.get("type", None)
if typ == "text":
num_tokens += len(encoding.encode(item["text"]))

elif typ == "image_url":
# By default, we use high here to avoid undercounting tokens
detail = item.get("image_url").get("detail", "high")
if detail == "low":
num_tokens += 85
elif detail in ["auto", "high"]:
num_tokens += 170
else:
raise ValueError(
f"Unsupported image detail {detail}, expected "
f"one of ['low', 'auto', 'high'].",
)
else:
raise ValueError(
"The type field currently only supports 'text' "
f"and 'image_url', but got {typ}.",
)
return num_tokens


def count_openai_tokens( # pylint: disable=too-many-branches
model_name: str,
messages: list[dict[str, str]],
) -> int:
Expand All @@ -81,8 +151,8 @@ def count_openai_tokens(
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
print("Warning: model not found. Using o200k_base encoding.")
encoding = tiktoken.get_encoding("o200k_base")

if model_name in {
"gpt-3.5-turbo-0125",
"gpt-4-0314",
Expand Down Expand Up @@ -121,9 +191,24 @@ def count_openai_tokens(
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
# Considering vision models
if key == "content" and isinstance(value, list):
num_tokens += _count_content_tokens_for_openai_vision_model(
value,
encoding,
)

elif isinstance(value, str):
num_tokens += len(encoding.encode(value))

else:
raise TypeError(
f"Invalid type {type(value)} in the {key} field.",
)

if key == "name":
num_tokens += tokens_per_name

num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

Expand Down
21 changes: 21 additions & 0 deletions tests/tokens_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def setUp(self) -> None:
"name": "Friday",
},
]
self.messages_openai_vision = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "I want to book a flight to Paris.",
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.jpg",
"detail": "auto",
},
},
],
},
]

def test_openai_token_counting(self) -> None:
"""Test OpenAI token counting functions."""
Expand All @@ -59,6 +77,9 @@ def test_openai_token_counting(self) -> None:
n_tokens = count_openai_tokens("gpt-4o", self.messages)
self.assertEqual(n_tokens, 32)

n_tokens = count_openai_tokens("gpt-4o", self.messages_openai_vision)
self.assertEqual(n_tokens, 186)

@patch("dashscope.Tokenization.call")
def test_dashscope_token_counting(self, mock_call: MagicMock) -> None:
"""Test Dashscope token counting functions."""
Expand Down

0 comments on commit 8c31ae6

Please sign in to comment.