From 656a21974da48e3cb4324d5cff934932c0f39e07 Mon Sep 17 00:00:00 2001 From: Steve Phelps Date: Tue, 7 Nov 2023 08:42:23 +0000 Subject: [PATCH] Retry on APIConnectionError --- src/openai_pygenerator/openai_pygenerator.py | 8 +++++++- tests/test_gpt.py | 14 ++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/openai_pygenerator/openai_pygenerator.py b/src/openai_pygenerator/openai_pygenerator.py index 9edb053..b763a54 100644 --- a/src/openai_pygenerator/openai_pygenerator.py +++ b/src/openai_pygenerator/openai_pygenerator.py @@ -26,7 +26,12 @@ import openai import urllib3.exceptions -from openai.error import APIError, RateLimitError, ServiceUnavailableError +from openai.error import ( + APIConnectionError, + APIError, + RateLimitError, + ServiceUnavailableError, +) Completion = Dict[str, str] Seconds = NewType("Seconds", int) @@ -118,6 +123,7 @@ def generate_completions( openai.error.Timeout, urllib3.exceptions.TimeoutError, RateLimitError, + APIConnectionError, APIError, ServiceUnavailableError, ) as err: diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 6383c74..8886b6c 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -23,7 +23,12 @@ import openai.error import pytest import urllib3.exceptions as urlex -from openai.error import APIError, RateLimitError, ServiceUnavailableError +from openai.error import ( + APIConnectionError, + APIError, + RateLimitError, + ServiceUnavailableError, +) from openai.openai_object import OpenAIObject from openai_pygenerator import ( @@ -72,6 +77,7 @@ def make_test_completion(role: str) -> Completion: "error", [ RateLimitError("rate limited", http_status=429), + APIConnectionError("connection timeout"), APIError("Gateway Timeout", http_status=524), ServiceUnavailableError( message=( @@ -89,7 +95,7 @@ def test_generate_completion(mock_openai, mock_sleep, error): MockChoices(["Test completion 1", "Test completion 2"]), ] - completions = list(gpt_completions([])) + completions = list(gpt_completions([])) # type: ignore assert completions == ["Test completion 1", "Test completion 2"] assert mock_sleep.call_count == 2 @@ -102,7 +108,7 @@ def test_generate_completion(mock_openai, mock_sleep, error): APIError("Gateway Timeout", http_status=524), APIError("Server shutdown", http_status=500), ServiceUnavailableError("Service unavailable"), - urlex.ReadTimeoutError("test-pool", "http://test", "read timeout"), + urlex.ReadTimeoutError("test-pool", "http://test", "read timeout"), # type: ignore openai.error.Timeout, ], ) @@ -110,7 +116,7 @@ def test_generate_completion_error(mock_openai, mock_sleep, error): mock_openai.side_effect = [error] * GPT_MAX_RETRIES with pytest.raises(Exception): - _ = list(gpt_completions([])) + _ = list(gpt_completions([])) # type: ignore assert mock_sleep.call_count == GPT_MAX_RETRIES