Skip to content

Commit

Permalink
feat(dotAI): Adding fallback mechanism when it comes to send models t…
Browse files Browse the repository at this point in the history
…o AI Provider (OpenAI)

Refs: #29284
  • Loading branch information
victoralfaro-dotcms committed Aug 16, 2024
1 parent 4e3d335 commit 4c589c3
Show file tree
Hide file tree
Showing 35 changed files with 1,204 additions and 380 deletions.
2 changes: 0 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.dotcms.ai.api;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;

import java.io.OutputStream;

Expand Down
46 changes: 24 additions & 22 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.domain.AIResponse;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.mock.request.FakeHttpRequest;
import com.dotcms.mock.response.BaseResponse;
Expand All @@ -27,7 +30,6 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -60,7 +62,7 @@ public JSONObject prompt(final String systemPrompt,
final String modelIn,
final float temperature,
final int maxTokens) {
final AIModel model = config.resolveModelOrThrow(modelIn);
final AIModel model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT);
final JSONObject json = new JSONObject();

json.put(AiKeys.TEMPERATURE, temperature);
Expand All @@ -78,7 +80,9 @@ public JSONObject prompt(final String systemPrompt,
@Override
public JSONObject summarize(final CompletionsForm summaryRequest) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.getEmbeddingResults(searcher);

// send all this as a json blob to OpenAI
final JSONObject json = buildRequestJson(summaryRequest, localResults);
Expand All @@ -87,27 +91,25 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
}

json.put(AiKeys.STREAM, false);
final String openAiResponse =
Try.of(() -> OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json))
.getOrElseThrow(DotRuntimeException::new);
final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults);
final String openAiResponse = Try.of(() -> sendRequest(config, json))
.getOrElseThrow(DotRuntimeException::new)
.getResponse();
final JSONObject dotCMSResponse = APILocator.getDotAIAPI()
.getEmbeddingsAPI()
.reduceChunksToContent(searcher, localResults);
dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse));

return dotCMSResponse;
}

@Override
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) {
public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List<EmbeddingsDTO> localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);

final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doPost(config.getApiUrl(), config, json, out);
AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(config, json), output);
}

@Override
Expand All @@ -116,11 +118,7 @@ public JSONObject raw(final JSONObject json) {
Logger.info(this.getClass(), "OpenAI request:" + json.toString(2));
}

final String response = OpenAIRequest.doRequest(
config.getApiUrl(),
HttpMethod.POST,
config,
json);
final String response = sendRequest(config, json).getResponse();
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI response:" + response);
}
Expand All @@ -135,10 +133,14 @@ public JSONObject raw(CompletionsForm promptForm) {
}

@Override
public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
public void rawStream(final CompletionsForm promptForm, final OutputStream output) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out);
AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(config, json), output);
}

private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload) {
return AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(appConfig, payload));
}

private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
Expand All @@ -151,7 +153,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}

private JSONObject buildRequestJson(final CompletionsForm form, final List<EmbeddingsDTO> searchResults) {
final AIModel model = config.resolveModelOrThrow(form.model);
final AIModel model = config.resolveModelOrThrow(form.model, AIModelType.TEXT);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
Expand Down
12 changes: 5 additions & 7 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.db.EmbeddingsDTO.Builder;
import com.dotcms.ai.db.EmbeddingsFactory;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.util.ContentToStringUtil;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.ai.util.VelocityContextFactory;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.api.web.HttpServletResponseThreadLocal;
Expand Down Expand Up @@ -43,7 +44,6 @@
import org.apache.velocity.context.Context;

import javax.validation.constraints.NotNull;
import javax.ws.rs.HttpMethod;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -428,11 +428,9 @@ private List<Float> sendTokensToOpenAI(final String contentId, @NotNull final Li
json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel());
json.put(AiKeys.INPUT, tokens);
debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
final String responseString = OpenAIRequest.doRequest(
config.getApiEmbeddingsUrl(),
HttpMethod.POST,
config,
json);
final String responseString = AIProxyClient.get()
.sendRequest(JSONObjectAIRequest.quickEmbeddings(config, json))
.getResponse();
debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
contentId, responseString.replace("\n", BLANK)));
final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> {
Expand Down
11 changes: 5 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static AIAppUtil get() {
public AIModel createTextModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.TEXT)
.withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withModelNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS))
Expand All @@ -59,7 +59,7 @@ public AIModel createTextModel(final Map<String, Secret> secrets) {
public AIModel createImageModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.IMAGE)
.withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withModelNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS))
Expand All @@ -76,7 +76,7 @@ public AIModel createImageModel(final Map<String, Secret> secrets) {
public AIModel createEmbeddingsModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.EMBEDDINGS)
.withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES))
.withModelNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS))
Expand Down Expand Up @@ -117,9 +117,8 @@ public String discoverSecret(final Map<String, Secret> secrets, final AppKeys ke
* @return the list of split secret values
*/
public List<String> splitDiscoveredSecret(final Map<String, Secret> secrets, final AppKeys key) {
return Arrays.stream(Optional.ofNullable(discoverSecret(secrets, key)).orElse(StringPool.BLANK).split(","))
.map(String::trim)
.map(String::toLowerCase)
return Arrays
.stream(Optional.ofNullable(discoverSecret(secrets, key)).orElse(StringPool.BLANK).split(","))
.collect(Collectors.toList());
}

Expand Down
Loading

0 comments on commit 4c589c3

Please sign in to comment.