Skip to content

Commit

Permalink
[HOTFIX] Add warning for different finish reasons in Gemini model API (
Browse files Browse the repository at this point in the history
…#233)

---------

Co-authored-by: DavdGao <[email protected]>
  • Loading branch information
zyzhang1130 and DavdGao authored May 24, 2024
1 parent 7995c4c commit 5d7638d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
59 changes: 55 additions & 4 deletions src/agentscope/models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

try:
import google.generativeai as genai

# This package will be installed when the google-generativeai is installed
import google.ai.generativelanguage as glm
except ImportError:
genai = None
glm = None


class GeminiWrapperBase(ModelWrapperBase, ABC):
Expand Down Expand Up @@ -42,6 +46,13 @@ def __init__(
"""
super().__init__(config_name=config_name)

# Test if the required package is installed
if genai is None:
raise ImportError(
"The google-generativeai package is not installed, "
"please install it first.",
)

# Load the api_key from argument or environment variable
api_key = api_key or os.environ.get("GOOGLE_API_KEY")

Expand Down Expand Up @@ -149,7 +160,50 @@ def __call__(
**kwargs,
)

# step3: record the api invocation if needed
# step3: Check for candidates and handle accordingly
if (
not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].text
):
# If we cannot get the response text from the model
finish_reason = response.candidates[0].finish_reason
reasons = glm.Candidate.FinishReason

if finish_reason == reasons.STOP:
error_info = (
"Natural stop point of the model or provided stop "
"sequence."
)
elif finish_reason == reasons.MAX_TOKENS:
error_info = (
"The maximum number of tokens as specified in the request "
"was reached."
)
elif finish_reason == reasons.SAFETY:
error_info = (
"The candidate content was flagged for safety reasons."
)
elif finish_reason == reasons.RECITATION:
error_info = (
"The candidate content was flagged for recitation reasons."
)
elif finish_reason in [
reasons.FINISH_REASON_UNSPECIFIED,
reasons.OTHER,
]:
error_info = "Unknown error."
else:
error_info = "No information provided from Gemini API."

raise ValueError(
"The Google Gemini API failed to generate text response with "
f"the following finish reason: {error_info}\n"
f"YOUR INPUT: {contents}\n"
f"RAW RESPONSE FROM GEMINI API: {response}\n",
)

# step4: record the api invocation if needed
self._save_model_invocation(
arguments={
"contents": contents,
Expand All @@ -160,9 +214,6 @@ def __call__(
)

# step5: update monitor accordingly
# TODO: Up to 2024/03/11, the response from Gemini doesn't contain
# the detailed information about cost. Here we simply count
# the tokens manually.
token_prompt = self.model.count_tokens(contents).total_tokens
token_response = self.model.count_tokens(response.text).total_tokens
self.update_monitor(
Expand Down
20 changes: 20 additions & 0 deletions tests/gemini_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,31 @@ def flush() -> None:
MonitorFactory.flush()


class DummyPart:
"""Dummy part for testing."""

text = "Hello! How can I help you?"


class DummyContent:
"""Dummy content for testing."""

parts = [DummyPart()]


class DummyCandidate:
"""Dummy candidate for testing."""

content = DummyContent()


class DummyResponse:
"""Dummy response for testing."""

text = "Hello! How can I help you?"

candidates = [DummyCandidate]

def __str__(self) -> str:
"""Return string representation."""
return str({"text": self.text})
Expand Down

0 comments on commit 5d7638d

Please sign in to comment.