From 5ab2557f4035289df7eda00bc77a48eb2099908f Mon Sep 17 00:00:00 2001 From: Raj Chauhan Date: Mon, 1 Apr 2024 21:43:14 -0300 Subject: [PATCH] Commited after Refactoring. --- .../openai/utils/TikTokensUtil.java | 38 ++++++--- build.gradle | 3 + .../openai/AuthenticationInterceptor.java | 1 + .../main/java/example/OpenAiApiExample.java | 6 +- ...allArgumentsSerializerAndDeserializer.java | 67 +++++++++------ .../openai/service/FunctionExecutor.java | 62 +++++++------- .../openai/service/ResponseBodyCallback.java | 84 +++++++++++-------- .../openai/service/AssistantFunctionTest.java | 2 +- .../openai/service/AssistantTest.java | 2 +- .../theokanning/openai/service/RunTest.java | 2 +- 10 files changed, 157 insertions(+), 110 deletions(-) diff --git a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java index 0a50907e..5157f747 100644 --- a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java +++ b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java @@ -173,30 +173,42 @@ public static int tokens(String modelName, List messages) { Encoding encoding = getEncoding(modelName); int tokensPerMessage = 0; int tokensPerName = 0; - //3.5统一处理 + + // Constants for token counts per message and name + final int TOKENS_PER_MESSAGE_GPT_3_5_TURBO = 4; + final int TOKENS_PER_MESSAGE_GPT_4 = 3; + final int TOKENS_PER_NAME = 1; + + // Determine token counts based on model if (modelName.equals("gpt-3.5-turbo-0301") || modelName.equals("gpt-3.5-turbo")) { - tokensPerMessage = 4; + tokensPerMessage = TOKENS_PER_MESSAGE_GPT_3_5_TURBO; tokensPerName = -1; } - //4.0统一处理 if (modelName.equals("gpt-4") || modelName.equals("gpt-4-0314")) { - tokensPerMessage = 3; - tokensPerName = 1; + tokensPerMessage = TOKENS_PER_MESSAGE_GPT_4; + tokensPerName = TOKENS_PER_NAME; } - int sum = 0; + + int totalTokens = 0; // Variable to hold total tokens + for (ChatMessage msg : messages) { - sum += tokensPerMessage; - sum += tokens(encoding, msg.getContent()); - sum += tokens(encoding, msg.getRole()); - sum += tokens(encoding, msg.getName()); + int messageTokens = tokens(encoding, msg.getContent()) + + tokens(encoding, msg.getRole()) + + tokens(encoding, msg.getName()); + if (isNotBlank(msg.getName())) { - sum += tokensPerName; + messageTokens += tokensPerName; } + + totalTokens += tokensPerMessage + messageTokens; } - sum += 3; - return sum; + + totalTokens += 3; // Additional tokens for processing + + return totalTokens; } + /** * Reverse the string text through the model name and the encoded array. * diff --git a/build.gradle b/build.gradle index 23e4f934..dd3ce57b 100644 --- a/build.gradle +++ b/build.gradle @@ -9,3 +9,6 @@ allprojects { } } } + + + diff --git a/client/src/main/java/com/theokanning/openai/AuthenticationInterceptor.java b/client/src/main/java/com/theokanning/openai/AuthenticationInterceptor.java index fbe9a5b4..4cea5b08 100644 --- a/client/src/main/java/com/theokanning/openai/AuthenticationInterceptor.java +++ b/client/src/main/java/com/theokanning/openai/AuthenticationInterceptor.java @@ -1,5 +1,6 @@ package com.theokanning.openai; + /** * OkHttp Interceptor that adds an authorization token header * diff --git a/example/src/main/java/example/OpenAiApiExample.java b/example/src/main/java/example/OpenAiApiExample.java index 52ae1ccf..bd8b935b 100644 --- a/example/src/main/java/example/OpenAiApiExample.java +++ b/example/src/main/java/example/OpenAiApiExample.java @@ -29,7 +29,7 @@ public static void main(String... args) { System.out.println("\nCreating Image..."); CreateImageRequest request = CreateImageRequest.builder() - .prompt("A cow breakdancing with a turtle") + .prompt("A+ in coding assignment") .build(); System.out.println("\nImage is located at:"); @@ -48,10 +48,6 @@ public static void main(String... args) { .logitBias(new HashMap<>()) .build(); - service.streamChatCompletion(chatCompletionRequest) - .doOnError(Throwable::printStackTrace) - .blockingForEach(System.out::println); - service.shutdownExecutor(); } } diff --git a/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallArgumentsSerializerAndDeserializer.java b/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallArgumentsSerializerAndDeserializer.java index 9b7be0f9..b2beecd8 100644 --- a/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallArgumentsSerializerAndDeserializer.java +++ b/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallArgumentsSerializerAndDeserializer.java @@ -1,24 +1,32 @@ package com.theokanning.openai.service; - import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; -import com.fasterxml.jackson.databind.*; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.JsonNodeType; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; + import com.fasterxml.jackson.databind.node.TextNode; +import com.fasterxml.jackson.databind.*; + import java.io.IOException; +import java.util.HashMap; +import java.util.Map; -public class ChatFunctionCallArgumentsSerializerAndDeserializer { - private final static ObjectMapper MAPPER = new ObjectMapper(); + +public class ChatFunctionCallArgumentsSerializerAndDeserializer { + private static final ObjectMapper MAPPER = new ObjectMapper(); private ChatFunctionCallArgumentsSerializerAndDeserializer() { } public static class Serializer extends JsonSerializer { - private Serializer() { } @@ -32,33 +40,40 @@ public void serialize(JsonNode value, JsonGenerator gen, SerializerProvider seri } } + public abstract static class JsonNodeHandler { + public abstract JsonNode handle(JsonParser p, DeserializationContext ctxt) throws IOException; + } + + public static class MissingNodeHandler extends JsonNodeHandler { + @Override + public JsonNode handle(JsonParser p, DeserializationContext ctxt) { + return JsonNodeFactory.instance.missingNode(); + } + } + + public static class DefaultNodeHandler extends JsonNodeHandler { + @Override + public JsonNode handle(JsonParser p, DeserializationContext ctxt) throws IOException { + return MAPPER.readTree(p); + } + } + public static class Deserializer extends JsonDeserializer { + private static final Map HANDLERS = initializeHandlers(); - private Deserializer() { + private static Map initializeHandlers() { + Map handlers = new HashMap<>(); + handlers.put(JsonToken.VALUE_NULL, new MissingNodeHandler()); + // Add more handlers for different token types if needed + return handlers; } @Override public JsonNode deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - String json = p.getValueAsString(); - if (json == null || p.currentToken() == JsonToken.VALUE_NULL) { - return null; - } - - try { - JsonNode node = null; - try { - node = MAPPER.readTree(json); - } catch (JsonParseException ignored) { - } - if (node == null || node.getNodeType() == JsonNodeType.MISSING) { - node = MAPPER.readTree(p); - } - return node; - } catch (Exception ex) { - ex.printStackTrace(); - return null; - } + JsonToken currentToken = p.getCurrentToken(); + JsonNodeHandler handler = HANDLERS.getOrDefault(currentToken, new DefaultNodeHandler()); + return handler.handle(p, ctxt); } } -} +} \ No newline at end of file diff --git a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java index 5d143a95..b9648f3f 100644 --- a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java +++ b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java @@ -28,7 +28,7 @@ public FunctionExecutor(List functions, ObjectMapper objectMapper) public Optional executeAndConvertToMessageSafely(ChatFunctionCall call) { try { - return Optional.ofNullable(executeAndConvertToMessage(call)); + return Optional.ofNullable(MessageConverter.executeAndConvertToMessage(this, call)); } catch (Exception ignored) { return Optional.empty(); } @@ -36,7 +36,7 @@ public Optional executeAndConvertToMessageSafely(ChatFunctionCall c public ChatMessage executeAndConvertToMessageHandlingExceptions(ChatFunctionCall call) { try { - return executeAndConvertToMessage(call); + return MessageConverter.executeAndConvertToMessage(this, call); } catch (Exception exception) { exception.printStackTrace(); return convertExceptionToMessage(exception); @@ -48,34 +48,6 @@ public ChatMessage convertExceptionToMessage(Exception exception) { return new ChatMessage(ChatMessageRole.FUNCTION.value(), "{\"error\": \"" + error + "\"}", "error"); } - public ChatMessage executeAndConvertToMessage(ChatFunctionCall call) { - return new ChatMessage(ChatMessageRole.FUNCTION.value(), executeAndConvertToJson(call).toPrettyString(), call.getName()); - } - - public JsonNode executeAndConvertToJson(ChatFunctionCall call) { - try { - Object execution = execute(call); - if (execution instanceof TextNode) { - JsonNode objectNode = MAPPER.readTree(((TextNode) execution).asText()); - if (objectNode.isMissingNode()) - return (JsonNode) execution; - return objectNode; - } - if (execution instanceof ObjectNode) { - return (JsonNode) execution; - } - if (execution instanceof String) { - JsonNode objectNode = MAPPER.readTree((String) execution); - if (objectNode.isMissingNode()) - throw new RuntimeException("Parsing exception"); - return objectNode; - } - return MAPPER.readValue(MAPPER.writeValueAsString(execution), JsonNode.class); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - @SuppressWarnings("unchecked") public T execute(ChatFunctionCall call) { ChatFunction function = FUNCTIONS.get(call.getName()); @@ -102,4 +74,34 @@ public void setObjectMapper(ObjectMapper objectMapper) { this.MAPPER = objectMapper; } + // Inner class to handle message conversion + private static class MessageConverter { + public static ChatMessage executeAndConvertToMessage(FunctionExecutor executor, ChatFunctionCall call) { + return new ChatMessage(ChatMessageRole.FUNCTION.value(), executeAndConvertToJson(executor, call).toPrettyString(), call.getName()); + } + + public static JsonNode executeAndConvertToJson(FunctionExecutor executor, ChatFunctionCall call) { + try { + Object execution = executor.execute(call); + if (execution instanceof TextNode) { + JsonNode objectNode = executor.MAPPER.readTree(((TextNode) execution).asText()); + if (objectNode.isMissingNode()) + return (JsonNode) execution; + return objectNode; + } + if (execution instanceof ObjectNode) { + return (JsonNode) execution; + } + if (execution instanceof String) { + JsonNode objectNode = executor.MAPPER.readTree((String) execution); + if (objectNode.isMissingNode()) + throw new RuntimeException("Parsing exception"); + return objectNode; + } + return executor.MAPPER.readValue(executor.MAPPER.writeValueAsString(execution), JsonNode.class); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } } diff --git a/service/src/main/java/com/theokanning/openai/service/ResponseBodyCallback.java b/service/src/main/java/com/theokanning/openai/service/ResponseBodyCallback.java index c5404e0f..04bdeb0b 100644 --- a/service/src/main/java/com/theokanning/openai/service/ResponseBodyCallback.java +++ b/service/src/main/java/com/theokanning/openai/service/ResponseBodyCallback.java @@ -40,61 +40,79 @@ public void onResponse(Call call, Response response) try { if (!response.isSuccessful()) { - HttpException e = new HttpException(response); - ResponseBody errorBody = response.errorBody(); - - if (errorBody == null) { - throw e; - } else { - OpenAiError error = mapper.readValue( - errorBody.string(), - OpenAiError.class - ); - throw new OpenAiHttpException(error, e, e.code()); - } + handleUnsuccessfulResponse(response); + return; } InputStream in = response.body().byteStream(); reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); - String line; - SSE sse = null; + parseSSE(reader); + + emitter.onComplete(); + + } catch (Throwable t) { + onFailure(call, t); + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + // do nothing + } + } + } + } + private void handleUnsuccessfulResponse(Response response) throws IOException { + HttpException e = new HttpException(response); + ResponseBody errorBody = response.errorBody(); + + if (errorBody == null) { + throw e; + } else { + OpenAiError error = mapper.readValue( + errorBody.string(), + OpenAiError.class + ); + throw new OpenAiHttpException(error, e, e.code()); + } + } + + private void parseSSE(BufferedReader reader) throws IOException { + String line; + SSE sse = null; + + try { while (!emitter.isCancelled() && (line = reader.readLine()) != null) { if (line.startsWith("data:")) { String data = line.substring(5).trim(); sse = new SSE(data); } else if (line.equals("") && sse != null) { - if (sse.isDone()) { - if (emitDone) { - emitter.onNext(sse); - } - break; - } - - emitter.onNext(sse); + handleSSELine(sse); sse = null; } else { throw new SSEFormatException("Invalid sse format! " + line); } } + } catch (SSEFormatException e) { + throw new IOException("Error parsing SSE", e); + } + } - emitter.onComplete(); - } catch (Throwable t) { - onFailure(call, t); - } finally { - if (reader != null) { - try { - reader.close(); - } catch (IOException e) { - // do nothing - } + private void handleSSELine(SSE sse) { + if (sse.isDone()) { + if (emitDone) { + emitter.onNext(sse); } + return; } + + emitter.onNext(sse); } @Override public void onFailure(Call call, Throwable t) { emitter.onError(t); } -} +} \ No newline at end of file diff --git a/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java index 9ad819a7..ae69734b 100644 --- a/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java @@ -82,7 +82,7 @@ void createRetrieveRun() throws JsonProcessingException { AssistantRequest assistantRequest = AssistantRequest.builder() - .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) + .model(TikTokensUtil.ModelEnum.GPT_3_5_TURBO.getName()) .name("MATH_TUTOR") .instructions("You are a personal Math Tutor.") .tools(toolList) diff --git a/service/src/test/java/com/theokanning/openai/service/AssistantTest.java b/service/src/test/java/com/theokanning/openai/service/AssistantTest.java index 8b687e34..904fa275 100644 --- a/service/src/test/java/com/theokanning/openai/service/AssistantTest.java +++ b/service/src/test/java/com/theokanning/openai/service/AssistantTest.java @@ -39,7 +39,7 @@ static void teardown() { @Test @Order(1) void createAssistant() { - AssistantRequest assistantRequest = AssistantRequest.builder().model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()).name("Math Tutor").instructions("You are a personal Math Tutor.").tools(Collections.singletonList(new Tool(AssistantToolsEnum.CODE_INTERPRETER, null))).build(); + AssistantRequest assistantRequest = AssistantRequest.builder().model(TikTokensUtil.ModelEnum.GPT_3_5_TURBO.getName()).name("Math Tutor").instructions("You are a personal Math Tutor.").tools(Collections.singletonList(new Tool(AssistantToolsEnum.CODE_INTERPRETER, null))).build(); Assistant assistant = service.createAssistant(assistantRequest); assistantId = assistant.getId(); diff --git a/service/src/test/java/com/theokanning/openai/service/RunTest.java b/service/src/test/java/com/theokanning/openai/service/RunTest.java index 2bd0c166..a2b0742b 100644 --- a/service/src/test/java/com/theokanning/openai/service/RunTest.java +++ b/service/src/test/java/com/theokanning/openai/service/RunTest.java @@ -26,7 +26,7 @@ class RunTest { @Timeout(10) void createRetrieveRun() { AssistantRequest assistantRequest = AssistantRequest.builder() - .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) + .model(TikTokensUtil.ModelEnum.GPT_3_5_TURBO.getName()) .name("MATH_TUTOR") .instructions("You are a personal Math Tutor.") .build();