Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dotAI): Adding fallback mechanism when it comes to send models to AI Provider (OpenAI) #29516

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.dotcms.ai.api;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;

import java.io.OutputStream;

Expand Down Expand Up @@ -37,9 +35,10 @@ public interface CompletionsAPI {
* this method takes a prompt in the form of json and returns a json AI response based upon that prompt
*
* @param promptJSON
* @param userId
* @return
*/
JSONObject raw(JSONObject promptJSON);
JSONObject raw(JSONObject promptJSON, String userId);

/**
* this method takes a prompt and returns the AI response based upon that prompt
Expand All @@ -58,9 +57,15 @@ public interface CompletionsAPI {
* @param model
* @param temperature
* @param maxTokens
* @param userId
* @return
*/
JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens);
JSONObject prompt(String systemPrompt,
String userPrompt,
String model,
float temperature,
int maxTokens,
String userId);

/**
* this method takes a prompt in the form of json and returns streaming AI response based upon that prompt
Expand Down
100 changes: 55 additions & 45 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,35 @@

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.mock.request.FakeHttpRequest;
import com.dotcms.mock.response.BaseResponse;
import com.dotcms.rendering.velocity.util.VelocityUtil;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.dotmarketing.util.json.JSONArray;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.velocity.context.Context;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -42,15 +45,13 @@
public class CompletionsAPIImpl implements CompletionsAPI {

private final AppConfig config;
private final Lazy<AppConfig> defaultConfig;

public CompletionsAPIImpl(final AppConfig config) {
defaultConfig =
Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
final Lazy<AppConfig> defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
this.config = Optional.ofNullable(config).orElse(defaultConfig.get());
}

Expand All @@ -59,8 +60,9 @@ public JSONObject prompt(final String systemPrompt,
final String userPrompt,
final String modelIn,
final float temperature,
final int maxTokens) {
final AIModel model = config.resolveModelOrThrow(modelIn);
final int maxTokens,
final String userId) {
final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2;
final JSONObject json = new JSONObject();

json.put(AiKeys.TEMPERATURE, temperature);
Expand All @@ -70,15 +72,17 @@ public JSONObject prompt(final String systemPrompt,
json.put(AiKeys.MAX_TOKENS, maxTokens);
}

json.put(AiKeys.MODEL, model.getCurrentModel());
json.put(AiKeys.MODEL, model.getName());

return raw(json);
return raw(json, userId);
}

@Override
public JSONObject summarize(final CompletionsForm summaryRequest) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.getEmbeddingResults(searcher);

// send all this as a json blob to OpenAI
final JSONObject json = buildRequestJson(summaryRequest, localResults);
Expand All @@ -87,58 +91,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
}

json.put(AiKeys.STREAM, false);
final String openAiResponse =
Try.of(() -> OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json))
.getOrElseThrow(DotRuntimeException::new);
final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults);
final String openAiResponse = Try
.of(() -> sendRequest(config, json, UtilMethods.extractUserIdOrNull(summaryRequest.user)))
.getOrElseThrow(DotRuntimeException::new)
.getResponse();
final JSONObject dotCMSResponse = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.reduceChunksToContent(searcher, localResults);
dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse));

return dotCMSResponse;
}

@Override
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) {
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.getEmbeddingResults(searcher);

final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doPost(config.getApiUrl(), config, json, out);
AIProxyClient.get().callToAI(
JSONObjectAIRequest.quickText(
config,
json,
UtilMethods.extractUserIdOrNull(summaryRequest.user)),
output);
}

@Override
public JSONObject raw(final JSONObject json) {
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI request:" + json.toString(2));
}
public JSONObject raw(final JSONObject json, final String userId) {
AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2));

final String response = OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json);
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI response:" + response);
}
final String response = sendRequest(config, json, userId).getResponse();
AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response);

return new JSONObject(response);
}

@Override
public JSONObject raw(CompletionsForm promptForm) {
public JSONObject raw(final CompletionsForm promptForm) {
JSONObject jsonObject = buildRequestJson(promptForm);
return raw(jsonObject);
return raw(jsonObject, UtilMethods.extractUserIdOrNull(promptForm.user));
}

@Override
public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
public void rawStream(final CompletionsForm promptForm, final OutputStream output) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out);
AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(
config,
json,
UtilMethods.extractUserIdOrNull(promptForm.user)),
output);
}

private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) {
return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(appConfig, payload, userId));
}

private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
Expand All @@ -151,7 +161,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}

private JSONObject buildRequestJson(final CompletionsForm form, final List<EmbeddingsDTO> searchResults) {
final AIModel model = config.resolveModelOrThrow(form.model);
final Tuple2<AIModel, Model> modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
Expand All @@ -162,7 +172,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
final int systemPromptTokens = countTokens(systemPrompt);
textPrompt = reduceStringToTokenSize(
textPrompt,
model.getMaxTokens() - form.responseLengthTokens - systemPromptTokens);
modelTuple._1.getMaxTokens() - form.responseLengthTokens - systemPromptTokens);

final JSONObject json = new JSONObject();
json.put(AiKeys.STREAM, form.stream);
Expand All @@ -171,7 +181,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
buildMessages(systemPrompt, textPrompt, json);

if (UtilMethods.isSet(form.model)) {
json.put(AiKeys.MODEL, model.getCurrentModel());
json.put(AiKeys.MODEL, modelTuple._2.getName());
}

json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public static class DefaultChatAPIProvider implements ChatAPIProvider {
@Override
public ChatAPI getChatAPI(final Object... initArguments) {
if (Objects.nonNull(initArguments) && initArguments.length > 0 && initArguments[0] instanceof AppConfig) {
return new OpenAIChatAPIImpl((AppConfig) initArguments[0]);
final User user = initArguments.length > 1 && initArguments[1] instanceof User ? (User) initArguments[1] : null;
return new OpenAIChatAPIImpl((AppConfig) initArguments[0], user);
}

throw new IllegalArgumentException("To create a ChatAPI you need to provide an AppConfig");
Expand Down
6 changes: 4 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ public interface EmbeddingsAPI {
* Embeddings
*
* @param content The content that will be tokenized and sent to OpenAI.
* @param userId The ID of the user making the request.
*
* @return Tuple(Count of Tokens Input, List of Embeddings Output)
*/
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String content);
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(String content, String userId);

