From b0e41871036dfa29ffb8dd0527776b9949ae7d6d Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Thu, 8 Aug 2024 19:26:42 -0600 Subject: [PATCH] feat(dotAI): Adding fallback mechanism when it comes to send models to AI Provider (OpenAI) Refs: #29284 --- .../com/dotcms/ai/api/CompletionsAPI.java | 13 +- .../com/dotcms/ai/api/CompletionsAPIImpl.java | 99 +++---- .../com/dotcms/ai/api/DotAIAPIFacadeImpl.java | 4 +- .../java/com/dotcms/ai/api/EmbeddingsAPI.java | 6 +- .../com/dotcms/ai/api/EmbeddingsAPIImpl.java | 31 ++- .../com/dotcms/ai/api/EmbeddingsRunner.java | 6 +- .../com/dotcms/ai/api/OpenAIChatAPIImpl.java | 15 +- .../com/dotcms/ai/api/OpenAIImageAPIImpl.java | 10 +- .../java/com/dotcms/ai/app/AIAppUtil.java | 30 ++- .../main/java/com/dotcms/ai/app/AIModel.java | 110 ++++---- .../main/java/com/dotcms/ai/app/AIModels.java | 242 ++++++++++++------ .../java/com/dotcms/ai/app/AppConfig.java | 48 ++-- .../java/com/dotcms/ai/app/ConfigService.java | 8 +- .../java/com/dotcms/ai/client/AIClient.java | 108 ++++++++ .../dotcms/ai/client/AIClientStrategy.java | 42 +++ .../dotcms/ai/client/AIDefaultStrategy.java | 34 +++ .../ai/client/AIModelFallbackStrategy.java | 238 +++++++++++++++++ .../com/dotcms/ai/client/AIProxiedClient.java | 86 +++++++ .../com/dotcms/ai/client/AIProxyClient.java | 121 +++++++++ .../com/dotcms/ai/client/AIProxyStrategy.java | 34 +++ .../dotcms/ai/client/AIResponseEvaluator.java | 35 +++ .../dotcms/ai/client/openai/OpenAIClient.java | 159 ++++++++++++ .../openai/OpenAIResponseEvaluator.java | 87 +++++++ .../java/com/dotcms/ai/db/EmbeddingsDTO.java | 3 +- .../java/com/dotcms/ai/domain/AIProvider.java | 35 +++ .../java/com/dotcms/ai/domain/AIRequest.java | 219 ++++++++++++++++ .../java/com/dotcms/ai/domain/AIResponse.java | 50 ++++ .../com/dotcms/ai/domain/AIResponseData.java | 68 +++++ .../dotcms/ai/domain/JSONObjectAIRequest.java | 105 ++++++++ .../main/java/com/dotcms/ai/domain/Model.java | 104 ++++++++ .../com/dotcms/ai/domain/ModelStatus.java | 30 +++ .../DotAIAllModelsExhaustedException.java | 22 ++ .../DotAIAppConfigDisabledException.java | 22 ++ .../DotAIClientConnectException.java | 21 ++ .../DotAIModelNotFoundException.java | 21 ++ .../DotAIModelNotOperationalException.java | 21 ++ .../com/dotcms/ai/listener/AIAppListener.java | 21 +- .../ai/listener/EmbeddingContentListener.java | 4 +- .../dotcms/ai/model/AIImageRequestDTO.java | 9 +- .../java/com/dotcms/ai/model/SimpleModel.java | 25 +- .../dotcms/ai/rest/CompletionsResource.java | 38 +-- .../com/dotcms/ai/rest/ImageResource.java | 2 +- .../java/com/dotcms/ai/rest/TextResource.java | 15 +- .../dotcms/ai/rest/forms/CompletionsForm.java | 19 +- .../dotcms/ai/rest/forms/EmbeddingsForm.java | 9 +- .../com/dotcms/ai/{ => util}/Marshaller.java | 2 +- .../com/dotcms/ai/util/OpenAIRequest.java | 189 -------------- .../dotcms/ai/validator/AIAppValidator.java | 95 +++++++ .../com/dotcms/ai/viewtool/AIViewTool.java | 8 +- .../dotcms/ai/viewtool/CompletionsTool.java | 22 +- .../dotcms/ai/viewtool/EmbeddingsTool.java | 22 +- .../ai/workflow/OpenAIAutoTagActionlet.java | 21 +- .../ai/workflow/OpenAIAutoTagRunner.java | 2 +- .../OpenAIContentPromptActionlet.java | 37 ++- .../workflow/OpenAIContentPromptRunner.java | 4 +- dotCMS/src/main/resources/apps/dotAI.yml | 4 +- .../WEB-INF/messages/Language.properties | 2 + .../webapp/html/portlet/ext/dotai/dotai.js | 5 +- .../OpenAIChatAPIImplTest.java} | 17 +- .../OpenAIImageAPIImplTest.java} | 8 +- .../java/com/dotcms/ai/app/AIAppUtilTest.java | 9 +- .../dotcms/ai/client/AIProxyClientTest.java | 65 +++++ .../ai/client/openai/AIProxiedClientTest.java | 102 ++++++++ .../openai/OpenAIResponseEvaluatorTest.java | 104 ++++++++ .../src/test/java/com/dotcms/MainSuite2b.java | 2 + .../src/test/java/com/dotcms/ai/AiTest.java | 49 ++-- .../java/com/dotcms/ai/app/AIModelsTest.java | 228 ++++++++++++----- .../com/dotcms/ai/app/ConfigServiceTest.java | 101 ++++++++ .../EmbeddingContentListenerTest.java | 5 +- .../dotcms/ai/viewtool/AIViewToolTest.java | 6 +- .../ai/viewtool/CompletionsToolTest.java | 21 +- .../ai/viewtool/EmbeddingsToolTest.java | 13 +- .../workflow/OpenAIAutoTagActionletTest.java | 7 +- .../OpenAIContentPromptActionletTest.java | 5 +- dotcms-postman/pom.xml | 3 - .../postman/AI.postman_collection.json | 6 +- 76 files changed, 2957 insertions(+), 636 deletions(-) create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/Model.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java rename dotCMS/src/main/java/com/dotcms/ai/{ => util}/Marshaller.java (98%) delete mode 100644 dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java rename dotCMS/src/test/java/com/dotcms/ai/{service/OpenAIChatServiceImplTest.java => api/OpenAIChatAPIImplTest.java} (86%) rename dotCMS/src/test/java/com/dotcms/ai/{service/OpenAIImageServiceImplTest.java => api/OpenAIImageAPIImplTest.java} (97%) create mode 100644 dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java create mode 100644 dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java create mode 100644 dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java create mode 100644 dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java index d980647f7b41..38ea42606703 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java @@ -1,9 +1,7 @@ package com.dotcms.ai.api; -import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotmarketing.util.json.JSONObject; -import io.vavr.Lazy; import java.io.OutputStream; @@ -37,9 +35,10 @@ public interface CompletionsAPI { * this method takes a prompt in the form of json and returns a json AI response based upon that prompt * * @param promptJSON + * @param userId * @return */ - JSONObject raw(JSONObject promptJSON); + JSONObject raw(JSONObject promptJSON, String userId); /** * this method takes a prompt and returns the AI response based upon that prompt @@ -58,9 +57,15 @@ public interface CompletionsAPI { * @param model * @param temperature * @param maxTokens + * @param userId * @return */ - JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens); + JSONObject prompt(String systemPrompt, + String userPrompt, + String model, + float temperature, + int maxTokens, + String userId); /** * this method takes a prompt in the form of json and returns streaming AI response based upon that prompt diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index 008f65609096..56af6ddab956 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -2,13 +2,17 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.mock.request.FakeHttpRequest; import com.dotcms.mock.response.BaseResponse; @@ -16,18 +20,18 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import io.vavr.Lazy; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.ArrayUtils; import org.apache.velocity.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.HttpMethod; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; @@ -42,15 +46,13 @@ public class CompletionsAPIImpl implements CompletionsAPI { private final AppConfig config; - private final Lazy defaultConfig; public CompletionsAPIImpl(final AppConfig config) { - defaultConfig = - Lazy.of(() -> ConfigService.INSTANCE.config( - Try.of(() -> WebAPILocator - .getHostWebAPI() - .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) - .getOrElse(APILocator.systemHost()))); + final Lazy defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config( + Try.of(() -> WebAPILocator + .getHostWebAPI() + .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) + .getOrElse(APILocator.systemHost()))); this.config = Optional.ofNullable(config).orElse(defaultConfig.get()); } @@ -59,8 +61,9 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String modelIn, final float temperature, - final int maxTokens) { - final AIModel model = config.resolveModelOrThrow(modelIn); + final int maxTokens, + final String userId) { + final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2; final JSONObject json = new JSONObject(); json.put(AiKeys.TEMPERATURE, temperature); @@ -70,15 +73,17 @@ public JSONObject prompt(final String systemPrompt, json.put(AiKeys.MAX_TOKENS, maxTokens); } - json.put(AiKeys.MODEL, model.getCurrentModel()); + json.put(AiKeys.MODEL, model.getName()); - return raw(json); + return raw(json, userId); } @Override public JSONObject summarize(final CompletionsForm summaryRequest) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); - final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); + final List localResults = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .getEmbeddingResults(searcher); // send all this as a json blob to OpenAI final JSONObject json = buildRequestJson(summaryRequest, localResults); @@ -87,58 +92,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) { } json.put(AiKeys.STREAM, false); - final String openAiResponse = - Try.of(() -> OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json)) - .getOrElseThrow(DotRuntimeException::new); - final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults); + final String openAiResponse = Try.of(() -> sendRequest(config, json, getUserIdIfNotNull(summaryRequest.user))) + .getOrElseThrow(DotRuntimeException::new) + .getResponse(); + final JSONObject dotCMSResponse = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .reduceChunksToContent(searcher, localResults); dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse)); return dotCMSResponse; } @Override - public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) { + public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); final JSONObject json = buildRequestJson(summaryRequest, localResults); json.put(AiKeys.STREAM, true); - OpenAIRequest.doPost(config.getApiUrl(), config, json, out); + AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText( + config, + json, + getUserIdIfNotNull(summaryRequest.user)), + output); } @Override - public JSONObject raw(final JSONObject json) { - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI request:" + json.toString(2)); - } + public JSONObject raw(final JSONObject json, final String userId) { + AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2)); - final String response = OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json); - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI response:" + response); - } + final String response = sendRequest(config, json, userId).getResponse(); + AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response); return new JSONObject(response); } @Override - public JSONObject raw(CompletionsForm promptForm) { + public JSONObject raw(final CompletionsForm promptForm) { JSONObject jsonObject = buildRequestJson(promptForm); - return raw(jsonObject); + return raw(jsonObject, getUserIdIfNotNull(promptForm.user)); } @Override - public void rawStream(final CompletionsForm promptForm, final OutputStream out) { + public void rawStream(final CompletionsForm promptForm, final OutputStream output) { final JSONObject json = buildRequestJson(promptForm); json.put(AiKeys.STREAM, true); - OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out); + AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText( + config, + json, + getUserIdIfNotNull(promptForm.user)), + output); + } + + private String getUserIdIfNotNull(final User user) { + return Optional.ofNullable(user).map(User::getUserId).orElse(null); + } + + private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) { + return AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(appConfig, payload, userId)); } private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) { @@ -151,7 +162,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f } private JSONObject buildRequestJson(final CompletionsForm form, final List searchResults) { - final AIModel model = config.resolveModelOrThrow(form.model); + final Tuple2 modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT); // aggregate matching results into text final StringBuilder supportingContent = new StringBuilder(); searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" ")); @@ -162,7 +173,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List 0 && initArguments[0] instanceof AppConfig) { - return new OpenAIChatAPIImpl((AppConfig) initArguments[0]); + if (Objects.nonNull(initArguments) && initArguments.length > 1 && initArguments[0] instanceof AppConfig) { + return new OpenAIChatAPIImpl((AppConfig) initArguments[0], (User) initArguments[1]); } throw new IllegalArgumentException("To create a ChatAPI you need to provide an AppConfig"); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java index 2ffaa8e702db..e619ae11ab03 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java @@ -140,10 +140,11 @@ public interface EmbeddingsAPI { * Embeddings * * @param content The content that will be tokenized and sent to OpenAI. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String content); + Tuple2> pullOrGenerateEmbeddings(String content, String userId); /** * this method takes a snippet of content and will try to see if we have already generated @@ -154,10 +155,11 @@ public interface EmbeddingsAPI { * * @param contentId The ID of the Contentlet being sent to the OpenAI Endpoint. * @param content The actual indexable data that will be tokenized and sent to OpenAI service. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String contentId, final String content); + Tuple2> pullOrGenerateEmbeddings(String contentId, String content, String userId); /** * Checks if the embeddings for the given inode, indexName, and extractedText already exist in the database. diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java index c40a8e1692e4..56e4c1804874 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java @@ -4,12 +4,13 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; import com.dotcms.ai.db.EmbeddingsDTO.Builder; import com.dotcms.ai.db.EmbeddingsFactory; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotcms.ai.util.ContentToStringUtil; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.VelocityContextFactory; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.api.web.HttpServletResponseThreadLocal; @@ -43,7 +44,6 @@ import org.apache.velocity.context.Context; import javax.validation.constraints.NotNull; -import javax.ws.rs.HttpMethod; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -311,13 +311,15 @@ public void initEmbeddingsTable() { } @Override - public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content) { - return pullOrGenerateEmbeddings("N/A", content); + public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content, final String userId) { + return pullOrGenerateEmbeddings("N/A", content, userId); } @WrapInTransaction @Override - public Tuple2> pullOrGenerateEmbeddings(final String contentId, @NotNull final String content) { + public Tuple2> pullOrGenerateEmbeddings(final String contentId, + @NotNull final String content, + final String userId) { if (UtilMethods.isEmpty(content)) { return Tuple.of(0, List.of()); } @@ -349,7 +351,7 @@ public Tuple2> pullOrGenerateEmbeddings(final String conten final Tuple2> openAiEmbeddings = Tuple.of( tokens.size(), - sendTokensToOpenAI(contentId, tokens)); + sendTokensToOpenAI(contentId, tokens, userId)); saveEmbeddingsForCache(content, openAiEmbeddings); EMBEDDING_CACHE.put(hashed, openAiEmbeddings); @@ -420,19 +422,20 @@ private void saveEmbeddingsForCache(final String content, final Tuple2 sendTokensToOpenAI(final String contentId, @NotNull final List tokens) { + private List sendTokensToOpenAI(final String contentId, + @NotNull final List tokens, + final String userId) { final JSONObject json = new JSONObject(); json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel()); json.put(AiKeys.INPUT, tokens); debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens)); - final String responseString = OpenAIRequest.doRequest( - config.getApiEmbeddingsUrl(), - HttpMethod.POST, - config, - json); + final String responseString = AIProxyClient.get() + .sendRequest(JSONObjectAIRequest.quickEmbeddings(config, json, userId)) + .getResponse(); debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s", contentId, responseString.replace("\n", BLANK))); final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> { @@ -490,8 +493,8 @@ private List getEmbeddingsFromJSON(final String contentId, final JSONObje } } - private EmbeddingsDTO getSearcher(EmbeddingsDTO searcher) { - final List queryEmbeddings = pullOrGenerateEmbeddings(searcher.query)._2; + private EmbeddingsDTO getSearcher(final EmbeddingsDTO searcher) { + final List queryEmbeddings = pullOrGenerateEmbeddings(searcher.query, searcher.user.getUserId())._2; return EmbeddingsDTO.copy(searcher).withEmbeddings(queryEmbeddings).build(); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java index 47058164d4ca..ed2c4c5c385d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java @@ -6,6 +6,7 @@ import com.dotcms.ai.util.EncodingUtil; import com.dotcms.business.WrapInTransaction; import com.dotcms.exception.ExceptionUtil; +import com.dotmarketing.business.APILocator; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; @@ -119,7 +120,10 @@ private void saveEmbedding(@NotNull final String initial) { } final Tuple2> embeddings = - this.embeddingsAPI.pullOrGenerateEmbeddings(this.contentlet.getIdentifier(), normalizedContent); + this.embeddingsAPI.pullOrGenerateEmbeddings( + contentlet.getIdentifier(), + normalizedContent, + APILocator.systemUser().getUserId()); if (embeddings._2.isEmpty()) { Logger.info(this.getClass(), String.format("No tokens for Content Type " + "'%s'. Normalized content: %s", this.contentlet.getContentType().variable(), normalizedContent)); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java index 2e94dbe4218d..0a0f5e3ce46a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java @@ -3,21 +3,24 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.util.OpenAIRequest; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; -import javax.ws.rs.HttpMethod; import java.util.List; import java.util.Map; public class OpenAIChatAPIImpl implements ChatAPI { private final AppConfig config; + private final User user; - public OpenAIChatAPIImpl(final AppConfig appConfig) { + public OpenAIChatAPIImpl(final AppConfig appConfig, final User user) { this.config = appConfig; + this.user = user; } @Override @@ -36,7 +39,7 @@ public JSONObject sendRawRequest(final JSONObject prompt) { prompt.remove(AiKeys.PROMPT); - return new JSONObject(doRequest(config.getApiUrl(), prompt)); + return new JSONObject(doRequest(prompt, user.getUserId())); } @Override @@ -47,8 +50,8 @@ public JSONObject sendTextPrompt(final String textPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final JSONObject json, final String userId) { + return AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(config, json, userId)).getResponse(); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java index 041a49b6e2d1..8c58de67e3a9 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java @@ -2,8 +2,9 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotcms.ai.model.AIImageRequestDTO; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.OpenAiRequestUtil; import com.dotcms.ai.util.StopWordsUtil; import com.dotcms.api.web.HttpServletRequestThreadLocal; @@ -24,7 +25,6 @@ import io.vavr.control.Try; import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.HttpMethod; import java.net.URL; import java.text.SimpleDateFormat; import java.util.Date; @@ -173,8 +173,10 @@ private String generateFileName(final String originalPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final String urlIn, final JSONObject json) { + return AIProxyClient.get() + .sendRequest(JSONObjectAIRequest.quickImage(config, json, user.getUserId())) + .getResponse(); } @VisibleForTesting diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java index a4f6d2c8fb12..3cb2a1dad162 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java @@ -6,6 +6,7 @@ import com.liferay.util.StringPool; import io.vavr.Lazy; import io.vavr.control.Try; +import org.apache.commons.collections4.CollectionUtils; import java.util.Arrays; import java.util.List; @@ -40,9 +41,14 @@ public static AIAppUtil get() { * @return the created text model instance */ public AIModel createTextModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.TEXT) - .withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS)) @@ -57,9 +63,14 @@ public AIModel createTextModel(final Map secrets) { * @return the created image model instance */ public AIModel createImageModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.IMAGE) - .withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS)) @@ -74,9 +85,14 @@ public AIModel createImageModel(final Map secrets) { * @return the created embeddings model instance */ public AIModel createEmbeddingsModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.EMBEDDINGS) - .withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS)) @@ -117,9 +133,11 @@ public String discoverSecret(final Map secrets, final AppKeys ke * @return the list of split secret values */ public List splitDiscoveredSecret(final Map secrets, final AppKeys key) { - return Arrays.stream(Optional.ofNullable(discoverSecret(secrets, key)).orElse(StringPool.BLANK).split(",")) - .map(String::trim) - .map(String::toLowerCase) + return Arrays + .stream(Optional + .ofNullable(discoverSecret(secrets, key)) + .map(secret -> secret.split(StringPool.COMMA)) + .orElse(new String[0])) .collect(Collectors.toList()); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index efbcc09a0872..7bd00b2af668 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -1,12 +1,15 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIModelNotFoundException; import com.dotcms.util.DotPreconditions; import com.dotmarketing.util.Logger; import java.util.List; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * Represents an AI model with various attributes such as type, names, tokens per minute, @@ -20,41 +23,34 @@ public class AIModel { public static final AIModel NOOP_MODEL = AIModel.builder() .withType(AIModelType.UNKNOWN) - .withNames(List.of()) + .withModelNames(List.of()) .build(); private final AIModelType type; - private final List names; + private final List models; private final int tokensPerMinute; private final int apiPerMinute; private final int maxTokens; private final boolean isCompletion; - private final AtomicInteger current; - private final AtomicBoolean decommissioned; - - private AIModel(final AIModelType type, - final List names, - final int tokensPerMinute, - final int apiPerMinute, - final int maxTokens, - final boolean isCompletion) { - DotPreconditions.checkNotNull(type, "type cannot be null"); - this.type = type; - this.names = Optional.ofNullable(names).orElse(List.of()); - this.tokensPerMinute = tokensPerMinute; - this.apiPerMinute = apiPerMinute; - this.maxTokens = maxTokens; - this.isCompletion = isCompletion; - current = new AtomicInteger(this.names.isEmpty() ? -1 : 0); - decommissioned = new AtomicBoolean(false); + private final AtomicInteger currentModelIndex; + + private AIModel(final Builder builder) { + DotPreconditions.checkNotNull(builder.type, "type cannot be null"); + this.type = builder.type; + this.models = builder.models; + this.tokensPerMinute = builder.tokensPerMinute; + this.apiPerMinute = builder.apiPerMinute; + this.maxTokens = builder.maxTokens; + this.isCompletion = builder.isCompletion; + currentModelIndex = new AtomicInteger(this.models.isEmpty() ? -1 : 0); } public AIModelType getType() { return type; } - public List getNames() { - return names; + public List getModels() { + return models; } public int getTokensPerMinute() { @@ -73,38 +69,41 @@ public boolean isCompletion() { return isCompletion; } - public int getCurrent() { - return current.get(); + public int getCurrentModelIndex() { + return currentModelIndex.get(); } - public void setCurrent(final int current) { - if (!isCurrentValid(current)) { + public void setCurrentModelIndex(final int currentModelIndex) { + if (!isCurrentValid(currentModelIndex)) { logInvalidModelMessage(); return; } - this.current.set(current); - } - - public boolean isDecommissioned() { - return decommissioned.get(); - } - - public void setDecommissioned(final boolean decommissioned) { - this.decommissioned.set(decommissioned); + this.currentModelIndex.set(currentModelIndex); } public boolean isOperational() { return this != NOOP_MODEL; } - public String getCurrentModel() { - final int currentIndex = this.current.get(); + public Model getCurrent() { + final int currentIndex = this.currentModelIndex.get(); if (!isCurrentValid(currentIndex)) { logInvalidModelMessage(); return null; } + return models.get(currentIndex); + } + + public String getCurrentModel() { + return getCurrent().getName(); + } - return names.get(currentIndex); + public Model getModel(final String modelName) { + final String normalized = modelName.trim().toLowerCase(); + return models.stream() + .filter(model -> normalized.equals(model.getName())) + .findFirst() + .orElseThrow(() -> new DotAIModelNotFoundException(String.format("Model [%s] not found", modelName))); } public long minIntervalBetweenCalls() { @@ -115,22 +114,21 @@ public long minIntervalBetweenCalls() { public String toString() { return "AIModel{" + "type=" + type + - ", names=" + names + + ", models='" + models + '\'' + ", tokensPerMinute=" + tokensPerMinute + ", apiPerMinute=" + apiPerMinute + ", maxTokens=" + maxTokens + ", isCompletion=" + isCompletion + - ", current=" + current + - ", decommissioned=" + decommissioned + + ", currentModelIndex=" + currentModelIndex.get() + '}'; } private boolean isCurrentValid(final int current) { - return !names.isEmpty() && current >= 0 && current < names.size(); + return !models.isEmpty() && current >= 0 && current < models.size(); } private void logInvalidModelMessage() { - Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", names.size())); + Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", models.size())); } public static Builder builder() { @@ -140,7 +138,7 @@ public static Builder builder() { public static class Builder { private AIModelType type; - private List names; + private List models; private int tokensPerMinute; private int apiPerMinute; private int maxTokens; @@ -154,13 +152,25 @@ public Builder withType(final AIModelType type) { return this; } - public Builder withNames(final List names) { - this.names = names; + public Builder withModels(final List models) { + this.models = Optional.ofNullable(models).orElse(List.of()); return this; } - public Builder withNames(final String... names) { - return withNames(List.of(names)); + public Builder withModelNames(final List names) { + return withModels( + Optional.ofNullable(names) + .map(modelNames -> IntStream.range(0, modelNames.size()) + .mapToObj(index -> Model.builder() + .withName(modelNames.get(index)) + .withIndex(index) + .build()) + .collect(Collectors.toList())) + .orElse(List.of())); + } + + public Builder withModelNames(final String... names) { + return withModelNames(List.of(names)); } public Builder withTokensPerMinute(final int tokensPerMinute) { @@ -184,7 +194,7 @@ public Builder withIsCompletion(final boolean isCompletion) { } public AIModel build() { - return new AIModel(type, names, tokensPerMinute, apiPerMinute, maxTokens, isCompletion); + return new AIModel(this); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 388afb7545e3..f48d752eb516 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -1,5 +1,9 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; import com.dotcms.ai.model.OpenAIModel; import com.dotcms.ai.model.OpenAIModels; import com.dotcms.ai.model.SimpleModel; @@ -13,17 +17,17 @@ import io.vavr.Lazy; import io.vavr.Tuple; import io.vavr.Tuple2; +import io.vavr.Tuple3; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -40,28 +44,55 @@ public class AIModels { private static final int AI_MODELS_FETCH_ATTEMPTS = Config.getIntProperty(AI_MODELS_FETCH_ATTEMPTS_KEY, 3); private static final String AI_MODELS_FETCH_TIMEOUT_KEY = "ai.models.fetch.timeout"; private static final int AI_MODELS_FETCH_TIMEOUT = Config.getIntProperty(AI_MODELS_FETCH_TIMEOUT_KEY, 4000); - private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private static final String AI_MODELS_API_URL_KEY = "DOT_AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_KEY = "AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_DEFAULT = "https://api.openai.com/v1/models"; private static final String AI_MODELS_API_URL = Config.getStringProperty( AI_MODELS_API_URL_KEY, - "https://api.openai.com/v1/models"); + AI_MODELS_API_URL_DEFAULT); private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours - private static final int AI_MODELS_CACHE_SIZE = 128; + private static final int AI_MODELS_CACHE_SIZE = 256; + private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); - private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = - Caffeine.newBuilder() - .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) - .maximumSize(AI_MODELS_CACHE_SIZE) - .build(); - private Supplier appConfigSupplier = ConfigService.INSTANCE::config; + private final ConcurrentMap>> internalModels; + private final ConcurrentMap, AIModel> modelsByName; + private final Cache> supportedModelsCache; public static AIModels get() { return INSTANCE.get(); } + private static CircuitBreakerUrl.Response fetchOpenAIModels(final String apiKey) { + final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() + .setMethod(CircuitBreakerUrl.Method.GET) + .setUrl(AI_MODELS_API_URL) + .setTimeout(AI_MODELS_FETCH_TIMEOUT) + .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) + .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + apiKey)) + .setThrowWhenNot2xx(true) + .build() + .doResponse(OpenAIModels.class); + + if (!CircuitBreakerUrl.isSuccessResponse(response)) { + AppConfig.debugLogger( + AIModels.class, + () -> String.format( + "Error fetching OpenAI supported models from [%s] (status code: [%d])", + AI_MODELS_API_URL, + response.getStatusCode())); + throw new DotRuntimeException("Error fetching OpenAI supported models"); + } + + return response; + } + private AIModels() { + internalModels = new ConcurrentHashMap<>(); + modelsByName = new ConcurrentHashMap<>(); + supportedModelsCache = + Caffeine.newBuilder() + .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) + .maximumSize(AI_MODELS_CACHE_SIZE) + .build(); } /** @@ -73,46 +104,47 @@ private AIModels() { * @param loading the list of AI models to load */ public void loadModels(final String host, final List loading) { - Optional.ofNullable(internalModels.get(host)) - .ifPresentOrElse( - model -> {}, - () -> internalModels.putIfAbsent( - host, - loading.stream() - .map(model -> Tuple.of(model.getType(), model)) - .collect(Collectors.toList()))); - loading.forEach(model -> model - .getNames() - .forEach(name -> { - final Tuple2 key = Tuple.of( - host, - name.toLowerCase().trim()); + final List> added = internalModels.putIfAbsent( + host, + loading.stream() + .map(model -> Tuple.of(model.getType(), model)) + .collect(Collectors.toList())); + loading.forEach(aiModel -> aiModel + .getModels() + .forEach(model -> { + final Tuple3 key = Tuple.of(host, model, aiModel.getType()); if (modelsByName.containsKey(key)) { - Logger.debug( - this, - String.format( + AppConfig.debugLogger( + getClass(), + () -> String.format( "Model [%s] already exists for host [%s], ignoring it", - name, + model, host)); return; } - modelsByName.putIfAbsent(key, model); + modelsByName.putIfAbsent(key, aiModel); })); + activateModels(host, added == null); } /** * Finds an AI model by the host and model name. The search is case-insensitive. * - * @param host the host for which the model is being searched + * @param appConfig the AppConfig for the host * @param modelName the name of the model to find + * @param type the type of the model to find * @return an Optional containing the found AIModel, or an empty Optional if not found */ - public Optional findModel(final String host, final String modelName) { + public Optional findModel(final AppConfig appConfig, + final String modelName, + final AIModelType type) { final String lowered = modelName.toLowerCase(); - final Set supported = getOrPullSupportedModels(); - return supported.contains(lowered) - ? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered))) - : Optional.empty(); + return Optional.ofNullable( + modelsByName.get( + Tuple.of( + appConfig.getHost(), + Model.builder().withName(lowered).build(), + type))); } /** @@ -130,6 +162,50 @@ public Optional findModel(final String host, final AIModelType type) { .findFirst()); } + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param host the host for which the model is being resolved + * @param type the type of the model to find + */ + public AIModel resolveModel(final String host, final AIModelType type) { + return findModel(host, type).orElse(AIModel.NOOP_MODEL); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + */ + public AIModel resolveAIModelOrThrow(final AppConfig appConfig, final String modelName, final AIModelType type) { + return findModel(appConfig, modelName, type) + .orElseThrow(() -> { + final String supported = String.join( + ", ", + AIModels.get().getOrPullSupportedModels(appConfig.getApiKey())); + return new DotAIModelNotFoundException( + String.format("Unable to find model: [%s]. Only [%s] are supported", modelName, supported)); + }); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return a Tuple2 containing the AIModel and the Model + */ + public Tuple2 resolveModelOrThrow(final AppConfig appConfig, + final String modelName, + final AIModelType type) { + final AIModel aiModel = resolveAIModelOrThrow(appConfig, modelName, type); + return Tuple.of(aiModel, aiModel.getModel(modelName)); + } + /** * Resets the internal models cache for the specified host. * @@ -145,27 +221,28 @@ public void resetModels(final String host) { .filter(key -> key._1.equals(host)) .collect(Collectors.toSet()) .forEach(modelsByName::remove); + cleanSupportedModelsCache(); } /** * Retrieves the list of supported models, either from the cache or by fetching them * from an external source if the cache is empty or expired. * + * @param apiKey the API key to use for fetching the supported models * @return a set of supported model names */ - public Set getOrPullSupportedModels() { + public Set getOrPullSupportedModels(final String apiKey) { final Set cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); if (CollectionUtils.isNotEmpty(cached)) { return cached; } - final AppConfig appConfig = appConfigSupplier.get(); - if (!appConfig.isEnabled()) { - AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty list of supported models"); - throw new DotRuntimeException("App dotAI config without API urls or API key"); + if (StringUtils.isBlank(apiKey)) { + Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models"); + throw new DotAIModelNotOperationalException("OpenAI is not enabled, cannot get list of models from OpenAI"); } - final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig); + final CircuitBreakerUrl.Response response = fetchOpenAIModels(apiKey); if (Objects.nonNull(response.getResponse().getError())) { throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage()); } @@ -176,7 +253,9 @@ public Set getOrPullSupportedModels() { .stream() .map(OpenAIModel::getId) .map(String::toLowerCase) + .map(String::trim) .collect(Collectors.toSet()); + supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); return supported; @@ -188,53 +267,50 @@ public Set getOrPullSupportedModels() { * @return a list of available model names */ public List getAvailableModels() { - final Set configured = internalModels.entrySet() + return internalModels.entrySet() .stream() .flatMap(entry -> entry.getValue().stream()) .map(Tuple2::_2) - .flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType()))) - .collect(Collectors.toSet()); - final Set supported = getOrPullSupportedModels() - .stream() - .map(SimpleModel::new) - .collect(Collectors.toSet()); - configured.retainAll(supported); - - return new ArrayList<>(configured); + .filter(AIModel::isOperational) + .flatMap(aiModel -> aiModel.getModels() + .stream() + .filter(Model::isOperational) + .map(model -> new SimpleModel( + model.getName(), + aiModel.getType(), + aiModel.getCurrentModelIndex() == model.getIndex()))) + .distinct() + .collect(Collectors.toList()); } - private static CircuitBreakerUrl.Response fetchOpenAIModels(final AppConfig appConfig) { - final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() - .setMethod(CircuitBreakerUrl.Method.GET) - .setUrl(AI_MODELS_API_URL) - .setTimeout(AI_MODELS_FETCH_TIMEOUT) - .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) - .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) - .setThrowWhenNot2xx(true) - .build() - .doResponse(OpenAIModels.class); + @VisibleForTesting + void cleanSupportedModelsCache() { + supportedModelsCache.invalidate(SUPPORTED_MODELS_KEY); + } - if (!CircuitBreakerUrl.isSuccessResponse(response)) { - Logger.debug( - AIModels.class, - String.format( - "Error fetching OpenAI supported models from [%s] (status code: [%d])", - AI_MODELS_API_URL, - response.getStatusCode())); - throw new DotRuntimeException("Error fetching OpenAI supported models"); + private void activateModels(final String host, boolean wasAdded) { + if (!wasAdded) { + return; } - return response; - } - - @VisibleForTesting - void setAppConfigSupplier(final Supplier appConfigSupplier) { - this.appConfigSupplier = appConfigSupplier; - } + final List aiModels = internalModels.get(host) + .stream() + .map(tuple -> tuple._2) + .collect(Collectors.toList()); - @VisibleForTesting - void cleanSupportedModelsCache() { - supportedModelsCache.invalidateAll(); + aiModels.forEach(aiModel -> + aiModel.getModels().forEach(model -> { + final String modelName = model.getName().trim().toLowerCase(); + final ModelStatus status; + status = ModelStatus.ACTIVE; + if (aiModel.getCurrentModelIndex() == -1) { + aiModel.setCurrentModelIndex(model.getIndex()); + } + Logger.debug( + this, + String.format("Model [%s] is supported by OpenAI, marking it as [%s]", modelName, status)); + model.setStatus(status); + })); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index 1053537f79f9..20df6fb078e7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,11 +1,12 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Config; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.liferay.util.StringPool; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.StringUtils; @@ -116,6 +117,15 @@ public static void setSystemHostConfig(final AppConfig systemHostConfig) { AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig); } + /** + * Retrieves the host. + * + * @return the host + */ + public String getHost() { + return host; + } + /** * Retrieves the API URL. * @@ -137,7 +147,7 @@ public String getApiImageUrl() { /** * Retrieves the API Embeddings URL. * - * @return + * @return the API Embeddings URL */ public String getApiEmbeddingsUrl() { return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl; @@ -287,33 +297,29 @@ public String getConfig(final AppKeys appKey) { * @param type the type of the model to find */ public AIModel resolveModel(final AIModelType type) { - return AIModels.get().findModel(host, type).orElse(AIModel.NOOP_MODEL); + return AIModels.get().resolveModel(host, type); } /** * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. * * @param modelName the name of the model to find + * @param type the type of the model to find */ - public AIModel resolveModelOrThrow(final String modelName) { - final AIModel aiModel = AIModels.get() - .findModel(host, modelName) - .orElseThrow(() -> { - final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels()); - return new DotRuntimeException( - "Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported "); - }); - - if (!aiModel.isOperational()) { - debugLogger( - AppConfig.class, - () -> String.format( - "Resolved model [%s] is not operational, avoiding its usage", - aiModel.getCurrentModel())); - throw new DotRuntimeException(String.format("Model [%s] is not operational", aiModel.getCurrentModel())); - } + public AIModel resolveAIModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveAIModelOrThrow(this, modelName, type); + } - return aiModel; + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return the resolved Model + */ + public Tuple2 resolveModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveModelOrThrow(this, modelName, type); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java index 115439388cc2..50f70eea4ad7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java @@ -7,6 +7,7 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; +import com.google.common.annotations.VisibleForTesting; import com.liferay.portal.model.User; import io.vavr.control.Try; @@ -23,7 +24,12 @@ public class ConfigService { private final LicenseValiditySupplier licenseValiditySupplier; private ConfigService() { - licenseValiditySupplier = new LicenseValiditySupplier() {}; + this(new LicenseValiditySupplier() {}); + } + + @VisibleForTesting + ConfigService(final LicenseValiditySupplier licenseValiditySupplier) { + this.licenseValiditySupplier = licenseValiditySupplier; } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java new file mode 100644 index 000000000000..5393bcaa35ac --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java @@ -0,0 +1,108 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import org.apache.http.client.methods.HttpDelete; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPatch; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpUriRequest; + +import javax.ws.rs.HttpMethod; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Interface representing an AI client capable of sending requests to an AI service. + * + *

