Skip to content

Commit

Permalink
feat(dotAI): Adding fallback mechanism when it comes to send models
Browse files Browse the repository at this point in the history
Removing `OpenAIRequest` class in favor of set of classes explainied in `src/main/java/com/dotcms/ai/client/README.md` and integrating it with the corresponding consuming components. Integration tests added/updated.

Refs: #29284
  • Loading branch information
victoralfaro-dotcms committed Aug 29, 2024
1 parent acbd281 commit d7074e3
Show file tree
Hide file tree
Showing 48 changed files with 1,378 additions and 643 deletions.
13 changes: 9 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.dotcms.ai.api;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;

import java.io.OutputStream;

Expand Down Expand Up @@ -37,9 +35,10 @@ public interface CompletionsAPI {
* this method takes a prompt in the form of json and returns a json AI response based upon that prompt
*
* @param promptJSON
* @param userId
* @return
*/
JSONObject raw(JSONObject promptJSON);
JSONObject raw(JSONObject promptJSON, String userId);

/**
* this method takes a prompt and returns the AI response based upon that prompt
Expand All @@ -58,9 +57,15 @@ public interface CompletionsAPI {
* @param model
* @param temperature
* @param maxTokens
* @param userId
* @return
*/
JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens);
JSONObject prompt(String systemPrompt,
String userPrompt,
String model,
float temperature,
int maxTokens,
String userId);

/**
* this method takes a prompt in the form of json and returns streaming AI response based upon that prompt
Expand Down
100 changes: 55 additions & 45 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,35 @@

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.client.JSONObjectAIRequest;
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.mock.request.FakeHttpRequest;
import com.dotcms.mock.response.BaseResponse;
import com.dotcms.rendering.velocity.util.VelocityUtil;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.dotmarketing.util.json.JSONArray;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.velocity.context.Context;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -42,15 +45,13 @@
public class CompletionsAPIImpl implements CompletionsAPI {

private final AppConfig config;
private final Lazy<AppConfig> defaultConfig;

public CompletionsAPIImpl(final AppConfig config) {
defaultConfig =
Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
final Lazy<AppConfig> defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
this.config = Optional.ofNullable(config).orElse(defaultConfig.get());
}

Expand All @@ -59,8 +60,9 @@ public JSONObject prompt(final String systemPrompt,
final String userPrompt,
final String modelIn,
final float temperature,
final int maxTokens) {
final AIModel model = config.resolveModelOrThrow(modelIn);
final int maxTokens,
final String userId) {
final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2;
final JSONObject json = new JSONObject();

json.put(AiKeys.TEMPERATURE, temperature);
Expand All @@ -70,15 +72,17 @@ public JSONObject prompt(final String systemPrompt,
json.put(AiKeys.MAX_TOKENS, maxTokens);
}

json.put(AiKeys.MODEL, model.getCurrentModel());
json.put(AiKeys.MODEL, model.getName());

return raw(json);
return raw(json, userId);
}

@Override
public JSONObject summarize(final CompletionsForm summaryRequest) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.getEmbeddingResults(searcher);

// send all this as a json blob to OpenAI
final JSONObject json = buildRequestJson(summaryRequest, localResults);
Expand All @@ -87,58 +91,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
}

json.put(AiKeys.STREAM, false);
final String openAiResponse =
Try.of(() -> OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json))
.getOrElseThrow(DotRuntimeException::new);
final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults);
final String openAiResponse = Try
.of(() -> sendRequest(config, json, UtilMethods.extractUserIdOrNull(summaryRequest.user)))
.getOrElseThrow(DotRuntimeException::new)
.getResponse();
final JSONObject dotCMSResponse = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.reduceChunksToContent(searcher, localResults);
dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse));

return dotCMSResponse;
}

@Override
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) {
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.getEmbeddingResults(searcher);

final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doPost(config.getApiUrl(), config, json, out);
AIProxyClient.get().callToAI(
JSONObjectAIRequest.quickText(
config,
json,
UtilMethods.extractUserIdOrNull(summaryRequest.user)),
output);
}

@Override
public JSONObject raw(final JSONObject json) {
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI request:" + json.toString(2));
}
public JSONObject raw(final JSONObject json, final String userId) {
AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2));

final String response = OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json);
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI response:" + response);
}
final String response = sendRequest(config, json, userId).getResponse();
AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response);

return new JSONObject(response);
}

@Override
public JSONObject raw(CompletionsForm promptForm) {
public JSONObject raw(final CompletionsForm promptForm) {
JSONObject jsonObject = buildRequestJson(promptForm);
return raw(jsonObject);
return raw(jsonObject, UtilMethods.extractUserIdOrNull(promptForm.user));
}

@Override
public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
public void rawStream(final CompletionsForm promptForm, final OutputStream output) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out);
AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(
config,
json,
UtilMethods.extractUserIdOrNull(promptForm.user)),
output);
}

private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) {
return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(appConfig, payload, userId));
}

private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
Expand All @@ -151,7 +161,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}

private JSONObject buildRequestJson(final CompletionsForm form, final List<EmbeddingsDTO> searchResults) {
final AIModel model = config.resolveModelOrThrow(form.model);
final Tuple2<AIModel, Model> modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
Expand All @@ -162,7 +172,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
final int systemPromptTokens = countTokens(systemPrompt);
textPrompt = reduceStringToTokenSize(
textPrompt,
model.getMaxTokens() - form.responseLengthTokens - systemPromptTokens);
modelTuple._1.getMaxTokens() - form.responseLengthTokens - systemPromptTokens);

final JSONObject json = new JSONObject();
json.put(AiKeys.STREAM, form.stream);
Expand All @@ -171,7 +181,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
buildMessages(systemPrompt, textPrompt, json);

if (UtilMethods.isSet(form.model)) {
json.put(AiKeys.MODEL, model.getCurrentModel());
json.put(AiKeys.MODEL, modelTuple._2.getName());
}

json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ public static class DefaultChatAPIProvider implements ChatAPIProvider {
@Override
public ChatAPI getChatAPI(final Object... initArguments) {
if (Objects.nonNull(initArguments) && initArguments.length > 0 && initArguments[0] instanceof AppConfig) {
return new OpenAIChatAPIImpl((AppConfig) initArguments[0]);
final User user = initArguments.length > 1 && initArguments[1] instanceof User
? (User) initArguments[1]
: null;
return new OpenAIChatAPIImpl((AppConfig) initArguments[0], user);
}

throw new IllegalArgumentException("To create a ChatAPI you need to provide an AppConfig");
Expand Down Expand Up @@ -147,7 +150,7 @@ public static final void setDefaultImageAPIProvider(final ImageAPIProvider image

/**
* Set the default chat API Provider.
* @param chatAPIProviderq
* @param chatAPIProvider
*/
public static final void setDefaultChatAPIProvider(final ChatAPIProvider chatAPIProvider) {
chatProviderMap.put(DEFAULT, chatAPIProvider);
Expand Down
6 changes: 4 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ public interface EmbeddingsAPI {
* Embeddings
*
* @param content The content that will be tokenized and sent to OpenAI.
* @param userId The ID of the user making the request.
*
* @return Tuple(Count of Tokens Input, List of Embeddings Output)
*/
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String content);
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(String content, String userId);

/**
* this method takes a snippet of content and will try to see if we have already generated
Expand All @@ -154,10 +155,11 @@ public interface EmbeddingsAPI {
*
* @param contentId The ID of the Contentlet being sent to the OpenAI Endpoint.
* @param content The actual indexable data that will be tokenized and sent to OpenAI service.
* @param userId The ID of the user making the request.
*
* @return Tuple(Count of Tokens Input, List of Embeddings Output)
*/
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId, final String content);
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(String contentId, String content, String userId);

