diff --git a/example/src/main/java/example/OpenAiApiFunctionsExample.java b/example/src/main/java/example/OpenAiApiFunctionsExample.java index 954b9104..50c29160 100644 --- a/example/src/main/java/example/OpenAiApiFunctionsExample.java +++ b/example/src/main/java/example/OpenAiApiFunctionsExample.java @@ -7,6 +7,7 @@ import com.theokanning.openai.service.FunctionExecutor; import com.theokanning.openai.service.OpenAiService; +import java.net.SocketTimeoutException; import java.util.*; class OpenAiApiFunctionsExample { @@ -38,7 +39,7 @@ public WeatherResponse(String location, WeatherUnit unit, int temperature, Strin } } - public static void main(String... args) { + public static void main(String... args) throws SocketTimeoutException { String token = System.getenv("OPENAI_TOKEN"); OpenAiService service = new OpenAiService(token); diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 7114531b..75f4d3c4 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -49,6 +49,7 @@ import javax.validation.constraints.NotNull; import java.io.IOException; +import java.net.SocketTimeoutException; import java.time.Duration; import java.time.LocalDate; import java.util.List; @@ -134,8 +135,15 @@ public Flowable streamCompletion(CompletionRequest request) { return stream(api.createCompletionStream(request), CompletionChunk.class); } - public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { - return execute(api.createChatCompletion(request)); + public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) throws SocketTimeoutException{ + try { + return execute(api.createChatCompletion(request)); + } catch (RuntimeException e) { + if (e.getCause() != null && e.getCause() instanceof SocketTimeoutException) + throw new SocketTimeoutException(e.getCause().getMessage()); + else + throw e; + } } public Flowable streamChatCompletion(ChatCompletionRequest request) { diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 3d26bf03..eab0af41 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -7,6 +7,7 @@ import com.theokanning.openai.completion.chat.*; import org.junit.jupiter.api.Test; +import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -47,7 +48,7 @@ public WeatherResponse(String location, WeatherUnit unit, int temperature, Strin OpenAiService service = new OpenAiService(token); @Test - void createChatCompletion() { + void createChatCompletion() throws SocketTimeoutException { final List messages = new ArrayList<>(); final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such."); messages.add(systemMessage); @@ -88,7 +89,7 @@ void streamChatCompletion() { } @Test - void createChatCompletionWithFunctions() { + void createChatCompletionWithFunctions() throws SocketTimeoutException { final List functions = Collections.singletonList(ChatFunction.builder() .name("get_weather") .description("Get the current weather in a given location")