+ * This interface defines methods for obtaining the AI provider and sending requests + * to the AI service. Implementations of this interface should handle the specifics + * of interacting with the AI service, including request formatting and response handling. + *

+ * + *

+ * The interface also provides a NOOP implementation that throws an + * {@link UnsupportedOperationException} for all operations. + *

+ * + * @author vico + */ +public interface AIClient { + + AIClient NOOP = new AIClient() { + @Override + public AIProvider getProvider() { + return AIProvider.NONE; + } + + @Override + public void sendRequest(final AIRequest request, final OutputStream output) { + throwUnsupported(); + } + + private void throwUnsupported() { + throw new UnsupportedOperationException("Noop client does not support sending requests"); + } + }; + + /** + * Resolves the appropriate HTTP method for the given method name and URL. + * + * @param method the HTTP method name (e.g., "GET", "POST", "PUT", "DELETE", "patch") + * @param url the URL to which the request will be sent + * @return the corresponding {@link HttpUriRequest} for the given method and URL + */ + static HttpUriRequest resolveMethod(final String method, final String url) { + switch(method) { + case HttpMethod.POST: + return new HttpPost(url); + case HttpMethod.PUT: + return new HttpPut(url); + case HttpMethod.DELETE: + return new HttpDelete(url); + case "patch": + return new HttpPatch(url); + case HttpMethod.GET: + default: + return new HttpGet(url); + } + } + + /** + * Validates and casts the given AI request to a {@link JSONObjectAIRequest}. + * + * @param the type of the request payload + * @param request the AI request to be validated and cast + * @return the validated and cast {@link JSONObjectAIRequest} + * @throws UnsupportedOperationException if the request is not an instance of {@link JSONObjectAIRequest} + */ + static JSONObjectAIRequest useRequestOrThrow(final AIRequest request) { + // When we get rid of JSONObject usage, we can remove this check + if (request instanceof JSONObjectAIRequest) { + return (JSONObjectAIRequest) request; + } + + throw new UnsupportedOperationException("Only JSONObjectAIRequest (JSONObject) is supported"); + } + + /** + * Returns the AI provider associated with this client. + * + * @return the AI provider + */ + AIProvider getProvider(); + + /** + * Sends the given AI request to the AI service and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @throws Exception if any error occurs during the request execution + */ + void sendRequest(AIRequest request, OutputStream output); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java new file mode 100644 index 000000000000..f015c5f83c13 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -0,0 +1,42 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Interface representing a strategy for handling AI client requests and responses. + * + *

+ * This interface defines a method for applying a strategy to an AI client request, + * allowing for different handling mechanisms to be implemented. The NOOP strategy + * is provided as a default implementation that performs no operations. + *

+ * + *

+ * Implementations of this interface should define how to process the AI request + * and handle the response, potentially writing the response to an output stream. + *

+ * + * @author vico + */ +public interface AIClientStrategy { + + AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build(); + + /** + * Applies the strategy to the given AI client request and handles the response. + * + * @param client the AI client to which the request is sent + * @param handler the response evaluator to handle the response + * @param request the AI request to be processed + * @param output the output stream to which the response will be written + */ + void applyStrategy(AIClient client, + AIResponseEvaluator handler, + AIRequest request, + OutputStream output); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java new file mode 100644 index 000000000000..22c3a5aec788 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -0,0 +1,34 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; + +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Default implementation of the {@link AIClientStrategy} interface. + * + *

+ * This class provides a default strategy for handling AI client requests by + * directly sending the request using the provided AI client and writing the + * response to the given output stream. + *

+ * + *

+ * The default strategy does not perform any additional processing or handling + * of the request or response, delegating the entire operation to the AI client. + *

+ * + * @author vico + */ +public class AIDefaultStrategy implements AIClientStrategy { + + @Override + public void applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream output) { + client.sendRequest(request, output); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java new file mode 100644 index 000000000000..a29c1a48454e --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -0,0 +1,238 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; +import com.dotcms.ai.validator.AIAppValidator; +import com.dotmarketing.exception.DotRuntimeException; +import io.vavr.Tuple; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.commons.io.IOUtils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +/** + * Implementation of the {@link AIClientStrategy} interface that provides a fallback mechanism + * for handling AI client requests. + * + *

+ * This class attempts to send a request using a primary AI model and, if the request fails, + * it falls back to alternative models until a successful response is obtained or all models + * are exhausted. + *

+ * + *

+ * The fallback strategy ensures that the AI client can continue to function even if some models + * are not operational or fail to process the request. + *

+ * + * @author vico + */ +public class AIModelFallbackStrategy implements AIClientStrategy { + + /** + * Applies the fallback strategy to the given AI client request and handles the response. + * + *

+ * This method first attempts to send the request using the primary model. If the request + * fails, it falls back to alternative models until a successful response is obtained or + * all models are exhausted. + *

+ * + * @param client the AI client to which the request is sent + * @param handler the response evaluator to handle the response + * @param request the AI request to be processed + * @param originalOutput the output stream to which the response will be written + * @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained + */ + @Override + public void applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream originalOutput) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final Tuple2 modelTuple = resolveModel(jsonRequest); + + final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, originalOutput, modelTuple); + if (firstAttempt.isSuccess()) { + return; + } + + runFallbacks(client, handler, jsonRequest, originalOutput, modelTuple); + } + + private static Tuple2 resolveModel(final JSONObjectAIRequest request) { + final String modelName = request.getPayload().optString(AiKeys.MODEL); + return request.getConfig().resolveModelOrThrow(modelName, request.getType()); + } + + private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { + if (firstAttempt.equals(model)) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the same as the current one [%s].", + model.getName(), + firstAttempt.getName())); + return true; + } + + return false; + } + + private static boolean isOperational(final Model model) { + if (!model.isOperational()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] is not operational. Skipping.", model.getName())); + return false; + } + + return true; + } + + private static AIResponseData doSend(final AIClient client, final AIRequest request) { + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + client.sendRequest(request, output); + + final AIResponseData responseData = new AIResponseData(); + responseData.setResponse(output.toString()); + IOUtils.closeQuietly(output); + + return responseData; + } + + private static void redirectOutput(final OutputStream originalOutput, final String response) { + try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { + IOUtils.copy(input, originalOutput); + } catch (IOException e) { + throw new DotRuntimeException(e); + } + } + + private static void notifyFailure(final AIModel aiModel, final AIRequest request) { + AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId()); + } + + private static void handleFailure(final Tuple2 modelTuple, + final AIRequest request, + final AIResponseData responseData) { + final AIModel aiModel = modelTuple._1; + final Model model = modelTuple._2; + + if (!responseData.getStatus().doesNeedToThrow()) { + model.setStatus(responseData.getStatus()); + } + + if (model.getIndex() == aiModel.getModels().size() - 1) { + aiModel.setCurrentModelIndex(-1); + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the last one. Cannot fallback anymore.", + model.getName())); + + notifyFailure(aiModel, request); + + throw new DotAIAllModelsExhaustedException( + String.format("All models for type [%s] has been exhausted.", aiModel.getType())); + } else { + aiModel.setCurrentModelIndex(model.getIndex() + 1); + } + } + + private static AIResponseData sendAttempt(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream originalOutput, + final Tuple2 modelTuple) { + + final AIResponseData responseData = Try + .of(() -> doSend(client, request)) + .getOrElseGet(exception -> fromException(evaluator, exception)); + + if (!responseData.isSuccess()) { + if (responseData.getStatus().doesNeedToThrow()) { + throw responseData.getException(); + } + } else { + evaluator.fromResponse(responseData.getResponse(), responseData); + } + + if (responseData.isSuccess()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); + redirectOutput(originalOutput, responseData.getResponse()); + } else { + logFailure(modelTuple, responseData); + + handleFailure(modelTuple, request, responseData); + } + + return responseData; + } + + private static void logFailure(final Tuple2 modelTuple, final AIResponseData responseData) { + Optional + .ofNullable(responseData.getResponse()) + .ifPresentOrElse( + response -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with response:%s%s%s. Trying next model.", + modelTuple._2.getName(), + System.lineSeparator(), + response, + System.lineSeparator())), + () -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with error: [%s]. Trying next model.", + modelTuple._2.getName(), + responseData.getError()))); + } + + private static AIResponseData fromException(final AIResponseEvaluator evaluator, final Throwable exception) { + final AIResponseData metadata = new AIResponseData(); + evaluator.fromException(exception, metadata); + return metadata; + } + + private static void runFallbacks(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream originalOutput, + final Tuple2 modelTuple) { + for(final Model model : modelTuple._1.getModels()) { + if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) { + continue; + } + + request.getPayload().put(AiKeys.MODEL, model.getName()); + final AIResponseData responseData = sendAttempt( + client, + evaluator, + request, + originalOutput, + Tuple.of(modelTuple._1, model)); + if (responseData.isSuccess()) { + return; + } + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java new file mode 100644 index 000000000000..839a27cedc1e --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -0,0 +1,86 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Optional; + +/** + * A proxy client for interacting with an AI service using a specified strategy. + * + *

