diff --git a/src/agentscope/models/gemini_model.py b/src/agentscope/models/gemini_model.py index 7625bf6f1..3deca77d1 100644 --- a/src/agentscope/models/gemini_model.py +++ b/src/agentscope/models/gemini_model.py @@ -13,8 +13,12 @@ try: import google.generativeai as genai + + # This package will be installed when the google-generativeai is installed + import google.ai.generativelanguage as glm except ImportError: genai = None + glm = None class GeminiWrapperBase(ModelWrapperBase, ABC): @@ -42,6 +46,13 @@ def __init__( """ super().__init__(config_name=config_name) + # Test if the required package is installed + if genai is None: + raise ImportError( + "The google-generativeai package is not installed, " + "please install it first.", + ) + # Load the api_key from argument or environment variable api_key = api_key or os.environ.get("GOOGLE_API_KEY") @@ -149,7 +160,50 @@ def __call__( **kwargs, ) - # step3: record the api invocation if needed + # step3: Check for candidates and handle accordingly + if ( + not response.candidates[0].content + or not response.candidates[0].content.parts + or not response.candidates[0].content.parts[0].text + ): + # If we cannot get the response text from the model + finish_reason = response.candidates[0].finish_reason + reasons = glm.Candidate.FinishReason + + if finish_reason == reasons.STOP: + error_info = ( + "Natural stop point of the model or provided stop " + "sequence." + ) + elif finish_reason == reasons.MAX_TOKENS: + error_info = ( + "The maximum number of tokens as specified in the request " + "was reached." + ) + elif finish_reason == reasons.SAFETY: + error_info = ( + "The candidate content was flagged for safety reasons." + ) + elif finish_reason == reasons.RECITATION: + error_info = ( + "The candidate content was flagged for recitation reasons." + ) + elif finish_reason in [ + reasons.FINISH_REASON_UNSPECIFIED, + reasons.OTHER, + ]: + error_info = "Unknown error." + else: + error_info = "No information provided from Gemini API." + + raise ValueError( + "The Google Gemini API failed to generate text response with " + f"the following finish reason: {error_info}\n" + f"YOUR INPUT: {contents}\n" + f"RAW RESPONSE FROM GEMINI API: {response}\n", + ) + + # step4: record the api invocation if needed self._save_model_invocation( arguments={ "contents": contents, @@ -160,9 +214,6 @@ def __call__( ) # step5: update monitor accordingly - # TODO: Up to 2024/03/11, the response from Gemini doesn't contain - # the detailed information about cost. Here we simply count - # the tokens manually. token_prompt = self.model.count_tokens(contents).total_tokens token_response = self.model.count_tokens(response.text).total_tokens self.update_monitor( diff --git a/tests/gemini_test.py b/tests/gemini_test.py index 20bb75cbb..f854fa7ac 100644 --- a/tests/gemini_test.py +++ b/tests/gemini_test.py @@ -20,11 +20,31 @@ def flush() -> None: MonitorFactory.flush() +class DummyPart: + """Dummy part for testing.""" + + text = "Hello! How can I help you?" + + +class DummyContent: + """Dummy content for testing.""" + + parts = [DummyPart()] + + +class DummyCandidate: + """Dummy candidate for testing.""" + + content = DummyContent() + + class DummyResponse: """Dummy response for testing.""" text = "Hello! How can I help you?" + candidates = [DummyCandidate] + def __str__(self) -> str: """Return string representation.""" return str({"text": self.text})