Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Vision models support in Anthropic #32103

Merged
merged 13 commits into from
Apr 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.external.plugins.models.AnthropicRequestDTO;
import com.external.plugins.utils.AnthropicMethodStrategy;
import com.external.plugins.utils.RequestUtils;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.gson.Gson;
Expand Down Expand Up @@ -137,7 +138,18 @@ public Mono<ActionExecutionResult> executeParameterized(
return Mono.just(apiKeyNotPresentErrorResult);
}

return RequestUtils.makeRequest(httpMethod, uri, apiKeyAuth, BodyInserters.fromValue(anthropicRequestDTO))
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
vivonk marked this conversation as resolved.
Show resolved Hide resolved
String requestBody;
try {
requestBody = objectMapper.writeValueAsString(anthropicRequestDTO);
} catch (Exception e) {
errorResult.setIsExecutionSuccess(false);
errorResult.setErrorInfo(
new AppsmithPluginException(AppsmithPluginError.PLUGIN_JSON_PARSE_ERROR, e.getMessage()));
return Mono.just(errorResult);
}
vivonk marked this conversation as resolved.
Show resolved Hide resolved

return RequestUtils.makeRequest(httpMethod, uri, apiKeyAuth, BodyInserters.fromValue(requestBody))
.flatMap(responseEntity -> {
HttpStatusCode statusCode = responseEntity.getStatusCode();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,38 @@
import com.appsmith.external.models.ActionConfiguration;
import com.external.plugins.constants.AnthropicConstants;
import com.external.plugins.models.AnthropicRequestDTO;
import com.external.plugins.models.Message;
import com.external.plugins.models.Role;
import com.external.plugins.utils.RequestUtils;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import org.springframework.http.HttpMethod;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.external.plugins.constants.AnthropicConstants.ANTHROPIC;
import static com.external.plugins.constants.AnthropicConstants.CHAT;
import static com.external.plugins.constants.AnthropicConstants.CHAT_MODEL_SELECTOR;
import static com.external.plugins.constants.AnthropicConstants.CHAT_V2;
import static com.external.plugins.constants.AnthropicConstants.CLOUD_SERVICES;
import static com.external.plugins.constants.AnthropicConstants.COMMAND;
import static com.external.plugins.constants.AnthropicConstants.CONTENT;
import static com.external.plugins.constants.AnthropicConstants.DATA;
import static com.external.plugins.constants.AnthropicConstants.DEFAULT_MAX_TOKEN;
import static com.external.plugins.constants.AnthropicConstants.DEFAULT_TEMPERATURE;
import static com.external.plugins.constants.AnthropicConstants.JSON;
import static com.external.plugins.constants.AnthropicConstants.MAX_TOKENS;
import static com.external.plugins.constants.AnthropicConstants.MESSAGES;
import static com.external.plugins.constants.AnthropicConstants.MODELS_API;
import static com.external.plugins.constants.AnthropicConstants.PROVIDER;
import static com.external.plugins.constants.AnthropicConstants.ROLE;
import static com.external.plugins.constants.AnthropicConstants.TEMPERATURE;
import static com.external.plugins.constants.AnthropicConstants.VIEW_TYPE;
import static com.external.plugins.constants.AnthropicErrorMessages.BAD_MAX_TOKEN_CONFIGURATION;
import static com.external.plugins.constants.AnthropicErrorMessages.BAD_TEMPERATURE_CONFIGURATION;
import static com.external.plugins.constants.AnthropicConstants.SYSTEM_PROMPT;
import static com.external.plugins.constants.AnthropicErrorMessages.EXECUTION_FAILURE;
import static com.external.plugins.constants.AnthropicErrorMessages.MODEL_NOT_SELECTED;
import static com.external.plugins.constants.AnthropicErrorMessages.QUERY_NOT_CONFIGURED;
import static com.external.plugins.constants.AnthropicErrorMessages.STRING_APPENDER;
import static com.external.plugins.utils.CommandUtils.getMaxTokenFromFormData;
import static com.external.plugins.utils.CommandUtils.getMessages;
import static com.external.plugins.utils.CommandUtils.getTemperatureFromFormData;

public class ChatCommand implements AnthropicCommand {
private final Gson gson = new Gson();
Expand All @@ -60,7 +55,7 @@ public HttpMethod getExecutionMethod() {
public URI createTriggerUri() {
return UriComponentsBuilder.fromUriString(CLOUD_SERVICES + MODELS_API)
.queryParam(PROVIDER, ANTHROPIC)
.queryParam(COMMAND, CHAT.toLowerCase())
.queryParam(COMMAND, CHAT_V2)
vivonk marked this conversation as resolved.
Show resolved Hide resolved
.build()
.toUri();
}
Expand All @@ -87,16 +82,47 @@ public AnthropicRequestDTO makeRequestBody(ActionConfiguration actionConfigurati
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, MODEL_NOT_SELECTED));
}

anthropicRequestDTO.setModel(model);

Float temperature = getTemperatureFromFormData(formData);
anthropicRequestDTO.setTemperature(temperature);
anthropicRequestDTO.setMaxTokensToSample(getMaxTokenFromFormData(formData));
anthropicRequestDTO.setPrompt(createPrompt(formData));
anthropicRequestDTO.setModel(model);

anthropicRequestDTO.setMaxTokens(getMaxTokenFromFormData(formData));
anthropicRequestDTO.setMessages(createMessages(formData));
if (formData.containsKey(SYSTEM_PROMPT) && formData.get(SYSTEM_PROMPT) != null) {
anthropicRequestDTO.setSystem(RequestUtils.extractDataFromFormData(formData, SYSTEM_PROMPT));
}

return anthropicRequestDTO;
}

private List<Message> createMessages(Map<String, Object> formData) {
if (!formData.containsKey(MESSAGES)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration");
}
List<Map<String, String>> messageMaps = getMessages((Map<String, Object>) formData.get(MESSAGES));
if (messageMaps == null) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration correctly");
}
List<Message> messages = new ArrayList<>();
for (Map<String, String> messageMap : messageMaps) {
if (messageMap != null && messageMap.containsKey(ROLE) && messageMap.containsKey(CONTENT)) {
vivonk marked this conversation as resolved.
Show resolved Hide resolved
Message message = new Message();
Message.TextContent textContent = new Message.TextContent();
textContent.setText(messageMap.get(CONTENT));

message.setRole(messageMap.get(ROLE));
message.setContent(List.of(textContent));

messages.add(message);
}
}
return messages;
}

