Skip to content

Commit

Permalink
Fix retry decorator exception handling logic (#19)
Browse files Browse the repository at this point in the history
* updated retry logic and added test

* changed exception type to tuple

* updated test

* update openai_exception type

* fixed test names

* tuple updates

* split test into two
  • Loading branch information
edwardcqian authored Aug 22, 2023
1 parent adf21d3 commit 1ace1f4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion llm_gateway/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from openai.error import Timeout, APIError, APIConnectionError, TryAgain

OPENAI_EXCEPTIONS = [Timeout, APIError, APIConnectionError, TryAgain]
OPENAI_EXCEPTIONS = (Timeout, APIError, APIConnectionError, TryAgain)
SUPPORTED_OPENAI_ENDPOINTS = {
"Model": ["list", "retrieve"],
"ChatCompletion": ["create"],
Expand Down
2 changes: 1 addition & 1 deletion llm_gateway/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
logger = logging.getLogger(__name__)


def max_retries(times: int, exceptions: list = [Exception]):
def max_retries(times: int, exceptions: tuple = (Exception,)):
"""
Max Retry Decorator
Retries the wrapped function/method `times` times
Expand Down
29 changes: 29 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import Mock

import pytest
from llm_gateway.utils import max_retries
from openai.error import APIError


def test_retry_decorator_mismatch_exception():
retry_mock = Mock()
retry_mock.side_effect = [APIError("test"), "success"]

@max_retries(1, exceptions=(ValueError,))
def mismatch_exception():
return retry_mock()

with pytest.raises(APIError):
mismatch_exception()


def test_retry_decorator_matching_exception():
retry_mock = Mock()
retry_mock.side_effect = [APIError("test"), "success"]

## Matching retry exception
@max_retries(1, exceptions=(APIError,))
def matching_exception():
return retry_mock()

assert matching_exception() == "success"

0 comments on commit 1ace1f4

Please sign in to comment.