From 2e3d4a688b16b318d634178d05d4c65907e969a7 Mon Sep 17 00:00:00 2001 From: vacuity Date: Mon, 11 Dec 2023 16:55:31 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8feature(base=5Furl):=20user=20can=20de?= =?UTF-8?q?fine=20custom=20openai=20base=20url?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../openai/service/OpenAiService.java | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 52ab6b0f..50ad560d 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -84,6 +84,16 @@ public OpenAiService(final String token) { this(token, DEFAULT_TIMEOUT); } + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + * @param baseUrl OpenAi baseUrl, default is "https://api.openai.com/" + */ + public OpenAiService(final String token, String baseUrl) { + this(token, DEFAULT_TIMEOUT, baseUrl); + } + /** * Creates a new OpenAiService that wraps OpenAiApi * @@ -91,9 +101,20 @@ public OpenAiService(final String token) { * @param timeout http read timeout, Duration.ZERO means no timeout */ public OpenAiService(final String token, final Duration timeout) { + this(token, timeout, null); + } + + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + * @param timeout http read timeout, Duration.ZERO means no timeout + * @param baseUrl OpenAi baseUrl, default is "https://api.openai.com/" + */ + public OpenAiService(final String token, final Duration timeout, String baseUrl) { ObjectMapper mapper = defaultObjectMapper(); OkHttpClient client = defaultClient(token, timeout); - Retrofit retrofit = defaultRetrofit(client, mapper); + Retrofit retrofit = defaultRetrofit(client, mapper, baseUrl); this.api = retrofit.create(OpenAiApi.class); this.executorService = client.dispatcher().executorService(); @@ -572,7 +593,7 @@ public void shutdownExecutor() { public static OpenAiApi buildApi(String token, Duration timeout) { ObjectMapper mapper = defaultObjectMapper(); OkHttpClient client = defaultClient(token, timeout); - Retrofit retrofit = defaultRetrofit(client, mapper); + Retrofit retrofit = defaultRetrofit(client, mapper, null); return retrofit.create(OpenAiApi.class); } @@ -596,14 +617,18 @@ public static OkHttpClient defaultClient(String token, Duration timeout) { .build(); } - public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) { + public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper, String baseUrl) { + if (baseUrl == null || "".equals(baseUrl)) { + baseUrl = BASE_URL; + } return new Retrofit.Builder() - .baseUrl(BASE_URL) + .baseUrl(baseUrl) .client(client) .addConverterFactory(JacksonConverterFactory.create(mapper)) .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) .build(); } + public Flowable mapStreamToAccumulator(Flowable flowable) { ChatFunctionCall functionCall = new ChatFunctionCall(null, null);