Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Support vision #423

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ public class ChatCompletionChoice {
Integer index;

/**
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
* The {@link ChatMessageRole#ASSISTANT} message or delta (when streaming) which was generated
*/
@JsonAlias("delta")
ChatMessage message;
ChatMessage<String> message;

/**
* The reason why GPT stopped generating, for example "length".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* <p>Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.</p>
Expand All @@ -16,32 +18,34 @@
*/
@Data
@NoArgsConstructor(force = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class ChatMessage {
public class ChatMessage<T> {

/**
* Must be either 'system', 'user', 'assistant' or 'function'.<br>
* You may use {@link ChatMessageRole} enum.
*/
@NonNull
String role;
/**
* An array of content parts with a defined type, each can be of type text or image_url when passing in images. You
* can pass multiple images by adding multiple image_url content parts. Image input is only supported when using the
* gpt-4-visual-preview model.
*/
@JsonInclude() // content should always exist in the call, even if it is null
String content;
T content;
//name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
String name;
@JsonProperty("function_call")
ChatFunctionCall functionCall;

public ChatMessage(String role, String content) {
public ChatMessage(String role, T content) {
this.role = role;
this.content = content;
}

public ChatMessage(String role, String content, String name) {
public ChatMessage(String role, T content, String name) {
this.role = role;
this.content = content;
this.name = name;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@NoArgsConstructor
public class ChatMessageContent {

/**
* The type of the content part
*
* @see ChatMessageContentType
*/
private String type;

/**
* The text content.
*/
private String text;

/**
* Image input is only supported when using the gpt-4-visual-preview model.
*/
@JsonProperty("image_url")
private ImageUrl imageUrl;

public ChatMessageContent(String text) {
this.type = ChatMessageContentType.TEXT.value();
this.text = text;
}

public ChatMessageContent(ImageUrl imageUrl) {
this.type = ChatMessageContentType.IMAGE_URL.value();
this.imageUrl = imageUrl;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.theokanning.openai.completion.chat;

/**
* see {@link ChatMessage} documentation.
*/
public enum ChatMessageContentType {

TEXT("text"),
IMAGE_URL("image_url");

private final String value;

ChatMessageContentType(final String value) {
this.value = value;
}

public String value() {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.theokanning.openai.completion.chat;

import lombok.*;

@Data
@AllArgsConstructor
@NoArgsConstructor
@RequiredArgsConstructor
public class ImageUrl {

/**
* Either a URL of the image or the base64 encoded image data.
*/
@NonNull
private String url;

/**
* Specifies the detail level of the image. Learn more in the
* <a href="https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding">
* Vision guide</a>.
*/
private String detail;
}
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
int sum = 0;
for (ChatMessage msg : messages) {
sum += tokensPerMessage;
sum += tokens(encoding, msg.getContent());
if(msg.getContent() instanceof String){
sum += tokens(encoding, msg.getContent().toString());
}
sum += tokens(encoding, msg.getRole());
sum += tokens(encoding, msg.getName());
if (isNotBlank(msg.getName())) {
Expand Down
49 changes: 49 additions & 0 deletions api/src/main/java/com/theokanning/openai/utils/VisionUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.theokanning.openai.utils;

import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageContent;
import com.theokanning.openai.completion.chat.ImageUrl;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* Vision tool class
*
* @author cong
* @since 2023/11/17
*/
public class VisionUtil {

private static final Pattern pattern = Pattern.compile("(https?://\\S+)");

public static ChatMessage<List<ChatMessageContent>> convertForVision(ChatMessage<String> msg) {
List<ChatMessageContent> content = new ArrayList<>();
String sourceText = msg.getContent();
// Regular expression to match image URLs
Matcher matcher = pattern.matcher(sourceText);
// Find image URLs and split the string
int lastIndex = 0;
while (matcher.find()) {
String url = matcher.group();
// Add the text before the image URL
if (matcher.start() > lastIndex) {
String text = sourceText.substring(lastIndex, matcher.start()).trim();
content.add(new ChatMessageContent(text));
}
// Add the image URL
ImageUrl imageUrl = new ImageUrl();
imageUrl.setUrl(url);
content.add(new ChatMessageContent(imageUrl));
lastIndex = matcher.end();
}
// Add the remaining text
if (lastIndex < sourceText.length()) {
String text = sourceText.substring(lastIndex).trim();
content.add(new ChatMessageContent(text));
}
return new ChatMessage<>(msg.getRole(), content, msg.getName());
}
}
42 changes: 42 additions & 0 deletions example/src/main/java/example/OpenAiApiVisionExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package example;

import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.utils.VisionUtil;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

class OpenAiApiVisionExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(30));

System.out.println("Streaming chat completion...");
final List<ChatMessage> messages = new ArrayList<>();
List<ChatMessageContent> content = new ArrayList<>();
content.add(new ChatMessageContent("What’s in this image?"));
content.add(new ChatMessageContent(new ImageUrl(
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")));
messages.add(new ChatMessage<>(ChatMessageRole.USER.value(), content));

// use VisionUtil to convert image prompt to OpenAI format
System.out.println("Converting image to OpenAI format...");
ChatMessage<List<ChatMessageContent>> visionChatMessage = VisionUtil.convertForVision(
new ChatMessage<>(ChatMessageRole.USER.value(),
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg "
+ "What are in these images? Is there any difference between them?"));
messages.add(visionChatMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-4-vision-preview")
.messages(messages)
.maxTokens(300)
.build();

service.streamChatCompletion(chatCompletionRequest).blockingForEach(System.out::println);
service.shutdownExecutor();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;

/**
* @author cong
* @since 2023/11/17
*/
public abstract class ChatMessageMixIn {
@JsonProperty("content")
@JsonSerialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentSerializer.class)
@JsonDeserialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentDeserializer.class)
abstract Object getContent();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.*;
import com.theokanning.openai.completion.chat.ChatMessageContent;
import com.theokanning.openai.completion.chat.ChatMessageContentType;
import com.theokanning.openai.completion.chat.ImageUrl;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public class ChatMessageSerializerAndDeserializer {

public static class ChatMessageContentSerializer extends JsonSerializer<Object> {
@Override
public void serialize(Object content, JsonGenerator gen, SerializerProvider serializers) throws IOException {
if (content == null) {
gen.writeNull();
return;
}
if (content instanceof String) {
gen.writeString((String)content);
return;
}
if (content instanceof List) {
gen.writeStartArray();
List<?> contentList = (List<?>)content;
for (Object item : contentList) {
if (item instanceof ChatMessageContent) {
ChatMessageContent contentItem = (ChatMessageContent)item;
gen.writeStartObject();
gen.writeStringField("type", contentItem.getType());
if (ChatMessageContentType.TEXT.value().equals(contentItem.getType())) {
gen.writeStringField("text", contentItem.getText());
} else if (ChatMessageContentType.IMAGE_URL.value().equals(contentItem.getType())) {
gen.writeObjectFieldStart("image_url");
gen.writeStringField("url", contentItem.getImageUrl().getUrl());
gen.writeStringField("detail", contentItem.getImageUrl().getDetail());
gen.writeEndObject();
}
gen.writeEndObject();
}
}
gen.writeEndArray();
}
}
}

public static class ChatMessageContentDeserializer extends JsonDeserializer<Object> {
@Override
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonNode contentNode = p.readValueAsTree();
if (contentNode.isTextual()) {
return contentNode.asText();
}
if (contentNode.isArray()) {
List<Object> contentList = new ArrayList<>();
for (JsonNode itemNode : contentNode) {
String type = itemNode.get("type").asText();
if (ChatMessageContentType.TEXT.value().equals(type)) {
contentList.add(new ChatMessageContent(itemNode.get("text").asText()));
} else if (ChatMessageContentType.IMAGE_URL.value().equals(type)) {
JsonNode imageUrlJsonNode = itemNode.get("image_url");
ImageUrl imageUrl = new ImageUrl();
imageUrl.setUrl(Optional.ofNullable(imageUrlJsonNode.get("url"))
.map(JsonNode::asText).orElse(null));
imageUrl.setDetail(Optional.ofNullable(imageUrlJsonNode.get("detail"))
.map(JsonNode::asText).orElse(null));
contentList.add(new ChatMessageContent(imageUrl));
}
}
return contentList;
}
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ public static ObjectMapper defaultObjectMapper() {
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);
mapper.addMixIn(ChatMessage.class, ChatMessageMixIn.class);
return mapper;
}

Expand All @@ -607,10 +608,10 @@ public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper)

public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatCompletionChunk> flowable) {
ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
ChatMessage accumulatedMessage = new ChatMessage(ChatMessageRole.ASSISTANT.value(), null);
ChatMessage<String> accumulatedMessage = new ChatMessage<>(ChatMessageRole.ASSISTANT.value(), "");

return flowable.map(chunk -> {
ChatMessage messageChunk = chunk.getChoices().get(0).getMessage();
ChatMessage<String> messageChunk = chunk.getChoices().get(0).getMessage();
if (messageChunk.getFunctionCall() != null) {
if (messageChunk.getFunctionCall().getName() != null) {
String namePart = messageChunk.getFunctionCall().getName();
Expand Down
Loading