Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Vision models support in Anthropic #32103

Merged
merged 13 commits into from
Apr 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import com.external.plugins.commands.AnthropicCommand;
import com.external.plugins.constants.AnthropicConstants;
import com.external.plugins.models.AnthropicRequestDTO;
import com.external.plugins.models.CompletionDTO;
import com.external.plugins.models.MessageDTO;
import com.external.plugins.utils.AnthropicMethodStrategy;
import com.external.plugins.utils.RequestUtils;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.gson.Gson;
Expand All @@ -45,6 +48,7 @@

import static com.external.plugins.constants.AnthropicConstants.ANTHROPIC_MODELS;
import static com.external.plugins.constants.AnthropicConstants.BODY;
import static com.external.plugins.constants.AnthropicConstants.CLAUDE3_PREFIX;
import static com.external.plugins.constants.AnthropicConstants.LABEL;
import static com.external.plugins.constants.AnthropicConstants.TEST_MODEL;
import static com.external.plugins.constants.AnthropicConstants.TEST_PROMPT;
Expand Down Expand Up @@ -137,7 +141,21 @@ public Mono<ActionExecutionResult> executeParameterized(
return Mono.just(apiKeyNotPresentErrorResult);
}

return RequestUtils.makeRequest(httpMethod, uri, apiKeyAuth, BodyInserters.fromValue(anthropicRequestDTO))
String model = anthropicRequestDTO.getModel();

// we don't want to serialise null values as Anthropic throws bad request otherwise
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
vivonk marked this conversation as resolved.
Show resolved Hide resolved
String requestBody;
try {
requestBody = objectMapper.writeValueAsString(anthropicRequestDTO);
} catch (Exception e) {
errorResult.setIsExecutionSuccess(false);
errorResult.setErrorInfo(
new AppsmithPluginException(AppsmithPluginError.PLUGIN_JSON_PARSE_ERROR, e.getMessage()));
return Mono.just(errorResult);
}
vivonk marked this conversation as resolved.
Show resolved Hide resolved

