From e61e4fe32c35d35f616f9f053fd116f0eca541d9 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Tue, 14 May 2024 21:33:49 -0400 Subject: [PATCH] Handle Anthropic API errors gracefully Fixes #71, #38 --- gptcli/anthropic.py | 26 +++++++++++++++------- gptcli/completion.py | 8 +++++++ gptcli/openai.py | 52 ++++++++++++++++++++++++++------------------ gptcli/session.py | 5 ++--- 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/gptcli/anthropic.py b/gptcli/anthropic.py index a9c343e..259cee6 100644 --- a/gptcli/anthropic.py +++ b/gptcli/anthropic.py @@ -2,7 +2,12 @@ from typing import Iterator, List import anthropic -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import ( + CompletionProvider, + Message, + CompletionError, + BadRequestError, +) api_key = os.environ.get("ANTHROPIC_API_KEY") @@ -53,13 +58,18 @@ def complete( kwargs["messages"] = messages client = get_client() - if stream: - with client.messages.stream(**kwargs) as stream: - for text in stream.text_stream: - yield text - else: - response = client.messages.create(**kwargs, stream=False) - yield "".join(c.text for c in response.content) + try: + if stream: + with client.messages.stream(**kwargs) as completion: + for text in completion.text_stream: + yield text + else: + response = client.messages.create(**kwargs, stream=False) + yield "".join(c.text for c in response.content) + except anthropic.BadRequestError as e: + raise BadRequestError(e.message) from e + except anthropic.APIError as e: + raise CompletionError(e.message) from e def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> int: diff --git a/gptcli/completion.py b/gptcli/completion.py index ea82d8f..effa500 100644 --- a/gptcli/completion.py +++ b/gptcli/completion.py @@ -19,3 +19,11 @@ def complete( self, messages: List[Message], args: dict, stream: bool = False ) -> Iterator[str]: pass + + +class CompletionError(Exception): + pass + + +class BadRequestError(CompletionError): + pass diff --git a/gptcli/openai.py b/gptcli/openai.py index 70acc9a..286711b 100644 --- a/gptcli/openai.py +++ b/gptcli/openai.py @@ -5,7 +5,12 @@ import tiktoken -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import ( + CompletionProvider, + Message, + CompletionError, + BadRequestError, +) class OpenAICompletionProvider(CompletionProvider): @@ -21,28 +26,33 @@ def complete( if "top_p" in args: kwargs["top_p"] = args["top_p"] - if stream: - response_iter = self.client.chat.completions.create( - messages=cast(List[ChatCompletionMessageParam], messages), - stream=True, - model=args["model"], - **kwargs, - ) + try: + if stream: + response_iter = self.client.chat.completions.create( + messages=cast(List[ChatCompletionMessageParam], messages), + stream=True, + model=args["model"], + **kwargs, + ) - for response in response_iter: + for response in response_iter: + next_choice = response.choices[0] + if next_choice.finish_reason is None and next_choice.delta.content: + yield next_choice.delta.content + else: + response = self.client.chat.completions.create( + messages=cast(List[ChatCompletionMessageParam], messages), + model=args["model"], + stream=False, + **kwargs, + ) next_choice = response.choices[0] - if next_choice.finish_reason is None and next_choice.delta.content: - yield next_choice.delta.content - else: - response = self.client.chat.completions.create( - messages=cast(List[ChatCompletionMessageParam], messages), - model=args["model"], - stream=False, - **kwargs, - ) - next_choice = response.choices[0] - if next_choice.message.content: - yield next_choice.message.content + if next_choice.message.content: + yield next_choice.message.content + except openai.BadRequestError as e: + raise BadRequestError(e.message) from e + except openai.APIError as e: + raise CompletionError(e.message) from e def num_tokens_from_messages_openai(messages: List[Message], model: str) -> int: diff --git a/gptcli/session.py b/gptcli/session.py index 3e7ae1d..120c322 100644 --- a/gptcli/session.py +++ b/gptcli/session.py @@ -1,8 +1,7 @@ from abc import abstractmethod from typing_extensions import TypeGuard from gptcli.assistant import Assistant -from gptcli.completion import Message, ModelOverrides -from openai import BadRequestError, OpenAIError +from gptcli.completion import Message, ModelOverrides, CompletionError, BadRequestError from typing import Any, Dict, List, Tuple @@ -115,7 +114,7 @@ def _respond(self, args: ModelOverrides) -> bool: except BadRequestError as e: self.listener.on_error(e) return False - except OpenAIError as e: + except CompletionError as e: self.listener.on_error(e) return True