/**
* This is the kind of format we want to build from the messages as a prompt.
* Example Prompt: `\n\nHuman: ${query}\n\nAssistant:`
Expand Down Expand Up @@ -127,54 +153,4 @@ private String createPrompt(Map<String, Object> formData) {
"messages are not provided in the configuration");
}
}

/**
* When JS is enabled in form component, value is stored in data key only. Difference is if viewType is json,
* it's stored as JSON string otherwise it's Java serialized object
*/
private List<Map<String, String>> getMessages(Map<String, Object> messages) {
Type listType = new TypeToken<List<Map<String, String>>>() {}.getType();
if (messages.containsKey(VIEW_TYPE) && JSON.equals(messages.get(VIEW_TYPE))) {
// data is present in data key as String
return gson.fromJson((String) messages.get(DATA), listType);
}
// return object stored in data key
return (List<Map<String, String>>) messages.get(DATA);
}

private int getMaxTokenFromFormData(Map<String, Object> formData) {
String maxTokenAsString = RequestUtils.extractValueFromFormData(formData, MAX_TOKENS);

if (!StringUtils.hasText(maxTokenAsString)) {
return DEFAULT_MAX_TOKEN;
}

try {
return Integer.parseInt(maxTokenAsString);
} catch (IllegalArgumentException illegalArgumentException) {
return DEFAULT_MAX_TOKEN;
} catch (Exception exception) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, BAD_MAX_TOKEN_CONFIGURATION));
}
}

private Float getTemperatureFromFormData(Map<String, Object> formData) {
String temperatureString = RequestUtils.extractValueFromFormData(formData, TEMPERATURE);

if (!StringUtils.hasText(temperatureString)) {
return DEFAULT_TEMPERATURE;
}

try {
return Float.parseFloat(temperatureString);
} catch (IllegalArgumentException illegalArgumentException) {
return DEFAULT_TEMPERATURE;
} catch (Exception exception) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, BAD_TEMPERATURE_CONFIGURATION));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package com.external.plugins.commands;

import com.appsmith.external.exceptions.pluginExceptions.AppsmithPluginError;
import com.appsmith.external.exceptions.pluginExceptions.AppsmithPluginException;
import com.appsmith.external.models.ActionConfiguration;
import com.external.plugins.models.AnthropicRequestDTO;
import com.external.plugins.models.Message;
import com.external.plugins.utils.RequestUtils;
import com.google.gson.Gson;
import org.springframework.http.HttpMethod;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.external.plugins.constants.AnthropicConstants.ANTHROPIC;
import static com.external.plugins.constants.AnthropicConstants.BASE64;
import static com.external.plugins.constants.AnthropicConstants.CLOUD_SERVICES;
import static com.external.plugins.constants.AnthropicConstants.COMMAND;
import static com.external.plugins.constants.AnthropicConstants.CONTENT;
import static com.external.plugins.constants.AnthropicConstants.IMAGE;
import static com.external.plugins.constants.AnthropicConstants.MESSAGES;
import static com.external.plugins.constants.AnthropicConstants.MODELS_API;
import static com.external.plugins.constants.AnthropicConstants.PROVIDER;
import static com.external.plugins.constants.AnthropicConstants.ROLE;
import static com.external.plugins.constants.AnthropicConstants.SYSTEM_PROMPT;
import static com.external.plugins.constants.AnthropicConstants.TEXT;
import static com.external.plugins.constants.AnthropicConstants.TYPE;
import static com.external.plugins.constants.AnthropicConstants.VISION;
import static com.external.plugins.constants.AnthropicConstants.VISION_MODEL_SELECTOR;
import static com.external.plugins.constants.AnthropicErrorMessages.EXECUTION_FAILURE;
import static com.external.plugins.constants.AnthropicErrorMessages.MODEL_NOT_SELECTED;
import static com.external.plugins.constants.AnthropicErrorMessages.QUERY_NOT_CONFIGURED;
import static com.external.plugins.constants.AnthropicErrorMessages.STRING_APPENDER;
import static com.external.plugins.utils.CommandUtils.getMaxTokenFromFormData;
import static com.external.plugins.utils.CommandUtils.getMessages;
import static com.external.plugins.utils.CommandUtils.getTemperatureFromFormData;