return RequestUtils.makeRequest(httpMethod, uri, apiKeyAuth, BodyInserters.fromValue(requestBody))
.flatMap(responseEntity -> {
HttpStatusCode statusCode = responseEntity.getStatusCode();

Expand Down Expand Up @@ -171,7 +189,12 @@ public Mono<ActionExecutionResult> executeParameterized(
Object body;
try {
body = objectMapper.readValue(responseEntity.getBody(), Object.class);
actionExecutionResult.setBody(body);
if (model.contains(CLAUDE3_PREFIX)) {
actionExecutionResult.setBody(body);
} else {
actionExecutionResult.setBody(
formatResponseBodyAsCompletionAPI(model, responseEntity.getBody()));
}
} catch (IOException ex) {
actionExecutionResult.setIsExecutionSuccess(false);
actionExecutionResult.setErrorInfo(new AppsmithPluginException(
Expand Down Expand Up @@ -204,6 +227,24 @@ public Mono<ActionExecutionResult> executeParameterized(
});
}

/**
* To keep things backward compatible, if model doesn't belong to claude 3, format response in form of claude completion API
*/
private Object formatResponseBodyAsCompletionAPI(String model, byte[] response) {
try {
MessageDTO messageDTO = objectMapper.readValue(response, MessageDTO.class);
CompletionDTO completionDTO = new CompletionDTO();
completionDTO.setId(messageDTO.getId());
completionDTO.setType("completion");
completionDTO.setStopReason(messageDTO.getStopReason());
completionDTO.setModel(model);
completionDTO.setCompletion(messageDTO.getFirstMessage());
return completionDTO;
} catch (IOException e) {
throw new AppsmithPluginException(AppsmithPluginError.PLUGIN_JSON_PARSE_ERROR, new String(response));
}
}

@Override
public Mono<TriggerResultDTO> trigger(
APIConnection connection, DatasourceConfiguration datasourceConfiguration, TriggerRequestDTO request) {
Expand Down Expand Up @@ -252,7 +293,11 @@ public Mono<TriggerResultDTO> trigger(
})
.onErrorResume(error -> {
log.debug("Error while fetching Anthropic models list", error);
return Mono.just(getDataToMap(ANTHROPIC_MODELS));
if (ANTHROPIC_MODELS.containsKey(requestType)) {
return Mono.just(getDataToMap(ANTHROPIC_MODELS.get(requestType)));
}
return Mono.error(new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR, error.getMessage()));
})
.map(trigger -> {
TriggerResultDTO triggerResult = new TriggerResultDTO(trigger);
Expand All @@ -268,7 +313,7 @@ public Set<String> validateDatasource(DatasourceConfiguration datasourceConfigur
}

private List<Map<String, String>> getDataToMap(List<String> data) {
return data.stream().sorted().map(x -> Map.of(LABEL, x, VALUE, x)).collect(Collectors.toList());
return data.stream().map(x -> Map.of(LABEL, x, VALUE, x)).collect(Collectors.toList());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,38 @@
import com.appsmith.external.models.ActionConfiguration;
import com.external.plugins.constants.AnthropicConstants;
import com.external.plugins.models.AnthropicRequestDTO;
import com.external.plugins.models.Role;
import com.external.plugins.models.Message;
import com.external.plugins.utils.CommandUtils;
import com.external.plugins.utils.RequestUtils;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import org.springframework.http.HttpMethod;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.external.plugins.constants.AnthropicConstants.ANTHROPIC;
import static com.external.plugins.constants.AnthropicConstants.CHAT;
import static com.external.plugins.constants.AnthropicConstants.CHAT_MODEL_SELECTOR;
import static com.external.plugins.constants.AnthropicConstants.CHAT_V2;
import static com.external.plugins.constants.AnthropicConstants.CLOUD_SERVICES;
import static com.external.plugins.constants.AnthropicConstants.COMMAND;
import static com.external.plugins.constants.AnthropicConstants.CONTENT;
import static com.external.plugins.constants.AnthropicConstants.DATA;
import static com.external.plugins.constants.AnthropicConstants.DEFAULT_MAX_TOKEN;
import static com.external.plugins.constants.AnthropicConstants.DEFAULT_TEMPERATURE;
import static com.external.plugins.constants.AnthropicConstants.JSON;
import static com.external.plugins.constants.AnthropicConstants.MAX_TOKENS;
import static com.external.plugins.constants.AnthropicConstants.MESSAGES;
import static com.external.plugins.constants.AnthropicConstants.MODELS_API;
import static com.external.plugins.constants.AnthropicConstants.PROVIDER;
import static com.external.plugins.constants.AnthropicConstants.ROLE;
import static com.external.plugins.constants.AnthropicConstants.TEMPERATURE;
import static com.external.plugins.constants.AnthropicConstants.VIEW_TYPE;
import static com.external.plugins.constants.AnthropicErrorMessages.BAD_MAX_TOKEN_CONFIGURATION;
import static com.external.plugins.constants.AnthropicErrorMessages.BAD_TEMPERATURE_CONFIGURATION;
import static com.external.plugins.constants.AnthropicConstants.SYSTEM_PROMPT;
import static com.external.plugins.constants.AnthropicErrorMessages.EXECUTION_FAILURE;
import static com.external.plugins.constants.AnthropicErrorMessages.MODEL_NOT_SELECTED;
import static com.external.plugins.constants.AnthropicErrorMessages.QUERY_NOT_CONFIGURED;
import static com.external.plugins.constants.AnthropicErrorMessages.STRING_APPENDER;
import static com.external.plugins.utils.CommandUtils.getMaxTokenFromFormData;
import static com.external.plugins.utils.CommandUtils.getMessages;
import static com.external.plugins.utils.CommandUtils.getTemperatureFromFormData;

public class ChatCommand implements AnthropicCommand {
private final Gson gson = new Gson();
Expand All @@ -60,7 +55,7 @@ public HttpMethod getExecutionMethod() {
public URI createTriggerUri() {
return UriComponentsBuilder.fromUriString(CLOUD_SERVICES + MODELS_API)
.queryParam(PROVIDER, ANTHROPIC)
.queryParam(COMMAND, CHAT.toLowerCase())
.queryParam(COMMAND, CHAT_V2)
vivonk marked this conversation as resolved.
Show resolved Hide resolved
.build()
.toUri();
}
Expand All @@ -87,94 +82,44 @@ public AnthropicRequestDTO makeRequestBody(ActionConfiguration actionConfigurati
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, MODEL_NOT_SELECTED));
}

anthropicRequestDTO.setModel(model);

Float temperature = getTemperatureFromFormData(formData);
anthropicRequestDTO.setTemperature(temperature);
anthropicRequestDTO.setMaxTokensToSample(getMaxTokenFromFormData(formData));
anthropicRequestDTO.setPrompt(createPrompt(formData));
anthropicRequestDTO.setModel(model);

anthropicRequestDTO.setMaxTokens(getMaxTokenFromFormData(formData));
anthropicRequestDTO.setMessages(createMessages(formData));
if (formData.containsKey(SYSTEM_PROMPT) && formData.get(SYSTEM_PROMPT) != null) {
anthropicRequestDTO.setSystem(RequestUtils.extractDataFromFormData(formData, SYSTEM_PROMPT));
}

return anthropicRequestDTO;
}

/**
* This is the kind of format we want to build from the messages as a prompt.
* Example Prompt: `\n\nHuman: ${query}\n\nAssistant:`
* Lastly, we leave it with an additional Assistant: so that it can respond back as an assistant
*/
private String createPrompt(Map<String, Object> formData) {
StringBuilder stringBuilder = new StringBuilder();
if (formData.containsKey(MESSAGES)) {
List<Map<String, String>> messageMaps = getMessages((Map<String, Object>) formData.get(MESSAGES));
if (messageMaps == null) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration correctly");
}
for (Map<String, String> messageMap : messageMaps) {
if (messageMap != null && messageMap.containsKey(ROLE) && messageMap.containsKey(CONTENT)) {
stringBuilder
.append("\n\n")
.append(messageMap.get(ROLE))
.append(": ")
.append(messageMap.get(CONTENT));
}
}
return stringBuilder.append("\n").append(Role.Assistant).append(":").toString();
} else {
private List<Message> createMessages(Map<String, Object> formData) {
if (!formData.containsKey(MESSAGES)) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration");
}
}

/**
* When JS is enabled in form component, value is stored in data key only. Difference is if viewType is json,
* it's stored as JSON string otherwise it's Java serialized object
*/
private List<Map<String, String>> getMessages(Map<String, Object> messages) {
Type listType = new TypeToken<List<Map<String, String>>>() {}.getType();
if (messages.containsKey(VIEW_TYPE) && JSON.equals(messages.get(VIEW_TYPE))) {
// data is present in data key as String
return gson.fromJson((String) messages.get(DATA), listType);
}
// return object stored in data key
return (List<Map<String, String>>) messages.get(DATA);
}

private int getMaxTokenFromFormData(Map<String, Object> formData) {
String maxTokenAsString = RequestUtils.extractValueFromFormData(formData, MAX_TOKENS);

if (!StringUtils.hasText(maxTokenAsString)) {
return DEFAULT_MAX_TOKEN;
}

try {
return Integer.parseInt(maxTokenAsString);
} catch (IllegalArgumentException illegalArgumentException) {
return DEFAULT_MAX_TOKEN;
} catch (Exception exception) {
List<Map<String, String>> messageMaps = getMessages((Map<String, Object>) formData.get(MESSAGES));
if (messageMaps == null) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, BAD_MAX_TOKEN_CONFIGURATION));
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR,
"messages are not provided in the configuration correctly");
}
}

private Float getTemperatureFromFormData(Map<String, Object> formData) {
String temperatureString = RequestUtils.extractValueFromFormData(formData, TEMPERATURE);
List<Message> messages = new ArrayList<>();
for (Map<String, String> messageMap : messageMaps) {
if (messageMap != null && messageMap.containsKey(ROLE) && messageMap.containsKey(CONTENT)) {
vivonk marked this conversation as resolved.
Show resolved Hide resolved
Message message = new Message();
Message.TextContent textContent = new Message.TextContent();
textContent.setText(messageMap.get(CONTENT));

if (!StringUtils.hasText(temperatureString)) {
return DEFAULT_TEMPERATURE;
}
message.setRole(CommandUtils.getActualRoleValue(messageMap.get(ROLE)));
message.setContent(List.of(textContent));

try {
return Float.parseFloat(temperatureString);
} catch (IllegalArgumentException illegalArgumentException) {
return DEFAULT_TEMPERATURE;
} catch (Exception exception) {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_EXECUTE_ARGUMENT_ERROR,
String.format(STRING_APPENDER, EXECUTION_FAILURE, BAD_TEMPERATURE_CONFIGURATION));
messages.add(message);
}
}
return messages;
}
}
Loading
Loading