/**
* this method takes a snippet of content and will try to see if we have already generated
Expand All @@ -154,10 +155,11 @@ public interface EmbeddingsAPI {
*
* @param contentId The ID of the Contentlet being sent to the OpenAI Endpoint.
* @param content The actual indexable data that will be tokenized and sent to OpenAI service.
* @param userId The ID of the user making the request.
*
* @return Tuple(Count of Tokens Input, List of Embeddings Output)
*/
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId, final String content);
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(String contentId, String content, String userId);

/**
* Checks if the embeddings for the given inode, indexName, and extractedText already exist in the database.
Expand Down
33 changes: 19 additions & 14 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.db.EmbeddingsDTO.Builder;
import com.dotcms.ai.db.EmbeddingsFactory;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.util.ContentToStringUtil;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.ai.util.VelocityContextFactory;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.api.web.HttpServletResponseThreadLocal;
Expand Down Expand Up @@ -43,7 +44,6 @@
import org.apache.velocity.context.Context;

import javax.validation.constraints.NotNull;
import javax.ws.rs.HttpMethod;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -311,13 +311,15 @@ public void initEmbeddingsTable() {
}

@Override
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(@NotNull final String content) {
return pullOrGenerateEmbeddings("N/A", content);
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(@NotNull final String content, final String userId) {
return pullOrGenerateEmbeddings("N/A", content, userId);
}

@WrapInTransaction
@Override
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId, @NotNull final String content) {
public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String contentId,
@NotNull final String content,
final String userId) {
if (UtilMethods.isEmpty(content)) {
return Tuple.of(0, List.of());
}
Expand Down Expand Up @@ -349,7 +351,7 @@ public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String conten

final Tuple2<Integer, List<Float>> openAiEmbeddings = Tuple.of(
tokens.size(),
sendTokensToOpenAI(contentId, tokens));
sendTokensToOpenAI(contentId, tokens, userId));
saveEmbeddingsForCache(content, openAiEmbeddings);
EMBEDDING_CACHE.put(hashed, openAiEmbeddings);

Expand Down Expand Up @@ -420,19 +422,20 @@ private void saveEmbeddingsForCache(final String content, final Tuple2<Integer,
*
* @param contentId The ID of the Contentlet that will be sent to the OpenAI Endpoint.
* @param tokens The encoded tokens representing the indexable data of a Contentlet.
* @param userId The ID of the user making the request.
*
* @return A {@link List} of {@link Float} values representing the embeddings.
*/
private List<Float> sendTokensToOpenAI(final String contentId, @NotNull final List<Integer> tokens) {
private List<Float> sendTokensToOpenAI(final String contentId,
@NotNull final List<Integer> tokens,
final String userId) {
final JSONObject json = new JSONObject();
json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel());
json.put(AiKeys.INPUT, tokens);
debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
final String responseString = OpenAIRequest.doRequest(
config.getApiEmbeddingsUrl(),
HttpMethod.POST,
config,
json);
final String responseString = AIProxyClient.get()
.callToAI(JSONObjectAIRequest.quickEmbeddings(config, json, userId))
.getResponse();
debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
contentId, responseString.replace("\n", BLANK)));
final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> {
Expand Down Expand Up @@ -490,8 +493,10 @@ private List<Float> getEmbeddingsFromJSON(final String contentId, final JSONObje
}
}

private EmbeddingsDTO getSearcher(EmbeddingsDTO searcher) {
final List<Float> queryEmbeddings = pullOrGenerateEmbeddings(searcher.query)._2;
private EmbeddingsDTO getSearcher(final EmbeddingsDTO searcher) {
final List<Float> queryEmbeddings = pullOrGenerateEmbeddings(
searcher.query,
UtilMethods.extractUserIdOrNull(searcher.user))._2;
return EmbeddingsDTO.copy(searcher).withEmbeddings(queryEmbeddings).build();
}

Expand Down
6 changes: 5 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.business.WrapInTransaction;
import com.dotcms.exception.ExceptionUtil;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.portlets.contentlet.model.Contentlet;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
Expand Down Expand Up @@ -119,7 +120,10 @@ private void saveEmbedding(@NotNull final String initial) {
}

final Tuple2<Integer, List<Float>> embeddings =
this.embeddingsAPI.pullOrGenerateEmbeddings(this.contentlet.getIdentifier(), normalizedContent);
this.embeddingsAPI.pullOrGenerateEmbeddings(
contentlet.getIdentifier(),
normalizedContent,
APILocator.systemUser().getUserId());
if (embeddings._2.isEmpty()) {
Logger.info(this.getClass(), String.format("No tokens for Content Type " +
"'%s'. Normalized content: %s", this.contentlet.getContentType().variable(), normalizedContent));
Expand Down
Loading
Loading