+ * This class provides a mechanism to send requests to an AI service through a proxied client, + * applying a given strategy for handling the requests and responses. It supports a NOOP implementation + * that performs no operations. + *

+ * + *

+ * The class allows for the creation of proxied clients with different strategies and response evaluators, + * enabling flexible handling of AI service interactions. + *

+ * + * @author vico + */ +public class AIProxiedClient { + + public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP, null); + + private final AIClient client; + private final AIClientStrategy strategy; + private final AIResponseEvaluator responseEvaluator; + + private AIProxiedClient(final AIClient client, + final AIClientStrategy strategy, + final AIResponseEvaluator responseEvaluator) { + this.client = client; + this.strategy = strategy; + this.responseEvaluator = responseEvaluator; + } + + /** + * Creates an AIProxiedClient with the specified client, strategy, and response evaluator. + * + * @param client the AI client to be proxied + * @param strategy the strategy to be applied for handling requests and responses + * @param responseParser the response evaluator to process responses + * @return a new instance of AIProxiedClient + */ + public static AIProxiedClient of(final AIClient client, + final AIProxyStrategy strategy, + final AIResponseEvaluator responseParser) { + return new AIProxiedClient(client, strategy.getStrategy(), responseParser); + } + + /** + * Creates an AIProxiedClient with the specified client and strategy. + * + * @param client the AI client to be proxied + * @param strategy the strategy to be applied for handling requests and responses + * @return a new instance of AIProxiedClient + */ + public static AIProxiedClient of(final AIClient client, final AIProxyStrategy strategy) { + return of(client, strategy, null); + } + + /** + * Sends the given AI request to the AI service and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse callToAI(final AIRequest request, final OutputStream output) { + final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new); + + strategy.applyStrategy(client, responseEvaluator, request, finalOutput); + + return Optional.ofNullable(output) + .map(out -> AIResponse.EMPTY) + .orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build()); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java new file mode 100644 index 000000000000..c4c2f9643da4 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java @@ -0,0 +1,121 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.client.openai.OpenAIClient; +import com.dotcms.ai.client.openai.OpenAIResponseEvaluator; +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import com.google.common.annotations.VisibleForTesting; +import io.vavr.Lazy; + +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A proxy client for managing and interacting with multiple AI service providers. + * + *

+ * This class provides a mechanism to send requests to various AI service providers through proxied clients, + * applying different strategies for handling the requests and responses. It supports adding new clients and + * switching between different AI providers. + *

+ * + *

+ * The class allows for flexible handling of AI service interactions by maintaining a map of proxied clients + * and providing methods to send requests to the current or specified provider. + *

+ * + * @author vico + */ +public class AIProxyClient { + + private static final Lazy INSTANCE = Lazy.of(AIProxyClient::new); + + private final ConcurrentMap proxiedClients; + private final AtomicReference currentProvider; + + private AIProxyClient() { + proxiedClients = new ConcurrentHashMap<>(); + addClient( + AIProvider.OPEN_AI, + AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseEvaluator.get())); + currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); + } + + @VisibleForTesting + AIProxyClient(final AIProxiedClient client) { + proxiedClients = new ConcurrentHashMap<>(); + addClient(AIProvider.OPEN_AI, client); + currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); + } + + public static AIProxyClient get() { + return INSTANCE.get(); + } + + /** + * Adds a proxied client for the specified AI provider. + * + * @param provider the AI provider for which the client is added + * @param client the proxied client to be added + */ + public void addClient(final AIProvider provider, final AIProxiedClient client) { + proxiedClients.put(provider, client); + } + + /** + * Sends the given AI request to the specified AI provider and writes the response to the provided output stream. + * + * @param provider the AI provider to which the request is sent + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse sendRequest(final AIProvider provider, + final AIRequest request, + final OutputStream output) { + return Optional.ofNullable(proxiedClients.getOrDefault(provider, AIProxiedClient.NOOP)) + .map(client -> client.callToAI(request, output)) + .orElse(AIResponse.EMPTY); + } + + /** + * Sends the given AI request to the specified AI provider. + * + * @param the type of the request payload + * @param provider the AI provider to which the request is sent + * @param request the AI request to be sent + * @return the AI response + */ + public AIResponse sendRequest(final AIProvider provider, final AIRequest request) { + return sendRequest(provider, request, null); + } + + /** + * Sends the given AI request to the current AI provider and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse sendRequest(final AIRequest request, final OutputStream output) { + return sendRequest(currentProvider.get(), request, output); + } + + /** + * Sends the given AI request to the current AI provider. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @return the AI response + */ + public AIResponse sendRequest(final AIRequest request) { + return sendRequest(request, null); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java new file mode 100644 index 000000000000..08b2c34f0a6b --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java @@ -0,0 +1,34 @@ +package com.dotcms.ai.client; + +/** + * Enumeration representing different strategies for proxying AI client requests. + * + *

+ * This enum provides different strategies for handling AI client requests, including + * a default strategy and a model fallback strategy. Each strategy is associated with + * an implementation of the {@link AIClientStrategy} interface. + *

+ * + *

+ * The strategies can be used to customize the behavior of AI client interactions, + * allowing for flexible handling of requests and responses. + *

+ * + * @author vico + */ +public enum AIProxyStrategy { + + DEFAULT(new AIDefaultStrategy()), + MODEL_FALLBACK(new AIModelFallbackStrategy()); + + private final AIClientStrategy strategy; + + AIProxyStrategy(final AIClientStrategy strategy) { + this.strategy = strategy; + } + + public AIClientStrategy getStrategy() { + return strategy; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java new file mode 100644 index 000000000000..428cb86e6f32 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java @@ -0,0 +1,35 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIResponseData; + +/** + * Interface for evaluating AI responses. + * It provides methods to process responses and exceptions, updating the provided metadata. + * + *

Methods:

+ *
    + *
  • \fromResponse\ - Processes a response string and updates the metadata.
  • + *
  • \fromThrowable\ - Processes an exception and updates the metadata.
  • + *
+ * + * @author vico + */ +public interface AIResponseEvaluator { + + /** + * Processes a response string and updates the metadata. + * + * @param response the response string to process + * @param metadata the metadata to update based on the response + */ + void fromResponse(String response, AIResponseData metadata); + + /** + * Processes an exception and updates the metadata. + * + * @param exception the exception to process + * @param metadata the metadata to update based on the exception + */ + void fromException(Throwable exception, AIResponseData metadata); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java new file mode 100644 index 000000000000..a698477bd381 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -0,0 +1,159 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.client.AIClient; +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIAppConfigDisabledException; +import com.dotcms.ai.exception.DotAIClientConnectException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; +import com.dotmarketing.util.Logger; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; + +import javax.ws.rs.core.MediaType; +import java.io.BufferedInputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of the {@link AIClient} interface for interacting with the OpenAI service. + * + *

+ * This class provides methods to send requests to the OpenAI service and handle responses. + * It includes functionality to manage rate limiting and ensure that models are operational + * before sending requests. + *

+ * + *

+ * The class uses a singleton pattern to ensure a single instance of the client is used + * throughout the application. It also maintains a record of the last REST call for each + * model to enforce rate limiting. + *

+ * + * @auhor vico + */ +public class OpenAIClient implements AIClient { + + private static final Lazy INSTANCE = Lazy.of(OpenAIClient::new); + + private final ConcurrentHashMap lastRestCall; + + public static OpenAIClient get() { + return INSTANCE.get(); + } + + private OpenAIClient() { + lastRestCall = new ConcurrentHashMap<>(); + } + + /** + * {@inheritDoc} + */ + @Override + public AIProvider getProvider() { + return AIProvider.OPEN_AI; + } + + /** + * {@inheritDoc} + */ + @Override + public void sendRequest(final AIRequest request, final OutputStream output) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final AppConfig appConfig = jsonRequest.getConfig(); + + AppConfig.debugLogger( + OpenAIClient.class, + () -> String.format( + "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", + jsonRequest.getUrl(), + jsonRequest.getMethod(), + System.lineSeparator(), + appConfig.toString(), + System.lineSeparator(), + jsonRequest.payloadToString())); + + if (!appConfig.isEnabled()) { + AppConfig.debugLogger(OpenAIClient.class, () -> "App dotAI is not enabled and will not send request."); + throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key"); + } + + final JSONObject payload = jsonRequest.getPayload(); + final String modelName = payload.optString(AiKeys.MODEL); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(modelName, jsonRequest.getType()); + final AIModel aiModel = modelTuple._1; + + if (!modelTuple._2.isOperational()) { + AppConfig.debugLogger( + getClass(), + () -> String.format("Resolved model [%s] is not operational, avoiding its usage", modelName)); + throw new DotAIModelNotOperationalException(String.format("Model [%s] is not operational", modelName)); + } + + final long sleep = lastRestCall.computeIfAbsent(aiModel, m -> 0L) + + aiModel.minIntervalBetweenCalls() + - System.currentTimeMillis(); + if (sleep > 0) { + Logger.info( + this, + "Rate limit:" + + aiModel.getApiPerMinute() + + "/minute, or 1 every " + + aiModel.minIntervalBetweenCalls() + + "ms. Sleeping:" + + sleep); + Try.run(() -> Thread.sleep(sleep)); + } + + lastRestCall.put(aiModel, System.currentTimeMillis()); + + try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + final StringEntity jsonEntity = new StringEntity(payload.toString(), ContentType.APPLICATION_JSON); + final HttpUriRequest httpRequest = AIClient.resolveMethod(jsonRequest.getMethod(), jsonRequest.getUrl()); + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); + httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); + + if (!payload.getAsMap().isEmpty()) { + Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); + } + + try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { + final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); + final byte[] buffer = new byte[1024]; + int len; + while ((len = in.read(buffer)) != -1) { + output.write(buffer, 0, len); + output.flush(); + } + } + } catch (Exception e) { + if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)){ + Logger.warn(this, "INVALID REQUEST: " + e.getMessage(), e); + } else { + Logger.warn(this, "INVALID REQUEST: " + e.getMessage()); + } + + Logger.warn(this, " - " + jsonRequest.getMethod() + " : " + payload); + + throw new DotAIClientConnectException("Error while sending request to OpenAI", e); + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java new file mode 100644 index 000000000000..7be176a4f446 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java @@ -0,0 +1,87 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.client.AIResponseEvaluator; +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; + +import java.util.Optional; +import java.util.stream.Stream; + +/** + * Evaluates AI responses from OpenAI and updates the provided metadata. + * This class implements the singleton pattern and provides methods to process responses and exceptions. + * + *

Methods:

+ *
    + *
  • \fromResponse\ - Processes a response string and updates the metadata.
  • + *
  • \fromThrowable\ - Processes an exception and updates the metadata.
  • + *
+ * + * @author vico + */ +public class OpenAIResponseEvaluator implements AIResponseEvaluator { + + private static final Lazy INSTANCE = Lazy.of(OpenAIResponseEvaluator::new); + + public static OpenAIResponseEvaluator get() { + return INSTANCE.get(); + } + + private OpenAIResponseEvaluator() { + } + + /** + * {@inheritDoc} + */ + @Override + public void fromResponse(final String response, final AIResponseData metadata) { + Optional.ofNullable(response) + .ifPresent(resp -> { + final JSONObject jsonResponse = new JSONObject(resp); + if (jsonResponse.has(AiKeys.ERROR)) { + final String error = jsonResponse.getString(AiKeys.ERROR); + metadata.setError(error); + metadata.setStatus(resolveStatus(error)); + } + }); + } + + /** + * {@inheritDoc} + */ + @Override + public void fromException(final Throwable exception, final AIResponseData metadata) { + metadata.setError(exception.getMessage()); + metadata.setStatus(resolveStatus(exception)); + metadata.setException(exception instanceof DotRuntimeException + ? (DotRuntimeException) exception + : new DotRuntimeException(exception)); + } + + private ModelStatus resolveStatus(final String error) { + if (error.contains("has been deprecated")) { + return ModelStatus.DECOMMISSIONED; + } else if (error.contains("does not exist or you do not have access to it")) { + return ModelStatus.INVALID; + } else { + return ModelStatus.UNKNOWN; + } + } + + private ModelStatus resolveStatus(final Throwable throwable) { + if (Stream + .of(DotAIModelNotFoundException.class, DotAIModelNotOperationalException.class) + .anyMatch(exception -> exception.isInstance(throwable))) { + return ModelStatus.INVALID; + } else { + return ModelStatus.UNKNOWN; + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java b/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java index a5afa4b2bdc9..d621b2e0cd57 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java @@ -87,7 +87,8 @@ public static Builder from(final CompletionsForm form) { .withOperator(form.operator) .withThreshold(form.threshold) .withTemperature(form.temperature) - .withTokenCount(form.responseLengthTokens); + .withTokenCount(form.responseLengthTokens) + .withUser(form.user); } public static Builder from(final Map form) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java new file mode 100644 index 000000000000..9e844d47619f --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java @@ -0,0 +1,35 @@ +package com.dotcms.ai.domain; + +/** + * Enumeration representing different AI service providers. + * + *

+ * This enum defines various AI service providers that can be used within the application. + * Each provider is associated with a specific name that identifies the AI service. + *

+ * + *

+ * The providers can be used to configure and manage interactions with different AI services, + * allowing for flexible integration and switching between multiple AI providers. + *

+ * + * @author vico + */ +public enum AIProvider { + + NONE("None"), + OPEN_AI("OpenAI"), + BEDROCK("Amazon Bedrock"), + GEMINI("Google Gemini"); + + private final String provider; + + AIProvider(final String provider) { + this.provider = provider; + } + + public String getProvider() { + return provider; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java new file mode 100644 index 000000000000..2635267c3e2a --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java @@ -0,0 +1,219 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; + +import javax.ws.rs.HttpMethod; +import java.io.Serializable; + +/** + * Represents a request to an AI service. + * + *

+ * This class encapsulates the details of an AI request, including the URL, HTTP method, + * configuration, model type, payload, and user ID. It provides methods to create and + * configure AI requests for different model types such as text, image, and embeddings. + *

+ * + * @param the type of the request payload + * @author vico + */ +public class AIRequest { + + private final String url; + private final String method; + private final AppConfig config; + private final AIModelType type; + private final T payload; + private final String userId; + + > AIRequest(final Builder builder) { + this.url = builder.url; + this.method = builder.method; + this.config = builder.config; + this.type = builder.type; + this.payload = builder.payload; + this.userId = builder.userId; + } + + /** + * Creates a quick text AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickText(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.TEXT, appConfig, payload, userId); + } + + /** + * Creates a quick image AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickImage(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.IMAGE, appConfig, payload, userId); + } + + /** + * Creates a quick embeddings AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickEmbeddings(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.EMBEDDINGS, appConfig, payload, userId); + } + + public static > Builder builder() { + return new Builder<>(); + } + + /** + * Resolves the URL for the specified model type and application configuration. + * + * @param type the AI model type + * @param appConfig the application configuration + * @return the resolved URL + */ + static String resolveUrl(final AIModelType type, final AppConfig appConfig) { + switch (type) { + case TEXT: + return appConfig.getApiUrl(); + case IMAGE: + return appConfig.getApiImageUrl(); + case EMBEDDINGS: + return appConfig.getApiEmbeddingsUrl(); + default: + throw new IllegalArgumentException("Invalid AIModelType: " + type); + } + } + + @SuppressWarnings("unchecked") + private static , R extends AIRequest> R quick( + final String url, + final AppConfig appConfig, + final AIModelType type, + final T payload, + final String usderId) { + return (R) AIRequest.builder() + .withUrl(url) + .withConfig(appConfig) + .withType(type) + .withPayload(payload) + .withUserId(usderId) + .build(); + } + + private static > R quick( + final AIModelType type, + final AppConfig appConfig, + final T payload, + final String userId) { + return quick(resolveUrl(type, appConfig), appConfig, type, payload, userId); + } + + public String getUrl() { + return url; + } + + public String getMethod() { + return method; + } + + public AppConfig getConfig() { + return config; + } + + public AIModelType getType() { + return type; + } + + public T getPayload() { + return payload; + } + + public String getUserId() { + return userId; + } + + @Override + public String toString() { + return "AIRequest{" + + "url='" + url + '\'' + + ", method='" + method + '\'' + + ", config=" + config + + ", type=" + type + + ", payload=" + payloadToString() + + ", userId='" + userId + '\'' + + '}'; + } + + public String payloadToString() { + return payload.toString(); + } + + public static class Builder> { + + String url; + String method = HttpMethod.POST; + AppConfig config; + AIModelType type; + T payload; + String userId; + + @SuppressWarnings("unchecked") + B self() { + return (B) this; + } + + public B withUrl(final String url) { + this.url = url; + return self(); + } + + public B withConfig(final AppConfig config) { + this.config = config; + return self(); + } + + public B withType(final AIModelType type) { + this.type = type; + return self(); + } + + public B withPayload(final T payload) { + this.payload = payload; + return self(); + } + + public B withUserId(final String userId) { + this.userId = userId; + return self(); + } + + public AIRequest build() { + return new AIRequest<>(this); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java new file mode 100644 index 000000000000..8d9887b24571 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java @@ -0,0 +1,50 @@ +package com.dotcms.ai.domain; + +/** + * Represents a response from an AI service. + * + *

+ * This class encapsulates the details of an AI response, including the response content. + * It provides methods to build and retrieve the response. + *

+ * + *

+ * The class also provides a static instance representing an empty response. + *

+ * + * @author vico + */ +public class AIResponse { + + public static final AIResponse EMPTY = builder().build(); + + private final String response; + + private AIResponse(final Builder builder) { + this.response = builder.response; + } + + public static Builder builder() { + return new Builder(); + } + + public String getResponse() { + return response; + } + + public static class Builder { + + private String response; + + public Builder withResponse(final String response) { + this.response = response; + return this; + } + + + public AIResponse build() { + return new AIResponse(this); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java new file mode 100644 index 000000000000..85ac2d9d0483 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java @@ -0,0 +1,68 @@ +package com.dotcms.ai.domain; + +import com.dotmarketing.exception.DotRuntimeException; +import org.apache.commons.lang3.StringUtils; + +/** + * Represents the data of a response from an AI service. + * + *

+ * This class encapsulates the details of an AI response, including the response content, error message, + * status, and any exceptions that may have occurred. It provides methods to retrieve and set these details, + * as well as a method to check if the response was successful. + *

+ * + * @author vico + */ +public class AIResponseData { + + private String response; + private String error; + private ModelStatus status; + private DotRuntimeException exception; + + public String getResponse() { + return response; + } + + public void setResponse(String response) { + this.response = response; + } + + public String getError() { + return error; + } + + public void setError(final String error) { + this.error = error; + } + + public ModelStatus getStatus() { + return status; + } + + public void setStatus(ModelStatus status) { + this.status = status; + } + + public DotRuntimeException getException() { + return exception; + } + + public void setException(DotRuntimeException exception) { + this.exception = exception; + } + + public boolean isSuccess() { + return StringUtils.isBlank(error); + } + + @Override + public String toString() { + return "AIResponseData{" + + "response='" + response + '\'' + + ", error='" + error + '\'' + + ", status=" + status + + '}'; + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java new file mode 100644 index 000000000000..e52e08c5df14 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java @@ -0,0 +1,105 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; +import com.dotmarketing.util.json.JSONObject; + +/** + * Represents a request to an AI service with a JSON payload. + * + *

+ * This class encapsulates the details of an AI request with a JSON payload, including the URL, HTTP method, + * configuration, model type, payload, and user ID. It provides methods to create and configure AI requests + * for different model types such as text, image, and embeddings. + *

+ * + * @author vico + */ +public class JSONObjectAIRequest extends AIRequest { + + JSONObjectAIRequest(final Builder builder) { + super(builder); + } + + /** + * Creates a quick text AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickText(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.TEXT, appConfig, payload, userId); + } + + /** + * Creates a quick image AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickImage(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.IMAGE, appConfig, payload, userId); + } + + /** + * Creates a quick embeddings AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickEmbeddings(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.EMBEDDINGS, appConfig, payload, userId); + } + + private static JSONObjectAIRequest quick(final String url, + final AppConfig appConfig, + final AIModelType type, + final JSONObject payload, + final String userId) { + return JSONObjectAIRequest.builder() + .withUrl(url) + .withConfig(appConfig) + .withType(type) + .withPayload(payload) + .withUserId(userId) + .build(); + } + + private static JSONObjectAIRequest quick(final AIModelType type, + final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(resolveUrl(type, appConfig), appConfig, type, payload, userId); + } + + @Override + public String payloadToString() { + return getPayload().toString(2); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends AIRequest.Builder { + + public JSONObjectAIRequest build() { + return new JSONObjectAIRequest(this); + } + + } + + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java new file mode 100644 index 000000000000..7b2b9ca150ba --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java @@ -0,0 +1,104 @@ +package com.dotcms.ai.domain; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Represents an AI model with a name, status, and index. + * + *

+ * This class encapsulates the details of an AI model, including its name, status, and index. + * It provides methods to retrieve and set these details, as well as methods to check if the model is operational. + *

+ * + *

+ * The class also provides a builder for constructing instances of the model. + *

+ * + * @author vico + */ +public class Model { + + private final String name; + private final AtomicReference status; + private final AtomicInteger index; + + private Model(final Builder builder) { + name = builder.name; + status = new AtomicReference<>(null); + index = new AtomicInteger(builder.index); + } + + public static Builder builder() { + return new Builder(); + } + + public String getName() { + return name; + } + + public ModelStatus getStatus() { + return status.get(); + } + + public void setStatus(final ModelStatus status) { + this.status.set(status); + } + + public int getIndex() { + return index.get(); + } + + public void setIndex(final int index) { + this.index.set(index); + } + + public boolean isOperational() { + return ModelStatus.ACTIVE == status.get(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Model model = (Model) o; + return Objects.equals(name, model.name); + } + + @Override + public int hashCode() { + return Objects.hashCode(name); + } + + @Override + public String toString() { + return "Model{" + + "name='" + name + '\'' + + ", status=" + status + + ", index=" + index.get() + + '}'; + } + + public static class Builder { + + private String name; + private int index; + + public Builder withName(final String name) { + this.name = name.toLowerCase().trim(); + return this; + } + + public Builder withIndex(final int index) { + this.index = index; + return this; + } + + public Model build() { + return new Model(this); + } + + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java b/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java new file mode 100644 index 000000000000..15aa6bd9b69c --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java @@ -0,0 +1,30 @@ +package com.dotcms.ai.domain; + +/** + * Represents the status of an AI model. + * + *

+ * This enum defines various statuses that an AI model can have, such as active, invalid, decommissioned, or unknown. + * Each status may have different implications for the operation of the model. + *

+ * + * @author vico + */ +public enum ModelStatus { + + ACTIVE(false), + INVALID(false), + DECOMMISSIONED(false), + UNKNOWN(true); + + private final boolean needsToThrow; + + ModelStatus(final boolean needsToThrow) { + this.needsToThrow = needsToThrow; + } + + public boolean doesNeedToThrow() { + return needsToThrow; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java new file mode 100644 index 000000000000..6833fdabb252 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java @@ -0,0 +1,22 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when all AI models have been exhausted. + * + *

+ * This exception is used to indicate that all available AI models have been exhausted and no further models + * are available for processing. It extends the {@link DotRuntimeException} to provide additional context + * specific to AI model exhaustion scenarios. + *

+ * + * @author vico + */ +public class DotAIAllModelsExhaustedException extends DotRuntimeException { + + public DotAIAllModelsExhaustedException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java new file mode 100644 index 000000000000..5b549c74fed5 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java @@ -0,0 +1,22 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when the AI application configuration is disabled. + * + *

+ * This exception is used to indicate that the AI application configuration is disabled and cannot be used. + * It extends the {@link DotRuntimeException} to provide additional context specific to AI application configuration + * disabled scenarios. + *

+ * + * @author vico + */ +public class DotAIAppConfigDisabledException extends DotRuntimeException { + + public DotAIAppConfigDisabledException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java new file mode 100644 index 000000000000..e17bf9b8d09f --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when there is a connection error with the AI client. + * + *

+ * This exception is used to indicate that there is a connection error with the AI client. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI client connection error scenarios. + *

+ * + * @author vico + */ +public class DotAIClientConnectException extends DotRuntimeException { + + public DotAIClientConnectException(final String message, final Throwable cause) { + super(message, cause); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java new file mode 100644 index 000000000000..3bd70811b123 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when an AI model is not found. + * + *

+ * This exception is used to indicate that a specific AI model could not be found. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI model not found scenarios. + *

+ * + * @author vico + */ +public class DotAIModelNotFoundException extends DotRuntimeException { + + public DotAIModelNotFoundException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java new file mode 100644 index 000000000000..ddea5f6866f8 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when there is a connection error with the AI client. + * + *

+ * This exception is used to indicate that there is a connection error with the AI client. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI client connection error scenarios. + *

+ * + * @author vico + */ +public class DotAIModelNotOperationalException extends DotRuntimeException { + + public DotAIModelNotOperationalException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java index 32bd759b58b2..a95b8f8df520 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java @@ -1,8 +1,10 @@ package com.dotcms.ai.listener; import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.validator.AIAppValidator; import com.dotcms.security.apps.AppSecretSavedEvent; import com.dotcms.system.event.local.model.EventSubscriber; import com.dotcms.system.event.local.model.KeyFilterable; @@ -35,6 +37,21 @@ public AIAppListener() { this(APILocator.getHostAPI()); } + /** + * Notifies the listener of an {@link AppSecretSavedEvent}. + * + *

+ * This method is called when an {@link AppSecretSavedEvent} occurs. It performs the following actions: + *

    + *
  • Logs a debug message if the event is null or the event's host identifier is blank.
  • + *
  • Finds the host associated with the event's host identifier.
  • + *
  • Resets the AI models for the found host's hostname.
  • + *
  • Validates the AI configuration using the {@link AIAppValidator}.
  • + *
+ *

+ * + * @param event the {@link AppSecretSavedEvent} that triggered the notification + */ @Override public void notify(final AppSecretSavedEvent event) { if (Objects.isNull(event)) { @@ -51,7 +68,9 @@ public void notify(final AppSecretSavedEvent event) { final Host host = Try.of(() -> hostAPI.find(hostId, APILocator.systemUser(), false)).getOrNull(); Optional.ofNullable(host).ifPresent(found -> AIModels.get().resetModels(found.getHostname())); - ConfigService.INSTANCE.config(host); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + + AIAppValidator.get().validateAIConfig(appConfig, event.getUserId()); } @Override diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index 9739bab313eb..5c5a7b24d5ef 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -3,6 +3,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.exception.DotAIAppConfigDisabledException; import com.dotcms.content.elasticsearch.business.event.ContentletArchiveEvent; import com.dotcms.content.elasticsearch.business.event.ContentletDeletedEvent; import com.dotcms.content.elasticsearch.business.event.ContentletPublishEvent; @@ -10,7 +11,6 @@ import com.dotcms.system.event.local.model.Subscriber; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.portlets.contentlet.model.ContentletListener; import com.dotmarketing.util.Logger; @@ -86,7 +86,7 @@ private AppConfig getAppConfig(final String hostId) { AppConfig.debugLogger( getClass(), () -> "dotAI is not enabled since no API urls or API key found in app config"); - throw new DotRuntimeException("App dotAI config without API urls or API key"); + throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key"); } return appConfig; diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java index 53f83c3ab149..e289b64c9a0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java @@ -1,6 +1,7 @@ package com.dotcms.ai.model; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; @@ -15,12 +16,11 @@ public class AIImageRequestDTO { private final String model; - public AIImageRequestDTO(Builder builder) { + public AIImageRequestDTO(final Builder builder) { this.numberOfImages = builder.numberOfImages; this.model = builder.model; this.prompt = builder.prompt; this.size = builder.size; - } public String getSize() { @@ -40,14 +40,15 @@ public String getModel() { } public static class Builder { + private AppConfig appConfig = ConfigService.INSTANCE.config(); @JsonSetter(nulls = Nulls.SKIP) private String prompt; @JsonSetter(nulls = Nulls.SKIP) private int numberOfImages = 1; @JsonSetter(nulls = Nulls.SKIP) - private String size = ConfigService.INSTANCE.config().getImageSize(); + private String size = appConfig.getImageSize(); @JsonSetter(nulls = Nulls.SKIP) - private String model = ConfigService.INSTANCE.config().getImageModel().getCurrentModel(); + private String model = appConfig.getImageModel().getCurrentModel(); public AIImageRequestDTO build() { return new AIImageRequestDTO(this); diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java index c5486b61191f..b24e042e853f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java @@ -17,16 +17,20 @@ public class SimpleModel implements Serializable { private final String name; private final AIModelType type; + private final boolean current; @JsonCreator - public SimpleModel(@JsonProperty("name") final String name, @JsonProperty("type") final AIModelType type) { + public SimpleModel(@JsonProperty("name") final String name, + @JsonProperty("type") final AIModelType type, + @JsonProperty("current") final boolean current) { this.name = name; this.type = type; + this.current = current; } @JsonCreator public SimpleModel(@JsonProperty("name") final String name) { - this(name, null); + this(name, null, false); } public String getName() { @@ -37,17 +41,30 @@ public AIModelType getType() { return type; } + public boolean isCurrent() { + return current; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SimpleModel that = (SimpleModel) o; - return Objects.equals(name, that.name); + return Objects.equals(name, that.name) && type == that.type; } @Override public int hashCode() { - return Objects.hashCode(name); + return Objects.hash(name, type); + } + + @Override + public String toString() { + return "SimpleModel{" + + "name='" + name + '\'' + + ", type=" + type + + ", current=" + current + + '}'; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index e7b62cf46712..5499de4ce660 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -61,7 +61,9 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().summarizeStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .summarizeStream(formIn, new LineReadingOutputStream(output))); } /** @@ -84,7 +86,9 @@ public final Response rawPrompt(@Context final HttpServletRequest request, response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().rawStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .rawStream(formIn, new LineReadingOutputStream(output))); } /** @@ -107,16 +111,15 @@ public final Response getConfig(@Context final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - final AppConfig app = ConfigService.INSTANCE.config(host); - + final AppConfig appConfig = ConfigService.INSTANCE.config(host); final Map map = new HashMap<>(); map.put(AiKeys.CONFIG_HOST, host.getHostname() + " (falls back to system host)"); for (final AppKeys config : AppKeys.values()) { - map.put(config.key, app.getConfig(config)); + map.put(config.key, appConfig.getConfig(config)); } - final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET"; + final String apiKey = UtilMethods.isSet(appConfig.getApiKey()) ? "*****" : "NOT SET"; map.put(AppKeys.API_KEY.key, apiKey); final List models = AIModels.get().getAvailableModels(); @@ -140,19 +143,25 @@ private static CompletionsForm resolveForm(final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - return (!user.isAdmin()) - ? CompletionsForm - .copy(formIn) - .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) - .build() - : formIn; + return withUserId( + !user.isAdmin() + ? CompletionsForm + .copy(formIn) + .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) + .build() + : formIn, + user); + } + + private static CompletionsForm withUserId(final CompletionsForm completionsForm, final User user) { + return CompletionsForm.copy(completionsForm).user(user).build(); } private static Response getResponse(final HttpServletRequest request, final HttpServletResponse response, final CompletionsForm formIn, final Supplier noStream, - final Consumer stream) { + final Consumer outputStream) { if (StringUtils.isBlank(formIn.prompt)) { return badRequestResponse(); } @@ -162,7 +171,7 @@ private static Response getResponse(final HttpServletRequest request, if (resolvedForm.stream) { final StreamingOutput streaming = output -> { - stream.accept(output); + outputStream.accept(output); output.flush(); output.close(); }; @@ -174,5 +183,4 @@ private static Response getResponse(final HttpServletRequest request, return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java index 375625d58adf..e536de66e87c 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java @@ -1,7 +1,7 @@ package com.dotcms.ai.rest; import com.dotcms.ai.AiKeys; -import com.dotcms.ai.Marshaller; +import com.dotcms.ai.util.Marshaller; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.model.AIImageRequestDTO; diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java index fae06a565d3b..f0a05c50f4a4 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java @@ -10,6 +10,7 @@ import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -56,7 +57,7 @@ public Response doGet(@Context final HttpServletRequest request, * * @param request the HTTP request * @param response the HTTP response - * @param formIn the form data containing the prompt + * @param form the form data containing the prompt * @return a Response object containing the generated text * @throws IOException if an I/O error occurs */ @@ -65,13 +66,14 @@ public Response doGet(@Context final HttpServletRequest request, @Produces(MediaType.APPLICATION_JSON) public Response doPost(@Context final HttpServletRequest request, @Context final HttpServletResponse response, - final CompletionsForm formIn) throws IOException { + final CompletionsForm form) throws IOException { - new WebResource.InitBuilder(request, response) + final User user = new WebResource.InitBuilder(request, response) .requiredBackendUser(true) .requiredFrontendUser(true) .init() .getUser(); + final CompletionsForm formIn = CompletionsForm.copy(form).user(user).build(); if (UtilMethods.isEmpty(formIn.prompt)) { return Response @@ -82,7 +84,12 @@ public Response doPost(@Context final HttpServletRequest request, final AppConfig config = ConfigService.INSTANCE.config(WebAPILocator.getHostWebAPI().getHost(request)); - return Response.ok(APILocator.getDotAIAPI().getCompletionsAPI().raw(generateRequest(formIn, config)).toString()).build(); + return Response.ok( + APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(generateRequest(formIn, config), user.getUserId()) + .toString()) + .build(); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java index f4eb199d4bf2..2e1f58923556 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.liferay.portal.model.User; import io.vavr.control.Try; import javax.validation.constraints.Max; @@ -49,6 +50,7 @@ public class CompletionsForm { public final String model; public final String operator; public final String site; + public final User user; @Override public boolean equals(final Object o) { @@ -88,6 +90,7 @@ public String toString() { ", operator='" + operator + '\'' + ", site='" + site + '\'' + ", contentType=" + Arrays.toString(contentType) + + ", user=" + user + '}'; } @@ -118,6 +121,7 @@ private CompletionsForm(final Builder builder) { this.temperature = builder.temperature >= 2 ? 2 : builder.temperature; } this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getModel().getCurrentModel(); + this.user = builder.user; } private String validateBuilderQuery(final String query) { @@ -131,7 +135,6 @@ private long validateLanguage(final String language) { return Try.of(() -> Long.parseLong(language)) .recover(x -> APILocator.getLanguageAPI().getLanguage(language).getId()) .getOrElseTry(() -> APILocator.getLanguageAPI().getDefaultLanguage().getId()); - } public static Builder copy(final CompletionsForm form) { @@ -149,7 +152,8 @@ public static Builder copy(final CompletionsForm form) { .operator(form.operator) .indexName(form.indexName) .threshold(form.threshold) - .stream(form.stream); + .stream(form.stream) + .user(form.user); } public static final class Builder { @@ -182,6 +186,8 @@ public static final class Builder { private String operator = "cosine"; @JsonSetter(nulls = Nulls.SKIP) private String site; + @JsonSetter(nulls = Nulls.SKIP) + private User user; public Builder prompt(String queryOrPrompt) { this.prompt = queryOrPrompt; @@ -224,7 +230,7 @@ public Builder fieldVar(String fieldVar) { } public Builder model(String model) { - this.model =model; + this.model = model; return this; } @@ -254,7 +260,12 @@ public Builder operator(String operator) { } public Builder site(String site) { - this.site =site; + this.site = site; + return this; + } + + public Builder user(User user) { + this.user = user; return this; } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java index 61815b1307eb..62c61fa9d229 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java @@ -1,7 +1,6 @@ package com.dotcms.ai.rest.forms; import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.UtilMethods; @@ -65,8 +64,6 @@ public static final Builder copy(EmbeddingsForm form) { .fields(String.join(",", form.fields)) .velocityTemplate(form.velocityTemplate) .indexName(form.indexName); - - } @Override @@ -103,7 +100,6 @@ public String toString() { '}'; } - public static final class Builder { @JsonSetter(nulls = Nulls.SKIP) public String fields; @@ -135,7 +131,6 @@ public Builder limit(int limit) { return this; } - public Builder offset(int offset) { this.offset = offset; return this; @@ -161,10 +156,10 @@ public Builder velocityTemplate(String velocityTemplate) { return this; } - public EmbeddingsForm build() { return new EmbeddingsForm(this); - } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java similarity index 98% rename from dotCMS/src/main/java/com/dotcms/ai/Marshaller.java rename to dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java index fc39f5f88e8c..0f92396e50be 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java @@ -1,4 +1,4 @@ -package com.dotcms.ai; +package com.dotcms.ai.util; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java deleted file mode 100644 index b2a9b9adf789..000000000000 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.dotcms.ai.util; - -import com.dotcms.ai.AiKeys; -import com.dotcms.ai.app.AIModel; -import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.app.ConfigService; -import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; -import com.dotmarketing.util.json.JSONObject; -import io.vavr.control.Try; -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.*; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClients; - -import javax.ws.rs.HttpMethod; -import javax.ws.rs.core.MediaType; -import java.io.BufferedInputStream; -import java.io.ByteArrayOutputStream; -import java.io.OutputStream; -import java.util.concurrent.ConcurrentHashMap; - -/** - * The OpenAIRequest class is a utility class that handles HTTP requests to the OpenAI API. - * It provides methods for sending GET, POST, PUT, DELETE, and PATCH requests. - * This class also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * This class is implemented as a singleton, meaning that only one instance of the class is created throughout the execution of the program. - */ -public class OpenAIRequest { - - private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>(); - - private OpenAIRequest() {} - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is written to the provided OutputStream. - * This method also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * @param urlIn the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doRequest(final String urlIn, - final String method, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - AppConfig.debugLogger( - OpenAIRequest.class, - () -> String.format( - "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", - urlIn, - method, - System.lineSeparator(), - appConfig.toString(), - System.lineSeparator(), - json.toString(2))); - - if (!appConfig.isEnabled()) { - AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request."); - throw new DotRuntimeException("App dotAI config without API urls or API key"); - } - - final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); - final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) - + model.minIntervalBetweenCalls() - - System.currentTimeMillis(); - if (sleep > 0) { - Logger.info( - OpenAIRequest.class, - "Rate limit:" - + model.getApiPerMinute() - + "/minute, or 1 every " - + model.minIntervalBetweenCalls() - + "ms. Sleeping:" - + sleep); - Try.run(() -> Thread.sleep(sleep)); - } - - lastRestCall.put(model, System.currentTimeMillis()); - - try (CloseableHttpClient httpClient = HttpClients.createDefault()) { - final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); - final HttpUriRequest httpRequest = resolveMethod(method, urlIn); - httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); - httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); - - if (!json.getAsMap().isEmpty()) { - Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); - } - - try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { - final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); - final byte[] buffer = new byte[1024]; - int len; - while ((len = in.read(buffer)) != -1) { - out.write(buffer, 0, len); - out.flush(); - } - } - } catch (Exception e) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)){ - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e); - } else { - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage()); - } - - Logger.warn(OpenAIRequest.class, " - " + method + " : " +json); - - throw new DotRuntimeException(e); - } - } - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is returned as a string. - * - * @param url the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @return the response from the request as a string - */ - public static String doRequest(final String url, - final String method, - final AppConfig appConfig, - final JSONObject json) { - final ByteArrayOutputStream out = new ByteArrayOutputStream(); - doRequest(url, method, appConfig, json, out); - - return out.toString(); - } - - /** - * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doPost(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.POST, appConfig, json, out); - } - - /** - * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doGet(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.GET, appConfig, json, out); - } - - private static HttpUriRequest resolveMethod(final String method, final String urlIn) { - switch(method) { - case HttpMethod.POST: - return new HttpPost(urlIn); - case HttpMethod.PUT: - return new HttpPut(urlIn); - case HttpMethod.DELETE: - return new HttpDelete(urlIn); - case "patch": - return new HttpPatch(urlIn); - case HttpMethod.GET: - default: - return new HttpGet(urlIn); - } - } - -} diff --git a/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java new file mode 100644 index 000000000000..344d4eaced34 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java @@ -0,0 +1,95 @@ +package com.dotcms.ai.validator; + +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.Model; +import com.dotcms.api.system.event.message.MessageSeverity; +import com.dotcms.api.system.event.message.SystemMessageEventUtil; +import com.dotcms.api.system.event.message.builder.SystemMessage; +import com.dotcms.api.system.event.message.builder.SystemMessageBuilder; +import com.dotmarketing.util.DateUtil; +import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.language.LanguageUtil; +import io.vavr.Lazy; +import io.vavr.control.Try; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class AIAppValidator { + + private static final Lazy INSTANCE = Lazy.of(AIAppValidator::new); + + private SystemMessageEventUtil systemMessageEventUtil; + + private AIAppValidator() { + setSystemMessageEventUtil(SystemMessageEventUtil.getInstance()); + } + + public static AIAppValidator get() { + return INSTANCE.get(); + } + + public void validateAIConfig(final AppConfig appConfig, final String userId) { + if (Objects.isNull(userId)) { + AppConfig.debugLogger(getClass(), () -> "User Id is null, skipping AI configuration validation"); + return; + } + + final Set supportedModels = AIModels.get().getOrPullSupportedModels(appConfig.getApiKey()); + final Set unsupportedModels = Stream.of( + appConfig.getModel(), + appConfig.getImageModel(), + appConfig.getEmbeddingsModel()) + .flatMap(aiModel -> aiModel.getModels().stream()) + .map(Model::getName) + .filter(model -> !supportedModels.contains(model)) + .collect(Collectors.toSet()); + if (unsupportedModels.isEmpty()) { + return; + } + + final String unsupported = String.join(", ", unsupportedModels); + final String message = Try + .of(() -> LanguageUtil.get("ai.unsupported.models", unsupported)) + .getOrElse(String.format("The following models are not supported: [%s]", unsupported)); + final SystemMessage systemMessage = new SystemMessageBuilder() + .setMessage(message) + .setSeverity(MessageSeverity.WARNING) + .setLife(DateUtil.SEVEN_SECOND_MILLIS) + .create(); + + systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId)); + } + + public void validateModelsUsage(final AIModel aiModel, final String userId) { + final String unavailableModels = aiModel.getModels() + .stream() + .map(Model::getName) + .collect(Collectors.joining(", ")); + final String message = Try + .of(() -> LanguageUtil.get("ai.models.exhausted", aiModel.getType(), unavailableModels)). + getOrElse( + String.format( + "All the %s models: [%s] have been exhausted since they are invalid or has been decommissioned", + aiModel.getType(), + unavailableModels)); + final SystemMessage systemMessage = new SystemMessageBuilder() + .setMessage(message) + .setSeverity(MessageSeverity.WARNING) + .setLife(DateUtil.SEVEN_SECOND_MILLIS) + .create(); + + systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId)); + } + + @VisibleForTesting + void setSystemMessageEventUtil(SystemMessageEventUtil systemMessageEventUtil) { + this.systemMessageEventUtil = systemMessageEventUtil; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java index 050b56b1e535..0ad6d7837a2d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java @@ -4,9 +4,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.json.JSONObject; @@ -30,11 +28,13 @@ public class AIViewTool implements ViewTool { private AppConfig config; private ChatAPI chatService; private ImageAPI imageService; + private User user; @Override public void init(final Object obj) { context = (ViewContext) obj; config = config(); + user = user(); chatService = chatService(); imageService = imageService(); } @@ -128,12 +128,12 @@ User user() { @VisibleForTesting ChatAPI chatService() { - return APILocator.getDotAIAPI().getChatAPI(config); + return APILocator.getDotAIAPI().getChatAPI(config, user); } @VisibleForTesting ImageAPI imageService() { - return APILocator.getDotAIAPI().getImageAPI(config, user(), APILocator.getHostAPI(), APILocator.getTempFileAPI()); + return APILocator.getDotAIAPI().getImageAPI(config, user, APILocator.getHostAPI(), APILocator.getTempFileAPI()); } private

Try generate(final P prompt, final Function serviceCall) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java index 5508a23f4e32..899f69efe93a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java @@ -9,6 +9,8 @@ import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -17,6 +19,7 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.util.Map; +import java.util.Optional; /** * This class is a ViewTool that provides functionality related to completions. @@ -24,14 +27,18 @@ */ public class CompletionsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig config; + private final User user; CompletionsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.config = config(); + this.user = user(); } @Override @@ -69,7 +76,11 @@ public Object summarize(final String prompt) { * @return The summarized object. */ public Object summarize(final String prompt, final String indexName) { - final CompletionsForm form = new CompletionsForm.Builder().indexName(indexName).prompt(prompt).build(); + final CompletionsForm form = new CompletionsForm.Builder() + .indexName(indexName) + .prompt(prompt) + .user(user) + .build(); try { return APILocator.getDotAIAPI().getCompletionsAPI(config).summarize(form); } catch (Exception e) { @@ -112,7 +123,7 @@ public Object raw(String prompt) { */ public Object raw(final JSONObject prompt) { try { - return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt); + return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt, user.getUserId()); } catch (Exception e) { return handleException(e); } @@ -141,4 +152,9 @@ AppConfig config() { return ConfigService.INSTANCE.config(this.host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java index daf7e1756139..89414823aebb 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java @@ -8,12 +8,15 @@ import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; import javax.servlet.http.HttpServletRequest; import java.util.List; import java.util.Map; +import java.util.Optional; /** * This class provides functionality for generating and managing embeddings. @@ -22,9 +25,11 @@ */ public class EmbeddingsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig appConfig; + private final User user; /** * Constructor for the EmbeddingsTool class. @@ -33,9 +38,11 @@ public class EmbeddingsTool implements ViewTool { * @param initData Initialization data for the tool. */ EmbeddingsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.appConfig = appConfig(); + this.user = user(); } @Override @@ -71,10 +78,14 @@ public List generateEmbeddings(final String prompt) { if (tokens > maxTokens) { Logger.warn( EmbeddingsTool.class, - "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); + "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); } - return APILocator.getDotAIAPI().getEmbeddingsAPI().pullOrGenerateEmbeddings(prompt)._2; + return APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .pullOrGenerateEmbeddings(prompt, Optional.ofNullable(user).map(User::getUserId).orElse(null)) + ._2; } /** @@ -96,4 +107,9 @@ AppConfig appConfig() { return ConfigService.INSTANCE.config(host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java index 87a068eca10c..d8d01341bd0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; import com.dotmarketing.portlets.workflows.actionlet.WorkFlowActionlet; @@ -9,7 +10,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -24,7 +24,7 @@ public List getParameters() { WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite tags ", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); @@ -32,16 +32,26 @@ public List getParameters() { WorkflowActionletParameter limitTagsToHost = new MultiSelectionWorkflowActionletParameter( OpenAIParams.LIMIT_TAGS_TO_HOST.key, "Limit the keywords to pre-existing tags", "Limit", false, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true)) ) ); + + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( overwriteParameter, limitTagsToHost, - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false) + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0.", + ".1", + false) ); } @@ -63,5 +73,4 @@ public void executeAction(final WorkflowProcessor processor, new OpenAIAutoTagRunner(processor, params).run(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java index 2f7013bfd362..a63c03743686 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java @@ -146,7 +146,7 @@ private String openAIRequest(final Contentlet workingContentlet, final String co final String parsedContentPrompt = VelocityUtil.eval(contentToTag, ctx); final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI() - .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000); + .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000, user.getUserId()); return openAIResponse.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content"); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java index b6a14ab22d44..8a8e9320293b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; @@ -10,7 +11,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -23,22 +23,39 @@ public class OpenAIContentPromptActionlet extends WorkFlowActionlet { @Override public List getParameters() { - WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, + final WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite existing content (true|false)", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); - + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( - new WorkflowActionletParameter(OpenAIParams.FIELD_TO_WRITE.key, "The field where you want to write the results. " + - "
If your response is being returned as a json object, this field can be left blank" + - "
and the keys of the json object will be used to update the content fields.", "", false), + new WorkflowActionletParameter( + OpenAIParams.FIELD_TO_WRITE.key, + "The field where you want to write the results. " + + "
If your response is being returned as a json object, this field can be left blank" + + "
and the keys of the json object will be used to update the content fields.", + "", + false), overwriteParameter, - new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false) + new WorkflowActionletParameter( + OpenAIParams.OPEN_AI_PROMPT.key, + "The prompt that will be sent to the AI", + "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", + true), + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0. Defaults to " + + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java index 176b12d860c4..5ebca943a044 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java @@ -145,7 +145,9 @@ private String openAIRequest(final Contentlet workingContentlet) throws Exceptio final Context ctx = VelocityContextFactory.getMockContext(workingContentlet, user); final String parsedPrompt = VelocityUtil.eval(prompt, ctx); - final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI().raw(buildRequest(parsedPrompt, model, temperature)); + final JSONObject openAIResponse = APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(buildRequest(parsedPrompt, model, temperature), user.getUserId()); try { return openAIResponse diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml index e8f03c0c5f80..425c9b42b058 100644 --- a/dotCMS/src/main/resources/apps/dotAI.yml +++ b/dotCMS/src/main/resources/apps/dotAI.yml @@ -15,11 +15,11 @@ params: hint: "Your ChatGPT API key" required: true textModelNames: - value: "gpt-4o" + value: "gpt-4o-mini" hidden: false type: "STRING" label: "Model Names" - hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-3.5-turbo-16k)" + hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-4o-mini)" required: true rolePrompt: value: "You are dotCMSbot, and AI assistant to help content creators generate and rewrite content in their content management system." diff --git a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties index 782f5f57793a..63c55fbc5354 100644 --- a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties +++ b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties @@ -189,6 +189,8 @@ anonymous=Anonymous another-layout-already-exists=Another Tool Group already exists in the system with the same name Any-Structure-Type=Any Content Type Any-Structure=Any Content Type +ai.unsupported.models=The following models are not supported: [{0}] +ai.models.exhausted=All the {0} models: [{1}] have been exhausted since they are invalid or has been decommissioned api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.comparison.placeholder=Comparison api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.language.placeholder=Language api.ruleengine.system.conditionlet.CurrentSessionLanguage.name=Selected Language diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js index 088436aef605..f2db45f1a34c 100644 --- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js +++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js @@ -136,11 +136,12 @@ const writeModelToDropdown = async () => { } const newOption = document.createElement("option"); + console.log(JSON.stringify(dotAiState.config, null, 2)); newOption.value = dotAiState.config.availableModels[i].name; newOption.text = `${dotAiState.config.availableModels[i].name}` - if (dotAiState.config.availableModels[i] === dotAiState.config.model) { + if (dotAiState.config.availableModels[i].current) { newOption.selected = true; - newOption.text = `${dotAiState.config.availableModels[i]} (default)` + newOption.text = `${dotAiState.config.availableModels[i].name} (default)` } modelName.appendChild(newOption); } diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java similarity index 86% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java index e4c43486c3f1..c51e9c6323a5 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java @@ -1,12 +1,11 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import org.junit.Before; import org.junit.Test; @@ -17,17 +16,19 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIChatServiceImplTest { +public class OpenAIChatAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; private AppConfig config; private ChatAPI service; + private User user; @Before public void setUp() { config = mock(AppConfig.class); + user = mock(User.class); service = prepareService(RESPONSE_JSON); } @@ -54,11 +55,9 @@ public void test_sendTextPrompt() { } private ChatAPI prepareService(final String response) { - return new OpenAIChatAPIImpl(config) { - - + return new OpenAIChatAPIImpl(config, user) { @Override - public String doRequest(final String urlIn, final JSONObject json) { + String doRequest(final JSONObject json, final String userId) { return response; } }; @@ -66,7 +65,7 @@ public String doRequest(final String urlIn, final JSONObject json) { private JSONObject prepareJsonObject(final String prompt) { when(config.getModel()) - .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withNames("some-model").build()); + .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withModelNames("some-model").build()); when(config.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)).thenReturn(123.321F); when(config.getRolePrompt()).thenReturn("some-role-prompt"); diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java similarity index 97% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java index 6c3fc6822473..e73d9352a59b 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java @@ -1,7 +1,5 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; @@ -27,7 +25,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIImageServiceImplTest { +public class OpenAIImageAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; @@ -220,7 +218,7 @@ public AIImageRequestDTO.Builder getDtoBuilder() { } private JSONObject prepareJsonObject(final String prompt, final boolean tempFileError) throws Exception { - when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withNames("some-image-model").build()); + when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withModelNames("some-image-model").build()); when(config.getImageSize()).thenReturn("some-image-size"); final File file = mock(File.class); when(file.getName()).thenReturn(UUIDGenerator.shorty()); diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java index c4d5c93b7627..8c1cd1e79c4e 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -1,10 +1,12 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; import org.junit.Before; import org.junit.Test; import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -127,7 +129,7 @@ public void testCreateTextModel() { AIModel model = aiAppUtil.createTextModel(secrets); assertNotNull(model); assertEquals(AIModelType.TEXT, model.getType()); - assertTrue(model.getNames().contains("textmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("textmodel")); } /** @@ -143,7 +145,7 @@ public void testCreateImageModel() { AIModel model = aiAppUtil.createImageModel(secrets); assertNotNull(model); assertEquals(AIModelType.IMAGE, model.getType()); - assertTrue(model.getNames().contains("imagemodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("imagemodel")); } /** @@ -159,7 +161,8 @@ public void testCreateEmbeddingsModel() { AIModel model = aiAppUtil.createEmbeddingsModel(secrets); assertNotNull(model); assertEquals(AIModelType.EMBEDDINGS, model.getType()); - assertTrue(model.getNames().contains("embeddingsmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()) + .contains("embeddingsmodel")); } @Test diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java new file mode 100644 index 000000000000..9109d502f60d --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java @@ -0,0 +1,65 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for the AIProxyClient class. + */ +public class AIProxyClientTest { + + private AIProxyClient proxyClient; + private AIProxiedClient mockProxiedClient; + + @Before + public void setUp() { + mockProxiedClient = mock(AIProxiedClient.class); + proxyClient = new AIProxyClient(mockProxiedClient); + } + + /** + * Scenario: Sending a valid AI request with an output stream + * Given a valid AI request and an output stream + * When the request is sent to the AI service + * Then the response should be written to the output stream + */ + @Test + public void testSendRequest_withValidRequestAndOutput() { + AIRequest request = mock(AIRequest.class); + OutputStream output = new ByteArrayOutputStream(); + + AIResponse response = proxyClient.sendRequest(request, output); + + verify(mockProxiedClient).callToAI(request, output); + assertEquals(AIResponse.EMPTY, response); + } + + /** + * Scenario: Sending a valid AI request with a null output stream + * Given a valid AI request and a null output stream + * When the request is sent to the AI service + * Then the response should be returned as a string + */ + @Test + public void testSendRequest_withValidRequestAndNullOutput() { + AIRequest request = mock(AIRequest.class); + OutputStream output = null; + + AIResponse response = proxyClient.sendRequest(request, output); + + verify(mockProxiedClient).callToAI(request, output); + assertNotNull(response); + } + +} \ No newline at end of file diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java new file mode 100644 index 000000000000..e2f890cfa463 --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java @@ -0,0 +1,102 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.client.AIClient; +import com.dotcms.ai.client.AIClientStrategy; +import com.dotcms.ai.client.AIProxiedClient; +import com.dotcms.ai.client.AIProxyStrategy; +import com.dotcms.ai.client.AIResponseEvaluator; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for the AIProxiedClient class. + * + * @author vico + */ +public class AIProxiedClientTest { + + private AIClient mockClient; + private AIProxyStrategy mockProxyStrategy; + private AIClientStrategy mockClientStrategy; + private AIResponseEvaluator mockResponseEvaluator; + private AIProxiedClient proxiedClient; + + @Before + public void setUp() { + mockClient = mock(AIClient.class); + mockProxyStrategy = mock(AIProxyStrategy.class); + mockClientStrategy = mock(AIClientStrategy.class); + when(mockProxyStrategy.getStrategy()).thenReturn(mockClientStrategy); + mockResponseEvaluator = mock(AIResponseEvaluator.class); + proxiedClient = AIProxiedClient.of(mockClient, mockProxyStrategy, mockResponseEvaluator); + } + + /** + * Scenario: Sending a valid AI request + * Given a valid AI request + * When the request is sent to the AI service + * Then the strategy should be applied + * And the response should be written to the output stream + */ + @Test + public void testCallToAI_withValidRequest() { + AIRequest request = mock(AIRequest.class); + OutputStream output = mock(OutputStream.class); + + AIResponse response = proxiedClient.callToAI(request, output); + + verify(mockClientStrategy).applyStrategy(mockClient, mockResponseEvaluator, request, output); + assertEquals(AIResponse.EMPTY, response); + } + + /** + * Scenario: Sending an AI request with null output stream + * Given a valid AI request and a null output stream + * When the request is sent to the AI service + * Then the strategy should be applied + * And the response should be returned as a string + */ + @Test + public void testCallToAI_withNullOutput() { + AIRequest request = mock(AIRequest.class); + AIResponse response = proxiedClient.callToAI(request, null); + + verify(mockClientStrategy).applyStrategy( + eq(mockClient), + eq(mockResponseEvaluator), + eq(request), + any(OutputStream.class)); + assertEquals("", response.getResponse()); + } + + /** + * Scenario: Sending an AI request with NOOP client + * Given a valid AI request and a NOOP client + * When the request is sent to the AI service + * Then no operations should be performed + * And the response should be empty + */ + @Test + public void testCallToAI_withNoopClient() { + proxiedClient = AIProxiedClient.NOOP; + AIRequest request = AIRequest.builder().build(); + OutputStream output = new ByteArrayOutputStream(); + + AIResponse response = proxiedClient.callToAI(request, output); + + assertEquals(AIResponse.EMPTY, response); + } +} \ No newline at end of file diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java new file mode 100644 index 000000000000..9ce5f40b9257 --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java @@ -0,0 +1,104 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotmarketing.exception.DotRuntimeException; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + + +/** + * Tests for the OpenAIResponseEvaluator class. + * + * @author vico + */ +public class OpenAIResponseEvaluatorTest { + + private OpenAIResponseEvaluator evaluator; + + @Before + public void setUp() { + evaluator = OpenAIResponseEvaluator.get(); + } + + /** + * Scenario: Processing a response with an error + * Given a response with an error message "Model has been deprecated" + * When the response is processed + * Then the metadata should contain the error message "Model has been deprecated" + * And the status should be set to DECOMMISSIONED + */ + @Test + public void testFromResponse_withError() { + String response = new JSONObject().put("error", "Model has been deprecated").toString(); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata); + + assertEquals("Model has been deprecated", metadata.getError()); + assertEquals(ModelStatus.DECOMMISSIONED, metadata.getStatus()); + } + + /** + * Scenario: Processing a response without an error + * Given a response without an error message + * When the response is processed + * Then the metadata should not contain any error message + * And the status should be null + */ + @Test + public void testFromResponse_withoutError() { + String response = new JSONObject().put("data", "some data").toString(); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata); + + assertNull(metadata.getError()); + assertNull(metadata.getStatus()); + } + + /** + * Scenario: Processing an exception of type DotRuntimeException + * Given an exception of type DotAIModelNotFoundException with message "Model not found" + * When the exception is processed + * Then the metadata should contain the error message "Model not found" + * And the status should be set to INVALID + * And the exception should be set to the given DotRuntimeException + */ + @Test + public void testFromException_withDotRuntimeException() { + DotRuntimeException exception = new DotAIModelNotFoundException("Model not found"); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromException(exception, metadata); + + assertEquals("Model not found", metadata.getError()); + assertEquals(ModelStatus.INVALID, metadata.getStatus()); + assertEquals(exception, metadata.getException()); + } + + /** + * Scenario: Processing a general exception + * Given a general exception with message "General error" + * When the exception is processed + * Then the metadata should contain the error message "General error" + * And the status should be set to UNKNOWN + * And the exception should be wrapped in a DotRuntimeException + */ + @Test + public void testFromException_withOtherException() { + Exception exception = new Exception("General error"); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromException(exception, metadata); + + assertEquals("General error", metadata.getError()); + assertEquals(ModelStatus.UNKNOWN, metadata.getStatus()); + assertEquals(DotRuntimeException.class, metadata.getException().getClass()); + } +} diff --git a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java index 81b74a231e3f..0f98dc89849b 100644 --- a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java +++ b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java @@ -1,6 +1,7 @@ package com.dotcms; import com.dotcms.ai.app.AIModelsTest; +import com.dotcms.ai.app.ConfigServiceTest; import com.dotcms.ai.listener.EmbeddingContentListenerTest; import com.dotcms.ai.viewtool.AIViewToolTest; import com.dotcms.ai.viewtool.CompletionsToolTest; @@ -302,6 +303,7 @@ EmbeddingsToolTest.class, CompletionsToolTest.class, AIModelsTest.class, + ConfigServiceTest.class, TimeMachineAPITest.class, Task240513UpdateContentTypesSystemFieldTest.class, PruneTimeMachineBackupJobTest.class, diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 855f61ad4572..02f2f31e6172 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -11,6 +11,7 @@ import com.github.tomakehurst.wiremock.WireMockServer; import java.util.Map; +import java.util.Objects; public interface AiTest { @@ -31,55 +32,55 @@ static WireMockServer prepareWireMock() { return wireMockServer; } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, + static Map aiAppSecrets(final Host host, final String apiKey, final String textModels, final String imageModels, final String embeddingsModel) throws DotDataException, DotSecurityException { - final AppSecrets appSecrets = new AppSecrets.Builder() + final AppSecrets.Builder builder = new AppSecrets.Builder() .withKey(AppKeys.APP_KEY) - .withSecret(AppKeys.API_URL.key, String.format(API_URL, wireMockServer.port())) - .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, wireMockServer.port())) - .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, wireMockServer.port())) + .withSecret(AppKeys.API_URL.key, String.format(API_URL, PORT)) + .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, PORT)) + .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, PORT)) .withHiddenSecret(AppKeys.API_KEY.key, apiKey) - .withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels) - .withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels) - .withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel) .withSecret(AppKeys.IMAGE_SIZE.key, IMAGE_SIZE) .withSecret(AppKeys.LISTENER_INDEXER.key, "{\"default\":\"blog\"}") .withSecret(AppKeys.COMPLETION_ROLE_PROMPT.key, AppKeys.COMPLETION_ROLE_PROMPT.defaultValue) - .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue) - .build(); + .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue); + + if (Objects.nonNull(textModels)) { + builder.withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels); + } + if (Objects.nonNull(imageModels)) { + builder.withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels); + } + if (Objects.nonNull(embeddingsModel)) { + builder.withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel); + } + + final AppSecrets appSecrets = builder.build(); APILocator.getAppsAPI().saveSecrets(appSecrets, host, APILocator.systemUser()); return appSecrets.getSecrets(); } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, - final String apiKey) + static Map aiAppSecrets(final Host host, final String apiKey) throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); + return aiAppSecrets(host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, + static Map aiAppSecrets(final Host host, final String textModels, final String imageModels, final String embeddingsModel) throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, API_KEY, textModels, imageModels, embeddingsModel); + return aiAppSecrets(host, API_KEY, textModels, imageModels, embeddingsModel); } - static Map aiAppSecrets(final WireMockServer wireMockServer, final Host host) + static Map aiAppSecrets(final Host host) throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); - } - - static void removeSecrets(final Host host) throws DotDataException, DotSecurityException { - APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser()); + return aiAppSecrets(host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index e08965e20843..3da3e9a57586 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -1,16 +1,18 @@ package com.dotcms.ai.app; import com.dotcms.ai.AiTest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.model.SimpleModel; import com.dotcms.datagen.SiteDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.exception.DotSecurityException; -import com.dotmarketing.util.DateUtil; import com.github.tomakehurst.wiremock.WireMockServer; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.junit.After; import org.junit.AfterClass; @@ -23,9 +25,11 @@ import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; /** @@ -59,7 +63,7 @@ public void before() { IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); otherHost = new SiteDataGen().nextPersisted(); - List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(wireMockServer, host)).get()); + List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(host)).get()); } @After @@ -73,31 +77,32 @@ public void after() { * Then the correct models should be found and returned. */ @Test - public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - saveSecrets( + public void test_loadModels_andFindThem() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost()); + AiTest.aiAppSecrets( host, "text-model-1,text-model-2", "image-model-3,image-model-4", "embeddings-model-5,embeddings-model-6"); - saveSecrets(otherHost, "text-model-1", null, null); + AiTest.aiAppSecrets(otherHost, "text-model-1", null, null); final String hostId = host.getHostname(); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); - final Optional notFound = aiModels.findModel(hostId, "some-invalid-model-name"); + final Optional notFound = aiModels.findModel(appConfig, "some-invalid-model-name", AIModelType.TEXT); assertTrue(notFound.isEmpty()); - final Optional text1 = aiModels.findModel(hostId, "text-model-1"); - final Optional text2 = aiModels.findModel(hostId, "text-model-2"); - assertModels(text1, text2, AIModelType.TEXT); + final Optional text1 = aiModels.findModel(appConfig, "text-model-1", AIModelType.TEXT); + final Optional text2 = aiModels.findModel(appConfig, "text-model-2", AIModelType.TEXT); + assertModels(text1, text2, AIModelType.TEXT, true); - final Optional image1 = aiModels.findModel(hostId, "image-model-3"); - final Optional image2 = aiModels.findModel(hostId, "image-model-4"); - assertModels(image1, image2, AIModelType.IMAGE); + final Optional image1 = aiModels.findModel(appConfig, "image-model-3", AIModelType.IMAGE); + final Optional image2 = aiModels.findModel(appConfig, "image-model-4", AIModelType.IMAGE); + assertModels(image1, image2, AIModelType.IMAGE, true); - final Optional embeddings1 = aiModels.findModel(hostId, "embeddings-model-5"); - final Optional embeddings2 = aiModels.findModel(hostId, "embeddings-model-6"); - assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS); + final Optional embeddings1 = aiModels.findModel(appConfig, "embeddings-model-5", AIModelType.EMBEDDINGS); + final Optional embeddings2 = aiModels.findModel(appConfig, "embeddings-model-6", AIModelType.EMBEDDINGS); + assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS, true); assertNotSame(text1.get(), image1.get()); assertNotSame(text1.get(), embeddings1.get()); @@ -112,27 +117,135 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx final Optional embeddings3 = aiModels.findModel(hostId, AIModelType.EMBEDDINGS); assertSameModels(embeddings3, embeddings1, embeddings2); - final Optional text4 = aiModels.findModel(otherHost.getHostname(), "text-model-1"); + final AppConfig otherAppConfig = ConfigService.INSTANCE.config(otherHost); + final Optional text4 = aiModels.findModel(otherAppConfig, "text-model-1", AIModelType.TEXT); assertTrue(text3.isPresent()); assertNotSame(text1.get(), text4.get()); - saveSecrets( + AiTest.aiAppSecrets( host, "text-model-7,text-model-8", "image-model-9,image-model-10", "embeddings-model-11, embeddings-model-12"); - final Optional text7 = aiModels.findModel(hostId, "text-model-7"); - final Optional text8 = aiModels.findModel(hostId, "text-model-8"); + final Optional text7 = aiModels.findModel(otherAppConfig, "text-model-7", AIModelType.TEXT); + final Optional text8 = aiModels.findModel(otherAppConfig, "text-model-8", AIModelType.TEXT); assertNotPresentModels(text7, text8); - final Optional image9 = aiModels.findModel(hostId, "image-model-9"); - final Optional image10 = aiModels.findModel(hostId, "image-model-10"); + final Optional image9 = aiModels.findModel(otherAppConfig, "image-model-9", AIModelType.IMAGE); + final Optional image10 = aiModels.findModel(otherAppConfig, "image-model-10", AIModelType.IMAGE); assertNotPresentModels(image9, image10); - final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11"); - final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12"); + final Optional embeddings11 = aiModels.findModel(otherAppConfig, "embeddings-model-11", AIModelType.EMBEDDINGS); + final Optional embeddings12 = aiModels.findModel(otherAppConfig, "embeddings-model-12", AIModelType.EMBEDDINGS); assertNotPresentModels(embeddings11, embeddings12); + + final List available = aiModels.getAvailableModels(); + final List availableNames = List.of( + "gpt-3.5-turbo-16k", "dall-e-3", "text-embedding-ada-002", + "text-model-1", "text-model-7", "text-model-8", + "image-model-9", "image-model-10", + "embeddings-model-11", "embeddings-model-12"); + assertTrue(available.stream().anyMatch(model -> availableNames.contains(model.getName()))); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveModel method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveModel() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost()); + AiTest.aiAppSecrets(host, "text-model-20", "image-model-21", "embeddings-model-22"); + ConfigService.INSTANCE.config(host); + AiTest.aiAppSecrets(otherHost, "text-model-23", null, null); + ConfigService.INSTANCE.config(otherHost); + + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.TEXT).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.IMAGE).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + assertTrue(aiModels.resolveModel(otherHost.getHostname(), AIModelType.TEXT).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.IMAGE).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveAIModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveAIModelOrThrow() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost()); + AiTest.aiAppSecrets(host, "text-model-30", "image-model-31", "embeddings-model-32"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final AIModel aiModel30 = aiModels.resolveAIModelOrThrow(appConfig, "text-model-30", AIModelType.TEXT); + final AIModel aiModel31 = aiModels.resolveAIModelOrThrow(appConfig, "image-model-31", AIModelType.IMAGE); + final AIModel aiModel32 = aiModels.resolveAIModelOrThrow( + appConfig, + "embeddings-model-32", + AIModelType.EMBEDDINGS); + + assertNotNull(aiModel30); + assertNotNull(aiModel31); + assertNotNull(aiModel32); + assertEquals("text-model-30", aiModel30.getModel("text-model-30").getName()); + assertEquals("image-model-31", aiModel31.getModel("image-model-31").getName()); + assertEquals("embeddings-model-32", aiModel32.getModel("embeddings-model-32").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-33", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-34", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-35", AIModelType.EMBEDDINGS)); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveModelOrThrow() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost()); + AiTest.aiAppSecrets(host, "text-model-40", "image-model-41", "embeddings-model-42"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final Tuple2 modelTuple40 = aiModels.resolveModelOrThrow( + appConfig, + "text-model-40", + AIModelType.TEXT); + final Tuple2 modelTuple41 = aiModels.resolveModelOrThrow( + appConfig, + "image-model-41", + AIModelType.IMAGE); + final Tuple2 modelTuple42 = aiModels.resolveModelOrThrow( + appConfig, + "embeddings-model-42", + AIModelType.EMBEDDINGS); + + assertNotNull(modelTuple40); + assertNotNull(modelTuple41); + assertNotNull(modelTuple42); + assertEquals("text-model-40", modelTuple40._1.getModel("text-model-40").getName()); + assertEquals("image-model-41", modelTuple41._1.getModel("image-model-41").getName()); + assertEquals("embeddings-model-42", modelTuple42._1.getModel("embeddings-model-42").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-43", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-44", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-45", AIModelType.EMBEDDINGS)); } /** @@ -141,69 +254,46 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx * Then a list of supported models should be returned. */ @Test - public void test_getOrPullSupportedModules() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); + public void test_getOrPullSupportedModels() throws Exception { + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost); AIModels.get().cleanSupportedModelsCache(); - Set supported = aiModels.getOrPullSupportedModels(); + Set supported = aiModels.getOrPullSupportedModels(AiTest.API_KEY); assertNotNull(supported); assertEquals(38, supported.size()); - - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /** * Given an invalid URL for supported models * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown */ @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_withNetworkError() { + public void test_getOrPullSupportedModuels_withNetworkError() { AIModels.get().cleanSupportedModelsCache(); IPUtils.disabledIpPrivateSubnet(false); - final Set supported = aiModels.getOrPullSupportedModels(); - assertSupported(supported); - + aiModels.getOrPullSupportedModels(AiTest.API_KEY); IPUtils.disabledIpPrivateSubnet(true); - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /** * Given no API key * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown. */ @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), null); + public void test_getOrPullSupportedModels_noApiKey() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost(), null); AIModels.get().cleanSupportedModelsCache(); - aiModels.getOrPullSupportedModels(); + aiModels.getOrPullSupportedModels(null); } - /** - * Given no API key - * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. - */ - @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException { - AiTest.removeSecrets(APILocator.systemHost()); - - AIModels.get().cleanSupportedModelsCache(); - aiModels.getOrPullSupportedModels(); - } - - private void saveSecrets(final Host host, - final String textModels, - final String imageModels, - final String embeddingsModels) throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, textModels, imageModels, embeddingsModels); - DateUtil.sleep(1000); - } - - private static void assertSameModels(Optional text3, Optional text1, Optional text2) { + private static void assertSameModels(final Optional text3, + final Optional text1, + final Optional text2) { assertTrue(text3.isPresent()); assertSame(text1.get(), text3.get()); assertSame(text2.get(), text3.get()); @@ -211,12 +301,17 @@ private static void assertSameModels(Optional text3, Optional private static void assertModels(final Optional model1, final Optional model2, - final AIModelType type) { + final AIModelType type, + final boolean assertModelNames) { assertTrue(model1.isPresent()); assertTrue(model2.isPresent()); assertSame(model1.get(), model2.get()); assertSame(type, model1.get().getType()); assertSame(type, model2.get().getType()); + if (assertModelNames) { + assertTrue(model1.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + assertTrue(model2.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + } } private static void assertNotPresentModels(final Optional model1, final Optional model2) { @@ -224,9 +319,4 @@ private static void assertNotPresentModels(final Optional model1, final assertTrue(model2.isEmpty()); } - private static void assertSupported(Set supported) { - assertNotNull(supported); - assertTrue(supported.isEmpty()); - } - } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java new file mode 100644 index 000000000000..2e6143037095 --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java @@ -0,0 +1,101 @@ +package com.dotcms.ai.app; + +import com.dotcms.ai.AiTest; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.LicenseValiditySupplier; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for the ConfigService class. + * + *

+ * This class contains tests to verify the behavior of the ConfigService, + * including scenarios with valid and invalid licenses, and configurations + * with and without secrets. + *

+ * + *

+ * The tests ensure that the ConfigService correctly initializes and + * configures the AppConfig based on the provided Host and license validity. + *

+ * + * @author vico + */ +public class ConfigServiceTest { + + private Host host; + private ConfigService configService; + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + } + + @Before + public void before() { + host = new SiteDataGen().nextPersisted(); + configService = ConfigService.INSTANCE; + } + + /** + * Given a ConfigService with an invalid license + * When the config method is called with a host + * Then the models should not be operational. + */ + @Test + public void test_invalidLicense() { + configService = new ConfigService(new LicenseValiditySupplier() { + @Override + public boolean hasValidLicense() { + return false; + } + }); + final AppConfig appConfig = configService.config(host); + + assertFalse(appConfig.getModel().isOperational()); + assertFalse(appConfig.getImageModel().isOperational()); + assertFalse(appConfig.getEmbeddingsModel().isOperational()); + } + + /** + * Given a host with secrets and a ConfigService + * When the config method is called with the host + * Then the models should be operational and the host should be correctly set in the AppConfig. + */ + @Test + public void test_config_hostWithSecrets() throws Exception { + AiTest.aiAppSecrets(host, "text-model-0", "image-model-1", "embeddings-model-2"); + final AppConfig appConfig = configService.config(host); + + assertTrue(appConfig.getModel().isOperational()); + assertTrue(appConfig.getImageModel().isOperational()); + assertTrue(appConfig.getEmbeddingsModel().isOperational()); + assertEquals(host.getHostname(), appConfig.getHost()); + } + + /** + * Given a host without secrets and a ConfigService + * When the config method is called with the host + * Then the models should be operational and the host should be set to "System Host" in the AppConfig. + */ + @Test + public void test_config_hostWithoutSecrets() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost(), "text-model-10", "image-model-11", "embeddings-model-12"); + final AppConfig appConfig = configService.config(host); + + assertTrue(appConfig.getModel().isOperational()); + assertTrue(appConfig.getImageModel().isOperational()); + assertTrue(appConfig.getEmbeddingsModel().isOperational()); + assertEquals("System Host", appConfig.getHost()); + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java index 3c61cd335f55..a41bcc9b1398 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java @@ -191,8 +191,8 @@ private static boolean waitForEmbeddings(final Contentlet blogContent, final Str } private static void addDotAISecrets() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, AiTest.API_KEY); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), AiTest.API_KEY); + AiTest.aiAppSecrets(host, AiTest.API_KEY); + AiTest.aiAppSecrets(APILocator.systemHost(), AiTest.API_KEY); } private static void removeDotAISecrets() throws DotDataException, DotSecurityException { @@ -232,4 +232,5 @@ public static java.util.function.Predicate distinctByKey( Set seen = ConcurrentHashMap.newKeySet(); return t -> seen.add(keyExtractor.apply(t)); } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java index 314079da2a93..1062044719a9 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java @@ -7,6 +7,7 @@ import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; @@ -48,8 +49,9 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - config = ConfigService.INSTANCE.config(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + config = ConfigService.INSTANCE.config(systemHost); } @AfterClass diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java index f00c1ab31ae8..e481133edbca 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java @@ -6,6 +6,7 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; @@ -13,6 +14,7 @@ import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.commons.lang3.StringUtils; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; @@ -41,6 +43,7 @@ public class CompletionsToolTest { private static AppConfig appConfig; + private static User user; private static WireMockServer wireMockServer; private static Host host; @@ -52,9 +55,10 @@ public static void beforeClass() throws Exception { IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(APILocator.systemHost(), "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + AiTest.aiAppSecrets(host, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); } @AfterClass @@ -86,7 +90,7 @@ public void test_getConfig() { assertNotNull(config); assertEquals(AppKeys.COMPLETION_ROLE_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_ROLE_PROMPT.key)); assertEquals(AppKeys.COMPLETION_TEXT_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_TEXT_PROMPT.key)); - assertEquals("gpt-3.5-turbo-16k", config.get(AppKeys.TEXT_MODEL_NAMES.key)); + assertEquals("gpt-4o-mini", config.get(AppKeys.TEXT_MODEL_NAMES.key)); } /** @@ -119,7 +123,7 @@ public void test_summarize() { @Test public void test_raw() { final String query = "What is the speed of light in the vacuum"; - final String prompt = String.format("{\"model\":\"gpt-3.5-turbo-16k\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); + final String prompt = String.format("{\"model\":\"gpt-4o-mini\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); final JSONObject result = (JSONObject) completionsTool.raw(prompt); assertResult(result); @@ -142,7 +146,7 @@ public void test_raw_json() { messages.put(new JSONObject().put("role", "user").put("content", query + "?")); messages.put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style")); json.put("messages", messages); - json.put("model", "gpt-3.5-turbo-16k"); + json.put("model", "gpt-4o-mini"); final JSONObject result = (JSONObject) completionsTool.raw(json); assertResult(result); @@ -160,7 +164,7 @@ public void test_raw_json() { @Test public void test_raw_map() { final String query = "Who was the first president of the United States"; - final Map map = Map.of("model", "gpt-3.5-turbo-16k", "messages", new JSONArray() + final Map map = Map.of("model", "gpt-4o-mini", "messages", new JSONArray() .put(new JSONObject().put("role", "user").put("content", query + "?")) .put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style"))); @@ -179,6 +183,11 @@ Host host() { AppConfig config() { return appConfig; } + + @Override + User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java index a58bf579bc65..b12d1600993a 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java @@ -5,6 +5,7 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; @@ -12,6 +13,7 @@ import com.dotmarketing.exception.DotDataException; import com.dotmarketing.exception.DotSecurityException; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; import org.junit.Before; @@ -43,6 +45,7 @@ public class EmbeddingsToolTest { private Host host; private AppConfig appConfig; + private User user; private EmbeddingsTool embeddingsTool; @BeforeClass @@ -50,7 +53,7 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); + AiTest.aiAppSecrets(APILocator.systemHost()); } @Before @@ -58,8 +61,9 @@ public void before() throws DotDataException, DotSecurityException { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); host = new SiteDataGen().nextPersisted(); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(host); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); embeddingsTool = prepareEmbeddingsTool(viewContext); } @@ -136,6 +140,11 @@ Host host() { AppConfig appConfig() { return appConfig; } + + @Override + public User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java index 344efe66d3af..2d2d64362f76 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java @@ -78,7 +78,7 @@ public static void beforeClass() throws Exception { .withValue("{\"default\":\"blog\"}".toCharArray()) .build()); config = new AppConfig(host.getHostname(), secrets); - DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments)-> new CompletionsAPI() { + DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments) -> new CompletionsAPI() { @Override public JSONObject summarize(CompletionsForm searcher) { return null; @@ -90,7 +90,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(JSONObject promptJSON) { + public JSONObject raw(JSONObject promptJSON, final String userId) { return null; } @@ -104,7 +104,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java index 9e92628e2df6..3db2f537423c 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java @@ -89,7 +89,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(final JSONObject promptJSON) { + public JSONObject raw(final JSONObject promptJSON, final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + @@ -126,7 +126,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-postman/pom.xml b/dotcms-postman/pom.xml index 57c26cd6d645..382b5d793ef3 100644 --- a/dotcms-postman/pom.xml +++ b/dotcms-postman/pom.xml @@ -116,9 +116,6 @@ WireMock: green - - --verbose - diff --git a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json index 9a49e8bcd38e..fb7fd48b7309 100644 --- a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json +++ b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json @@ -2899,7 +2899,7 @@ "listen": "test", "script": { "exec": [ - "pm.test('Status code should be ok 20', function () {", + "pm.test('Status code should be ok 200', function () {", " pm.response.to.have.status(200);", "});", "", @@ -3064,7 +3064,7 @@ "listen": "test", "script": { "exec": [ - "pm.test('Status code should be ok 20', function () {", + "pm.test('Status code should be ok 200', function () {", " pm.response.to.have.status(200);", "});", "" @@ -3256,7 +3256,7 @@ "listen": "test", "script": { "exec": [ - "pm.test('Status code should be ok 20', function () {", + "pm.test('Status code should be ok 200', function () {", " pm.response.to.have.status(200);", "});", "",