/**
* Checks if the embeddings for the given inode, indexName, and extractedText already exist in the database.
Expand Down
33 changes: 19 additions & 14 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.db.EmbeddingsDTO.Builder;
import com.dotcms.ai.db.EmbeddingsFactory;
import com.dotcms.ai.client.JSONObjectAIRequest;
import com.dotcms.ai.util.ContentToStringUtil;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.ai.util.VelocityContextFactory;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.api.web.HttpServletResponseThreadLocal;
Expand Down Expand Up @@ -43,7 +44,6 @@
import org.apache.velocity.context.Context;

import javax.validation.constraints.NotNull;
import javax.ws.rs.HttpMethod;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -311,13 +311,15 @@ public void initEmbeddingsTable() {
}

@Override
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(@NotNull final String content) {
return pullOrGenerateEmbeddings("N/A", content);
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(@NotNull final String content, final String userId) {
return pullOrGenerateEmbeddings("N/A", content, userId);
}

@WrapInTransaction
@Override
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId, @NotNull final String content) {
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId,
@NotNull final String content,
final String userId) {
if (UtilMethods.isEmpty(content)) {
return Tuple.of(0, List.of());
}
Expand Down Expand Up @@ -349,7 +351,7 @@ public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String conten

final Tuple2<Integer, List<Float>> openAiEmbeddings = Tuple.of(
tokens.size(),
sendTokensToOpenAI(contentId, tokens));
sendTokensToOpenAI(contentId, tokens, userId));
saveEmbeddingsForCache(content, openAiEmbeddings);
EMBEDDING_CACHE.put(hashed, openAiEmbeddings);

Expand Down Expand Up @@ -420,19 +422,20 @@ private void saveEmbeddingsForCache(final String content, final Tuple2<Integer,
*
* @param contentId The ID of the Contentlet that will be sent to the OpenAI Endpoint.
* @param tokens The encoded tokens representing the indexable data of a Contentlet.
* @param userId The ID of the user making the request.
*
* @return A {@link List} of {@link Float} values representing the embeddings.
*/
private List<Float> sendTokensToOpenAI(final String contentId, @NotNull final List<Integer> tokens) {
private List<Float> sendTokensToOpenAI(final String contentId,
@NotNull final List<Integer> tokens,
final String userId) {
final JSONObject json = new JSONObject();
json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel());
json.put(AiKeys.INPUT, tokens);
debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
final String responseString = OpenAIRequest.doRequest(
config.getApiEmbeddingsUrl(),
HttpMethod.POST,
config,
json);
final String responseString = AIProxyClient.get()
.callToAI(JSONObjectAIRequest.quickEmbeddings(config, json, userId))
.getResponse();
debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
contentId, responseString.replace("\n", BLANK)));
final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> {
Expand Down Expand Up @@ -490,8 +493,10 @@ private List<Float> getEmbeddingsFromJSON(final String contentId, final JSONObje
}
}

private EmbeddingsDTO getSearcher(EmbeddingsDTO searcher) {
final List<Float> queryEmbeddings = pullOrGenerateEmbeddings(searcher.query)._2;
private EmbeddingsDTO getSearcher(final EmbeddingsDTO searcher) {
final List<Float> queryEmbeddings = pullOrGenerateEmbeddings(
searcher.query,
UtilMethods.extractUserIdOrNull(searcher.user))._2;
return EmbeddingsDTO.copy(searcher).withEmbeddings(queryEmbeddings).build();
}

Expand Down
Loading

0 comments on commit d7074e3

Please sign in to comment.