From 51a8633658ec21c17e65b81586ce899aa74413d4 Mon Sep 17 00:00:00 2001 From: cong <274902531@qq.com> Date: Sun, 19 Nov 2023 14:45:52 +0800 Subject: [PATCH] Add support for gpt-4-vision support the content in chat completion with format https://platform.openai.com/docs/guides/vision --- .../completion/chat/ChatCompletionChoice.java | 4 +- .../openai/completion/chat/ChatMessage.java | 20 +++-- .../completion/chat/ChatMessageContent.java | 39 +++++++++ .../chat/ChatMessageContentType.java | 20 +++++ .../openai/completion/chat/ImageUrl.java | 23 ++++++ .../openai/utils/TikTokensUtil.java | 4 +- .../theokanning/openai/utils/VisionUtil.java | 49 +++++++++++ .../java/example/OpenAiApiVisionExample.java | 42 ++++++++++ .../openai/service/ChatMessageMixIn.java | 16 ++++ .../ChatMessageSerializerAndDeserializer.java | 81 +++++++++++++++++++ .../openai/service/OpenAiService.java | 11 +-- .../openai/service/ChatCompletionTest.java | 42 ++++++++++ 12 files changed, 333 insertions(+), 18 deletions(-) create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContent.java create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContentType.java create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ImageUrl.java create mode 100644 api/src/main/java/com/theokanning/openai/utils/VisionUtil.java create mode 100644 example/src/main/java/example/OpenAiApiVisionExample.java create mode 100644 service/src/main/java/com/theokanning/openai/service/ChatMessageMixIn.java create mode 100644 service/src/main/java/com/theokanning/openai/service/ChatMessageSerializerAndDeserializer.java diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java index 7bb88698..35fc2eea 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java @@ -15,10 +15,10 @@ public class ChatCompletionChoice { Integer index; /** - * The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated + * The {@link ChatMessageRole#ASSISTANT} message or delta (when streaming) which was generated */ @JsonAlias("delta") - ChatMessage message; + ChatMessage message; /** * The reason why GPT stopped generating, for example "length". diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java index 912a71f0..175f082c 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java @@ -2,7 +2,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import lombok.*; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; /** *

Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.

