Skip to content

Commit

Permalink
Handle Anthropic API errors gracefully
Browse files Browse the repository at this point in the history
Fixes #71, #38
  • Loading branch information
kharvd committed May 15, 2024
1 parent 3c37b7c commit e61e4fe
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 32 deletions.
26 changes: 18 additions & 8 deletions gptcli/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from typing import Iterator, List
import anthropic

from gptcli.completion import CompletionProvider, Message
from gptcli.completion import (
CompletionProvider,
Message,
CompletionError,
BadRequestError,
)

api_key = os.environ.get("ANTHROPIC_API_KEY")

Expand Down Expand Up @@ -53,13 +58,18 @@ def complete(
kwargs["messages"] = messages

client = get_client()
if stream:
with client.messages.stream(**kwargs) as stream:
for text in stream.text_stream:
yield text
else:
response = client.messages.create(**kwargs, stream=False)
yield "".join(c.text for c in response.content)
try:
if stream:
with client.messages.stream(**kwargs) as completion:
for text in completion.text_stream:
yield text
else:
response = client.messages.create(**kwargs, stream=False)
yield "".join(c.text for c in response.content)
except anthropic.BadRequestError as e:
raise BadRequestError(e.message) from e
except anthropic.APIError as e:
raise CompletionError(e.message) from e


def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> int:
Expand Down
8 changes: 8 additions & 0 deletions gptcli/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,11 @@ def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
pass


class CompletionError(Exception):
pass


class BadRequestError(CompletionError):
pass
52 changes: 31 additions & 21 deletions gptcli/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import tiktoken

from gptcli.completion import CompletionProvider, Message
from gptcli.completion import (
CompletionProvider,
Message,
CompletionError,
BadRequestError,
)


class OpenAICompletionProvider(CompletionProvider):
Expand All @@ -21,28 +26,33 @@ def complete(
if "top_p" in args:
kwargs["top_p"] = args["top_p"]

if stream:
response_iter = self.client.chat.completions.create(
messages=cast(List[ChatCompletionMessageParam], messages),
stream=True,
model=args["model"],
**kwargs,
)
try:
if stream:
response_iter = self.client.chat.completions.create(
messages=cast(List[ChatCompletionMessageParam], messages),
stream=True,
model=args["model"],
**kwargs,
)

for response in response_iter:
for response in response_iter:
next_choice = response.choices[0]
if next_choice.finish_reason is None and next_choice.delta.content:
yield next_choice.delta.content
else:
response = self.client.chat.completions.create(
messages=cast(List[ChatCompletionMessageParam], messages),
model=args["model"],
stream=False,
**kwargs,
)
next_choice = response.choices[0]
if next_choice.finish_reason is None and next_choice.delta.content:
yield next_choice.delta.content
else:
response = self.client.chat.completions.create(
messages=cast(List[ChatCompletionMessageParam], messages),
model=args["model"],
stream=False,
**kwargs,
)
next_choice = response.choices[0]
if next_choice.message.content:
yield next_choice.message.content
if next_choice.message.content:
yield next_choice.message.content
except openai.BadRequestError as e:
raise BadRequestError(e.message) from e
except openai.APIError as e:
raise CompletionError(e.message) from e


def num_tokens_from_messages_openai(messages: List[Message], model: str) -> int:
Expand Down
5 changes: 2 additions & 3 deletions gptcli/session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import abstractmethod
from typing_extensions import TypeGuard
from gptcli.assistant import Assistant
from gptcli.completion import Message, ModelOverrides
from openai import BadRequestError, OpenAIError
from gptcli.completion import Message, ModelOverrides, CompletionError, BadRequestError
from typing import Any, Dict, List, Tuple


Expand Down Expand Up @@ -115,7 +114,7 @@ def _respond(self, args: ModelOverrides) -> bool:
except BadRequestError as e:
self.listener.on_error(e)
return False
except OpenAIError as e:
except CompletionError as e:
self.listener.on_error(e)
return True

Expand Down

0 comments on commit e61e4fe

Please sign in to comment.