public class VisionCommand implements AnthropicCommand {
private final Gson gson = new Gson();

@Override
public HttpMethod getTriggerHTTPMethod() {
return HttpMethod.GET;
}

@Override
public HttpMethod getExecutionMethod() {
return HttpMethod.POST;
}

@Override
public URI createTriggerUri() {
return UriComponentsBuilder.fromUriString(CLOUD_SERVICES + MODELS_API)
.queryParam(PROVIDER, ANTHROPIC)
.queryParam(COMMAND, VISION)
.build()
.toUri();
}

@Override
public URI createExecutionUri() {
return RequestUtils.createUriFromCommand(VISION);
}

@Override
public AnthropicRequestDTO makeRequestBody(ActionConfiguration actionConfiguration) {
Map<String, Object> formData = actionConfiguration.getFormData();
if (CollectionUtils.isEmpty(formData)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, QUERY_NOT_CONFIGURED));
}

AnthropicRequestDTO anthropicRequestDTO = new AnthropicRequestDTO();
String model = RequestUtils.extractDataFromFormData(formData, VISION_MODEL_SELECTOR);
if (!StringUtils.hasText(model)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, MODEL_NOT_SELECTED));
}
Float temperature = getTemperatureFromFormData(formData);
anthropicRequestDTO.setTemperature(temperature);
anthropicRequestDTO.setModel(model);

anthropicRequestDTO.setMaxTokens(getMaxTokenFromFormData(formData));
anthropicRequestDTO.setMessages(createMessages(formData));
if (formData.containsKey(SYSTEM_PROMPT) && formData.get(SYSTEM_PROMPT) != null) {
anthropicRequestDTO.setSystem(RequestUtils.extractDataFromFormData(formData, SYSTEM_PROMPT));
}

return anthropicRequestDTO;
}

private List<Message> createMessages(Map<String, Object> formData) {
if (!formData.containsKey(MESSAGES)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration");
}
List<Map<String, String>> messageMaps = getMessages((Map<String, Object>) formData.get(MESSAGES));
if (messageMaps == null) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration correctly");
}
List<Message> messages = new ArrayList<>();
for (Map<String, String> messageMap : messageMaps) {
if (messageMap != null && messageMap.containsKey(ROLE) && messageMap.containsKey(CONTENT)) {
Message message = new Message();
String type = messageMap.get(TYPE);
message.setRole(messageMap.get(ROLE));
if (TEXT.equals(type)) {
Message.TextContent textContent = new Message.TextContent();
textContent.setText(messageMap.get(CONTENT));
message.setContent(List.of(textContent));
} else if (IMAGE.equals(type)) {
String content = messageMap.get(CONTENT);
if (!isValidImageContent(content)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
"Image content provided in the configuration is not valid");
}
Message.ImageContent imageContent = new Message.ImageContent();
Message.Source source = new Message.Source();

source.setType(BASE64);
source.setMediaType(getMediaType(messageMap.get(CONTENT)));
source.setData(getImageData(messageMap.get(CONTENT)));

imageContent.setSource(source);
message.setContent(List.of(imageContent));
}
message.setRole(messageMap.get(ROLE));
messages.add(message);
}
}
// As per Anthropic API, two content by same role in row are not allowed. It should be followed like user and
// assistant
// That's why we have to club the messages to have user and assistant in alternate order
List<Message> orderedMessages = new ArrayList<>();
for (Message message : messages) {
if (orderedMessages.isEmpty()) {
orderedMessages.add(message);
} else {
Message lastMessage = orderedMessages.get(orderedMessages.size() - 1);
if (!lastMessage.getRole().equals(message.getRole())) {
// different roles so can be added in the order
orderedMessages.add(message);
} else {
// add last message content to the current message since both are same role
List<Message.Content> content = new ArrayList<>(lastMessage.getContent());
content.addAll(message.getContent());
message.setContent(content);
orderedMessages.remove(lastMessage);
orderedMessages.add(message);
}
}
}
return orderedMessages;
}

private boolean isValidImageContent(String content) {
return StringUtils.hasText(content) && content.startsWith("data:image");
}

private String getMediaType(String content) {
return content.split(";", 2)[0].split(":")[1];
}

private String getImageData(String content) {
return content.split(",", 2)[1];
}
}
Loading
Loading