@@ -16,32 +18,34 @@ */ @Data @NoArgsConstructor(force = true) -@RequiredArgsConstructor @AllArgsConstructor -public class ChatMessage { +public class ChatMessage { /** * Must be either 'system', 'user', 'assistant' or 'function'.
* You may use {@link ChatMessageRole} enum. */ - @NonNull String role; + /** + * An array of content parts with a defined type, each can be of type text or image_url when passing in images. You + * can pass multiple images by adding multiple image_url content parts. Image input is only supported when using the + * gpt-4-visual-preview model. + */ @JsonInclude() // content should always exist in the call, even if it is null - String content; + T content; //name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. String name; @JsonProperty("function_call") ChatFunctionCall functionCall; - public ChatMessage(String role, String content) { + public ChatMessage(String role, T content) { this.role = role; this.content = content; } - public ChatMessage(String role, String content, String name) { + public ChatMessage(String role, T content, String name) { this.role = role; this.content = content; this.name = name; } - } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContent.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContent.java new file mode 100644 index 00000000..6b58cf3e --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContent.java @@ -0,0 +1,39 @@ +package com.theokanning.openai.completion.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +public class ChatMessageContent { + + /** + * The type of the content part + * + * @see ChatMessageContentType + */ + private String type; + + /** + * The text content. + */ + private String text; + + /** + * Image input is only supported when using the gpt-4-visual-preview model. + */ + @JsonProperty("image_url") + private ImageUrl imageUrl; + + public ChatMessageContent(String text) { + this.type = ChatMessageContentType.TEXT.value(); + this.text = text; + } + + public ChatMessageContent(ImageUrl imageUrl) { + this.type = ChatMessageContentType.IMAGE_URL.value(); + this.imageUrl = imageUrl; + } +} \ No newline at end of file diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContentType.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContentType.java new file mode 100644 index 00000000..4a429cec --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageContentType.java @@ -0,0 +1,20 @@ +package com.theokanning.openai.completion.chat; + +/** + * see {@link ChatMessage} documentation. + */ +public enum ChatMessageContentType { + + TEXT("text"), + IMAGE_URL("image_url"); + + private final String value; + + ChatMessageContentType(final String value) { + this.value = value; + } + + public String value() { + return value; + } +} diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ImageUrl.java b/api/src/main/java/com/theokanning/openai/completion/chat/ImageUrl.java new file mode 100644 index 00000000..5d2935af --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ImageUrl.java @@ -0,0 +1,23 @@ +package com.theokanning.openai.completion.chat; + +import lombok.*; + +@Data +@AllArgsConstructor +@NoArgsConstructor +@RequiredArgsConstructor +public class ImageUrl { + + /** + * Either a URL of the image or the base64 encoded image data. + */ + @NonNull + private String url; + + /** + * Specifies the detail level of the image. Learn more in the + * + * Vision guide. + */ + private String detail; +} \ No newline at end of file diff --git a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java index 0a50907e..bee73e98 100644 --- a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java +++ b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java @@ -186,7 +186,9 @@ public static int tokens(String modelName, List messages) { int sum = 0; for (ChatMessage msg : messages) { sum += tokensPerMessage; - sum += tokens(encoding, msg.getContent()); + if(msg.getContent() instanceof String){ + sum += tokens(encoding, msg.getContent().toString()); + } sum += tokens(encoding, msg.getRole()); sum += tokens(encoding, msg.getName()); if (isNotBlank(msg.getName())) { diff --git a/api/src/main/java/com/theokanning/openai/utils/VisionUtil.java b/api/src/main/java/com/theokanning/openai/utils/VisionUtil.java new file mode 100644 index 00000000..0174902d --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/utils/VisionUtil.java @@ -0,0 +1,49 @@ +package com.theokanning.openai.utils; + +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.completion.chat.ChatMessageContent; +import com.theokanning.openai.completion.chat.ImageUrl; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Vision tool class + * + * @author cong + * @since 2023/11/17 + */ +public class VisionUtil { + + private static final Pattern pattern = Pattern.compile("(https?://\\S+)"); + + public static ChatMessage> convertForVision(ChatMessage msg) { + List content = new ArrayList<>(); + String sourceText = msg.getContent(); + // Regular expression to match image URLs + Matcher matcher = pattern.matcher(sourceText); + // Find image URLs and split the string + int lastIndex = 0; + while (matcher.find()) { + String url = matcher.group(); + // Add the text before the image URL + if (matcher.start() > lastIndex) { + String text = sourceText.substring(lastIndex, matcher.start()).trim(); + content.add(new ChatMessageContent(text)); + } + // Add the image URL + ImageUrl imageUrl = new ImageUrl(); + imageUrl.setUrl(url); + content.add(new ChatMessageContent(imageUrl)); + lastIndex = matcher.end(); + } + // Add the remaining text + if (lastIndex < sourceText.length()) { + String text = sourceText.substring(lastIndex).trim(); + content.add(new ChatMessageContent(text)); + } + return new ChatMessage<>(msg.getRole(), content, msg.getName()); + } +} diff --git a/example/src/main/java/example/OpenAiApiVisionExample.java b/example/src/main/java/example/OpenAiApiVisionExample.java new file mode 100644 index 00000000..340b4148 --- /dev/null +++ b/example/src/main/java/example/OpenAiApiVisionExample.java @@ -0,0 +1,42 @@ +package example; + +import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.service.OpenAiService; +import com.theokanning.openai.utils.VisionUtil; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +class OpenAiApiVisionExample { + public static void main(String... args) { + String token = System.getenv("OPENAI_TOKEN"); + OpenAiService service = new OpenAiService(token, Duration.ofSeconds(30)); + + System.out.println("Streaming chat completion..."); + final List messages = new ArrayList<>(); + List content = new ArrayList<>(); + content.add(new ChatMessageContent("What’s in this image?")); + content.add(new ChatMessageContent(new ImageUrl( + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"))); + messages.add(new ChatMessage<>(ChatMessageRole.USER.value(), content)); + + // use VisionUtil to convert image prompt to OpenAI format + System.out.println("Converting image to OpenAI format..."); + ChatMessage> visionChatMessage = VisionUtil.convertForVision( + new ChatMessage<>(ChatMessageRole.USER.value(), + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg " + + "What are in these images? Is there any difference between them?")); + messages.add(visionChatMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-4-vision-preview") + .messages(messages) + .maxTokens(300) + .build(); + + service.streamChatCompletion(chatCompletionRequest).blockingForEach(System.out::println); + service.shutdownExecutor(); + } +} diff --git a/service/src/main/java/com/theokanning/openai/service/ChatMessageMixIn.java b/service/src/main/java/com/theokanning/openai/service/ChatMessageMixIn.java new file mode 100644 index 00000000..023c3446 --- /dev/null +++ b/service/src/main/java/com/theokanning/openai/service/ChatMessageMixIn.java @@ -0,0 +1,16 @@ +package com.theokanning.openai.service; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +/** + * @author cong + * @since 2023/11/17 + */ +public abstract class ChatMessageMixIn { + @JsonProperty("content") + @JsonSerialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentSerializer.class) + @JsonDeserialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentDeserializer.class) + abstract Object getContent(); +} diff --git a/service/src/main/java/com/theokanning/openai/service/ChatMessageSerializerAndDeserializer.java b/service/src/main/java/com/theokanning/openai/service/ChatMessageSerializerAndDeserializer.java new file mode 100644 index 00000000..bda3081d --- /dev/null +++ b/service/src/main/java/com/theokanning/openai/service/ChatMessageSerializerAndDeserializer.java @@ -0,0 +1,81 @@ +package com.theokanning.openai.service; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.*; +import com.theokanning.openai.completion.chat.ChatMessageContent; +import com.theokanning.openai.completion.chat.ChatMessageContentType; +import com.theokanning.openai.completion.chat.ImageUrl; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class ChatMessageSerializerAndDeserializer { + + public static class ChatMessageContentSerializer extends JsonSerializer { + @Override + public void serialize(Object content, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if (content == null) { + gen.writeNull(); + return; + } + if (content instanceof String) { + gen.writeString((String)content); + return; + } + if (content instanceof List) { + gen.writeStartArray(); + List contentList = (List)content; + for (Object item : contentList) { + if (item instanceof ChatMessageContent) { + ChatMessageContent contentItem = (ChatMessageContent)item; + gen.writeStartObject(); + gen.writeStringField("type", contentItem.getType()); + if (ChatMessageContentType.TEXT.value().equals(contentItem.getType())) { + gen.writeStringField("text", contentItem.getText()); + } else if (ChatMessageContentType.IMAGE_URL.value().equals(contentItem.getType())) { + gen.writeObjectFieldStart("image_url"); + gen.writeStringField("url", contentItem.getImageUrl().getUrl()); + gen.writeStringField("detail", contentItem.getImageUrl().getDetail()); + gen.writeEndObject(); + } + gen.writeEndObject(); + } + } + gen.writeEndArray(); + } + } + } + + public static class ChatMessageContentDeserializer extends JsonDeserializer { + @Override + public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + JsonNode contentNode = p.readValueAsTree(); + if (contentNode.isTextual()) { + return contentNode.asText(); + } + if (contentNode.isArray()) { + List contentList = new ArrayList<>(); + for (JsonNode itemNode : contentNode) { + String type = itemNode.get("type").asText(); + if (ChatMessageContentType.TEXT.value().equals(type)) { + contentList.add(new ChatMessageContent(itemNode.get("text").asText())); + } else if (ChatMessageContentType.IMAGE_URL.value().equals(type)) { + JsonNode imageUrlJsonNode = itemNode.get("image_url"); + ImageUrl imageUrl = new ImageUrl(); + imageUrl.setUrl(Optional.ofNullable(imageUrlJsonNode.get("url")) + .map(JsonNode::asText).orElse(null)); + imageUrl.setDetail(Optional.ofNullable(imageUrlJsonNode.get("detail")) + .map(JsonNode::asText).orElse(null)); + contentList.add(new ChatMessageContent(imageUrl)); + } + } + return contentList; + } + return null; + } + } + +} diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 9c15522b..df2b3305 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -8,11 +8,7 @@ import com.fasterxml.jackson.databind.node.TextNode; import com.theokanning.openai.*; import com.theokanning.openai.assistants.*; -import com.theokanning.openai.audio.CreateSpeechRequest; -import com.theokanning.openai.audio.CreateTranscriptionRequest; -import com.theokanning.openai.audio.CreateTranslationRequest; -import com.theokanning.openai.audio.TranscriptionResult; -import com.theokanning.openai.audio.TranslationResult; +import com.theokanning.openai.audio.*; import com.theokanning.openai.billing.BillingUsage; import com.theokanning.openai.billing.Subscription; import com.theokanning.openai.client.OpenAiApi; @@ -534,6 +530,7 @@ public static ObjectMapper defaultObjectMapper() { mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class); mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class); mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class); + mapper.addMixIn(ChatMessage.class, ChatMessageMixIn.class); return mapper; } @@ -556,10 +553,10 @@ public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) public Flowable mapStreamToAccumulator(Flowable flowable) { ChatFunctionCall functionCall = new ChatFunctionCall(null, null); - ChatMessage accumulatedMessage = new ChatMessage(ChatMessageRole.ASSISTANT.value(), null); + ChatMessage accumulatedMessage = new ChatMessage<>(ChatMessageRole.ASSISTANT.value(), ""); return flowable.map(chunk -> { - ChatMessage messageChunk = chunk.getChoices().get(0).getMessage(); + ChatMessage messageChunk = chunk.getChoices().get(0).getMessage(); if (messageChunk.getFunctionCall() != null) { if (messageChunk.getFunctionCall().getName() != null) { String namePart = messageChunk.getFunctionCall().getName(); diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 25f0defb..190e826b 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -300,4 +300,46 @@ void streamChatCompletionWithDynamicFunctions() { assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit")); } + @Test + void createChatCompletionWithImageInput() { + final List messages = new ArrayList<>(); + List content = new ArrayList<>(); + content.add(new ChatMessageContent("What’s in this image?")); + content.add(new ChatMessageContent(new ImageUrl( + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"))); + messages.add(new ChatMessage<>(ChatMessageRole.USER.value(), content)); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-4-vision-preview") + .messages(messages) + .maxTokens(300) + .build(); + + List choices = service.createChatCompletion(chatCompletionRequest).getChoices(); + assertFalse(choices.isEmpty()); + } + + @Test + void streamChatCompletionWithImageInput() { + final List messages = new ArrayList<>(); + List content = new ArrayList<>(); + content.add(new ChatMessageContent("What’s in this image?")); + content.add(new ChatMessageContent(new ImageUrl( + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"))); + messages.add(new ChatMessage<>(ChatMessageRole.USER.value(), content)); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-4-vision-preview") + .messages(messages) + .maxTokens(300) + .build(); + + List chunks = new ArrayList<>(); + service.streamChatCompletion(chatCompletionRequest).blockingForEach(chunks::add); + assertFalse(chunks.isEmpty()); + assertNotNull(chunks.get(0).getChoices().get(0)); + } + }