diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java index d980647f7b41..38ea42606703 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java @@ -1,9 +1,7 @@ package com.dotcms.ai.api; -import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotmarketing.util.json.JSONObject; -import io.vavr.Lazy; import java.io.OutputStream; @@ -37,9 +35,10 @@ public interface CompletionsAPI { * this method takes a prompt in the form of json and returns a json AI response based upon that prompt * * @param promptJSON + * @param userId * @return */ - JSONObject raw(JSONObject promptJSON); + JSONObject raw(JSONObject promptJSON, String userId); /** * this method takes a prompt and returns the AI response based upon that prompt @@ -58,9 +57,15 @@ public interface CompletionsAPI { * @param model * @param temperature * @param maxTokens + * @param userId * @return */ - JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens); + JSONObject prompt(String systemPrompt, + String userPrompt, + String model, + float temperature, + int maxTokens, + String userId); /** * this method takes a prompt in the form of json and returns streaming AI response based upon that prompt diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index 008f65609096..ebfcaa5b2c9b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -2,13 +2,17 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.mock.request.FakeHttpRequest; import com.dotcms.mock.response.BaseResponse; @@ -16,18 +20,17 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; import io.vavr.Lazy; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.ArrayUtils; import org.apache.velocity.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.HttpMethod; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; @@ -42,15 +45,13 @@ public class CompletionsAPIImpl implements CompletionsAPI { private final AppConfig config; - private final Lazy defaultConfig; public CompletionsAPIImpl(final AppConfig config) { - defaultConfig = - Lazy.of(() -> ConfigService.INSTANCE.config( - Try.of(() -> WebAPILocator - .getHostWebAPI() - .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) - .getOrElse(APILocator.systemHost()))); + final Lazy defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config( + Try.of(() -> WebAPILocator + .getHostWebAPI() + .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) + .getOrElse(APILocator.systemHost()))); this.config = Optional.ofNullable(config).orElse(defaultConfig.get()); } @@ -59,8 +60,9 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String modelIn, final float temperature, - final int maxTokens) { - final AIModel model = config.resolveModelOrThrow(modelIn); + final int maxTokens, + final String userId) { + final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2; final JSONObject json = new JSONObject(); json.put(AiKeys.TEMPERATURE, temperature); @@ -70,15 +72,17 @@ public JSONObject prompt(final String systemPrompt, json.put(AiKeys.MAX_TOKENS, maxTokens); } - json.put(AiKeys.MODEL, model.getCurrentModel()); + json.put(AiKeys.MODEL, model.getName()); - return raw(json); + return raw(json, userId); } @Override public JSONObject summarize(final CompletionsForm summaryRequest) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); - final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); + final List localResults = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .getEmbeddingResults(searcher); // send all this as a json blob to OpenAI final JSONObject json = buildRequestJson(summaryRequest, localResults); @@ -87,58 +91,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) { } json.put(AiKeys.STREAM, false); - final String openAiResponse = - Try.of(() -> OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json)) - .getOrElseThrow(DotRuntimeException::new); - final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults); + final String openAiResponse = Try + .of(() -> sendRequest(config, json, UtilMethods.extractUserIdOrNull(summaryRequest.user))) + .getOrElseThrow(DotRuntimeException::new) + .getResponse(); + final JSONObject dotCMSResponse = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .reduceChunksToContent(searcher, localResults); dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse)); return dotCMSResponse; } @Override - public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) { + public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); - final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); + final List localResults = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .getEmbeddingResults(searcher); final JSONObject json = buildRequestJson(summaryRequest, localResults); json.put(AiKeys.STREAM, true); - OpenAIRequest.doPost(config.getApiUrl(), config, json, out); + AIProxyClient.get().callToAI( + JSONObjectAIRequest.quickText( + config, + json, + UtilMethods.extractUserIdOrNull(summaryRequest.user)), + output); } @Override - public JSONObject raw(final JSONObject json) { - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI request:" + json.toString(2)); - } + public JSONObject raw(final JSONObject json, final String userId) { + AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2)); - final String response = OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json); - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI response:" + response); - } + final String response = sendRequest(config, json, userId).getResponse(); + AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response); return new JSONObject(response); } @Override - public JSONObject raw(CompletionsForm promptForm) { + public JSONObject raw(final CompletionsForm promptForm) { JSONObject jsonObject = buildRequestJson(promptForm); - return raw(jsonObject); + return raw(jsonObject, UtilMethods.extractUserIdOrNull(promptForm.user)); } @Override - public void rawStream(final CompletionsForm promptForm, final OutputStream out) { + public void rawStream(final CompletionsForm promptForm, final OutputStream output) { final JSONObject json = buildRequestJson(promptForm); json.put(AiKeys.STREAM, true); - OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out); + AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText( + config, + json, + UtilMethods.extractUserIdOrNull(promptForm.user)), + output); + } + + private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) { + return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(appConfig, payload, userId)); } private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) { @@ -151,7 +161,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f } private JSONObject buildRequestJson(final CompletionsForm form, final List searchResults) { - final AIModel model = config.resolveModelOrThrow(form.model); + final Tuple2 modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT); // aggregate matching results into text final StringBuilder supportingContent = new StringBuilder(); searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" ")); @@ -162,7 +172,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List 0 && initArguments[0] instanceof AppConfig) { - return new OpenAIChatAPIImpl((AppConfig) initArguments[0]); + final User user = initArguments.length > 1 && initArguments[1] instanceof User ? (User) initArguments[1] : null; + return new OpenAIChatAPIImpl((AppConfig) initArguments[0], user); } throw new IllegalArgumentException("To create a ChatAPI you need to provide an AppConfig"); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java index 2ffaa8e702db..e619ae11ab03 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java @@ -140,10 +140,11 @@ public interface EmbeddingsAPI { * Embeddings * * @param content The content that will be tokenized and sent to OpenAI. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String content); + Tuple2> pullOrGenerateEmbeddings(String content, String userId); /** * this method takes a snippet of content and will try to see if we have already generated @@ -154,10 +155,11 @@ public interface EmbeddingsAPI { * * @param contentId The ID of the Contentlet being sent to the OpenAI Endpoint. * @param content The actual indexable data that will be tokenized and sent to OpenAI service. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String contentId, final String content); + Tuple2> pullOrGenerateEmbeddings(String contentId, String content, String userId); /** * Checks if the embeddings for the given inode, indexName, and extractedText already exist in the database. diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java index c40a8e1692e4..9ed0d974ec8f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java @@ -4,12 +4,13 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; import com.dotcms.ai.db.EmbeddingsDTO.Builder; import com.dotcms.ai.db.EmbeddingsFactory; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotcms.ai.util.ContentToStringUtil; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.VelocityContextFactory; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.api.web.HttpServletResponseThreadLocal; @@ -43,7 +44,6 @@ import org.apache.velocity.context.Context; import javax.validation.constraints.NotNull; -import javax.ws.rs.HttpMethod; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -311,13 +311,15 @@ public void initEmbeddingsTable() { } @Override - public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content) { - return pullOrGenerateEmbeddings("N/A", content); + public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content, final String userId) { + return pullOrGenerateEmbeddings("N/A", content, userId); } @WrapInTransaction @Override - public Tuple2> pullOrGenerateEmbeddings(final String contentId, @NotNull final String content) { + public Tuple2> pullOrGenerateEmbeddings(final String contentId, + @NotNull final String content, + final String userId) { if (UtilMethods.isEmpty(content)) { return Tuple.of(0, List.of()); } @@ -349,7 +351,7 @@ public Tuple2> pullOrGenerateEmbeddings(final String conten final Tuple2> openAiEmbeddings = Tuple.of( tokens.size(), - sendTokensToOpenAI(contentId, tokens)); + sendTokensToOpenAI(contentId, tokens, userId)); saveEmbeddingsForCache(content, openAiEmbeddings); EMBEDDING_CACHE.put(hashed, openAiEmbeddings); @@ -420,19 +422,20 @@ private void saveEmbeddingsForCache(final String content, final Tuple2 sendTokensToOpenAI(final String contentId, @NotNull final List tokens) { + private List sendTokensToOpenAI(final String contentId, + @NotNull final List tokens, + final String userId) { final JSONObject json = new JSONObject(); json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel()); json.put(AiKeys.INPUT, tokens); debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens)); - final String responseString = OpenAIRequest.doRequest( - config.getApiEmbeddingsUrl(), - HttpMethod.POST, - config, - json); + final String responseString = AIProxyClient.get() + .callToAI(JSONObjectAIRequest.quickEmbeddings(config, json, userId)) + .getResponse(); debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s", contentId, responseString.replace("\n", BLANK))); final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> { @@ -490,8 +493,10 @@ private List getEmbeddingsFromJSON(final String contentId, final JSONObje } } - private EmbeddingsDTO getSearcher(EmbeddingsDTO searcher) { - final List queryEmbeddings = pullOrGenerateEmbeddings(searcher.query)._2; + private EmbeddingsDTO getSearcher(final EmbeddingsDTO searcher) { + final List queryEmbeddings = pullOrGenerateEmbeddings( + searcher.query, + UtilMethods.extractUserIdOrNull(searcher.user))._2; return EmbeddingsDTO.copy(searcher).withEmbeddings(queryEmbeddings).build(); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java index 47058164d4ca..ed2c4c5c385d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java @@ -6,6 +6,7 @@ import com.dotcms.ai.util.EncodingUtil; import com.dotcms.business.WrapInTransaction; import com.dotcms.exception.ExceptionUtil; +import com.dotmarketing.business.APILocator; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; @@ -119,7 +120,10 @@ private void saveEmbedding(@NotNull final String initial) { } final Tuple2> embeddings = - this.embeddingsAPI.pullOrGenerateEmbeddings(this.contentlet.getIdentifier(), normalizedContent); + this.embeddingsAPI.pullOrGenerateEmbeddings( + contentlet.getIdentifier(), + normalizedContent, + APILocator.systemUser().getUserId()); if (embeddings._2.isEmpty()) { Logger.info(this.getClass(), String.format("No tokens for Content Type " + "'%s'. Normalized content: %s", this.contentlet.getContentType().variable(), normalizedContent)); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java index 2e94dbe4218d..99889d1e6466 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java @@ -3,21 +3,24 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.util.OpenAIRequest; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; -import javax.ws.rs.HttpMethod; import java.util.List; import java.util.Map; public class OpenAIChatAPIImpl implements ChatAPI { private final AppConfig config; + private final User user; - public OpenAIChatAPIImpl(final AppConfig appConfig) { + public OpenAIChatAPIImpl(final AppConfig appConfig, final User user) { this.config = appConfig; + this.user = user; } @Override @@ -36,7 +39,7 @@ public JSONObject sendRawRequest(final JSONObject prompt) { prompt.remove(AiKeys.PROMPT); - return new JSONObject(doRequest(config.getApiUrl(), prompt)); + return new JSONObject(doRequest(prompt, user.getUserId())); } @Override @@ -47,8 +50,8 @@ public JSONObject sendTextPrompt(final String textPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final JSONObject json, final String userId) { + return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(config, json, userId)).getResponse(); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java index 041a49b6e2d1..eeef0880bfa5 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java @@ -2,8 +2,9 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotcms.ai.model.AIImageRequestDTO; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.OpenAiRequestUtil; import com.dotcms.ai.util.StopWordsUtil; import com.dotcms.api.web.HttpServletRequestThreadLocal; @@ -24,7 +25,6 @@ import io.vavr.control.Try; import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.HttpMethod; import java.net.URL; import java.text.SimpleDateFormat; import java.util.Date; @@ -173,8 +173,10 @@ private String generateFileName(final String originalPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final String urlIn, final JSONObject json) { + return AIProxyClient.get() + .callToAI(JSONObjectAIRequest.quickImage(config, json, user.getUserId())) + .getResponse(); } @VisibleForTesting diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java index a4f6d2c8fb12..3cb2a1dad162 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java @@ -6,6 +6,7 @@ import com.liferay.util.StringPool; import io.vavr.Lazy; import io.vavr.control.Try; +import org.apache.commons.collections4.CollectionUtils; import java.util.Arrays; import java.util.List; @@ -40,9 +41,14 @@ public static AIAppUtil get() { * @return the created text model instance */ public AIModel createTextModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.TEXT) - .withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS)) @@ -57,9 +63,14 @@ public AIModel createTextModel(final Map secrets) { * @return the created image model instance */ public AIModel createImageModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.IMAGE) - .withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS)) @@ -74,9 +85,14 @@ public AIModel createImageModel(final Map secrets) { * @return the created embeddings model instance */ public AIModel createEmbeddingsModel(final Map secrets) { + final List modelNames = splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES); + if (CollectionUtils.isEmpty(modelNames)) { + return AIModel.NOOP_MODEL; + } + return AIModel.builder() .withType(AIModelType.EMBEDDINGS) - .withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES)) + .withModelNames(modelNames) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS)) @@ -117,9 +133,11 @@ public String discoverSecret(final Map secrets, final AppKeys ke * @return the list of split secret values */ public List splitDiscoveredSecret(final Map secrets, final AppKeys key) { - return Arrays.stream(Optional.ofNullable(discoverSecret(secrets, key)).orElse(StringPool.BLANK).split(",")) - .map(String::trim) - .map(String::toLowerCase) + return Arrays + .stream(Optional + .ofNullable(discoverSecret(secrets, key)) + .map(secret -> secret.split(StringPool.COMMA)) + .orElse(new String[0])) .collect(Collectors.toList()); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index efbcc09a0872..070b68306877 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -1,12 +1,15 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIModelNotFoundException; import com.dotcms.util.DotPreconditions; import com.dotmarketing.util.Logger; import java.util.List; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * Represents an AI model with various attributes such as type, names, tokens per minute, @@ -20,41 +23,34 @@ public class AIModel { public static final AIModel NOOP_MODEL = AIModel.builder() .withType(AIModelType.UNKNOWN) - .withNames(List.of()) + .withModelNames(List.of()) .build(); private final AIModelType type; - private final List names; + private final List models; private final int tokensPerMinute; private final int apiPerMinute; private final int maxTokens; private final boolean isCompletion; - private final AtomicInteger current; - private final AtomicBoolean decommissioned; - - private AIModel(final AIModelType type, - final List names, - final int tokensPerMinute, - final int apiPerMinute, - final int maxTokens, - final boolean isCompletion) { - DotPreconditions.checkNotNull(type, "type cannot be null"); - this.type = type; - this.names = Optional.ofNullable(names).orElse(List.of()); - this.tokensPerMinute = tokensPerMinute; - this.apiPerMinute = apiPerMinute; - this.maxTokens = maxTokens; - this.isCompletion = isCompletion; - current = new AtomicInteger(this.names.isEmpty() ? -1 : 0); - decommissioned = new AtomicBoolean(false); + private final AtomicInteger currentModelIndex; + + private AIModel(final Builder builder) { + DotPreconditions.checkNotNull(builder.type, "type cannot be null"); + this.type = builder.type; + this.models = builder.models; + this.tokensPerMinute = builder.tokensPerMinute; + this.apiPerMinute = builder.apiPerMinute; + this.maxTokens = builder.maxTokens; + this.isCompletion = builder.isCompletion; + currentModelIndex = new AtomicInteger(this.models.isEmpty() ? -1 : 0); } public AIModelType getType() { return type; } - public List getNames() { - return names; + public List getModels() { + return models; } public int getTokensPerMinute() { @@ -73,38 +69,37 @@ public boolean isCompletion() { return isCompletion; } - public int getCurrent() { - return current.get(); - } - - public void setCurrent(final int current) { - if (!isCurrentValid(current)) { - logInvalidModelMessage(); - return; - } - this.current.set(current); - } - - public boolean isDecommissioned() { - return decommissioned.get(); + public int getCurrentModelIndex() { + return currentModelIndex.get(); } - public void setDecommissioned(final boolean decommissioned) { - this.decommissioned.set(decommissioned); + public void setCurrentModelIndex(final int currentModelIndex) { + this.currentModelIndex.set(currentModelIndex); } public boolean isOperational() { return this != NOOP_MODEL; } - public String getCurrentModel() { - final int currentIndex = this.current.get(); + public Model getCurrent() { + final int currentIndex = currentModelIndex.get(); if (!isCurrentValid(currentIndex)) { logInvalidModelMessage(); return null; } + return models.get(currentIndex); + } - return names.get(currentIndex); + public String getCurrentModel() { + return getCurrent().getName(); + } + + public Model getModel(final String modelName) { + final String normalized = modelName.trim().toLowerCase(); + return models.stream() + .filter(model -> normalized.equals(model.getName())) + .findFirst() + .orElseThrow(() -> new DotAIModelNotFoundException(String.format("Model [%s] not found", modelName))); } public long minIntervalBetweenCalls() { @@ -115,22 +110,21 @@ public long minIntervalBetweenCalls() { public String toString() { return "AIModel{" + "type=" + type + - ", names=" + names + + ", models='" + models + '\'' + ", tokensPerMinute=" + tokensPerMinute + ", apiPerMinute=" + apiPerMinute + ", maxTokens=" + maxTokens + ", isCompletion=" + isCompletion + - ", current=" + current + - ", decommissioned=" + decommissioned + + ", currentModelIndex=" + currentModelIndex.get() + '}'; } private boolean isCurrentValid(final int current) { - return !names.isEmpty() && current >= 0 && current < names.size(); + return !models.isEmpty() && current >= 0 && current < models.size(); } private void logInvalidModelMessage() { - Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", names.size())); + Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", models.size())); } public static Builder builder() { @@ -140,7 +134,7 @@ public static Builder builder() { public static class Builder { private AIModelType type; - private List names; + private List models; private int tokensPerMinute; private int apiPerMinute; private int maxTokens; @@ -154,13 +148,25 @@ public Builder withType(final AIModelType type) { return this; } - public Builder withNames(final List names) { - this.names = names; + public Builder withModels(final List models) { + this.models = Optional.ofNullable(models).orElse(List.of()); return this; } - public Builder withNames(final String... names) { - return withNames(List.of(names)); + public Builder withModelNames(final List names) { + return withModels( + Optional.ofNullable(names) + .map(modelNames -> IntStream.range(0, modelNames.size()) + .mapToObj(index -> Model.builder() + .withName(modelNames.get(index)) + .withIndex(index) + .build()) + .collect(Collectors.toList())) + .orElse(List.of())); + } + + public Builder withModelNames(final String... names) { + return withModelNames(List.of(names)); } public Builder withTokensPerMinute(final int tokensPerMinute) { @@ -184,7 +190,7 @@ public Builder withIsCompletion(final boolean isCompletion) { } public AIModel build() { - return new AIModel(type, names, tokensPerMinute, apiPerMinute, maxTokens, isCompletion); + return new AIModel(this); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 388afb7545e3..7f3765e4b043 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -1,5 +1,9 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; import com.dotcms.ai.model.OpenAIModel; import com.dotcms.ai.model.OpenAIModels; import com.dotcms.ai.model.SimpleModel; @@ -13,17 +17,17 @@ import io.vavr.Lazy; import io.vavr.Tuple; import io.vavr.Tuple2; +import io.vavr.Tuple3; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -40,28 +44,55 @@ public class AIModels { private static final int AI_MODELS_FETCH_ATTEMPTS = Config.getIntProperty(AI_MODELS_FETCH_ATTEMPTS_KEY, 3); private static final String AI_MODELS_FETCH_TIMEOUT_KEY = "ai.models.fetch.timeout"; private static final int AI_MODELS_FETCH_TIMEOUT = Config.getIntProperty(AI_MODELS_FETCH_TIMEOUT_KEY, 4000); - private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private static final String AI_MODELS_API_URL_KEY = "DOT_AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_KEY = "AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_DEFAULT = "https://api.openai.com/v1/models"; private static final String AI_MODELS_API_URL = Config.getStringProperty( AI_MODELS_API_URL_KEY, - "https://api.openai.com/v1/models"); + AI_MODELS_API_URL_DEFAULT); private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours - private static final int AI_MODELS_CACHE_SIZE = 128; + private static final int AI_MODELS_CACHE_SIZE = 256; + private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); - private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = - Caffeine.newBuilder() - .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) - .maximumSize(AI_MODELS_CACHE_SIZE) - .build(); - private Supplier appConfigSupplier = ConfigService.INSTANCE::config; + private final ConcurrentMap>> internalModels; + private final ConcurrentMap, AIModel> modelsByName; + private final Cache> supportedModelsCache; public static AIModels get() { return INSTANCE.get(); } + private static CircuitBreakerUrl.Response fetchOpenAIModels(final String apiKey) { + final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() + .setMethod(CircuitBreakerUrl.Method.GET) + .setUrl(AI_MODELS_API_URL) + .setTimeout(AI_MODELS_FETCH_TIMEOUT) + .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) + .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + apiKey)) + .setThrowWhenNot2xx(true) + .build() + .doResponse(OpenAIModels.class); + + if (!CircuitBreakerUrl.isSuccessResponse(response)) { + AppConfig.debugLogger( + AIModels.class, + () -> String.format( + "Error fetching OpenAI supported models from [%s] (status code: [%d])", + AI_MODELS_API_URL, + response.getStatusCode())); + throw new DotRuntimeException("Error fetching OpenAI supported models"); + } + + return response; + } + private AIModels() { + internalModels = new ConcurrentHashMap<>(); + modelsByName = new ConcurrentHashMap<>(); + supportedModelsCache = + Caffeine.newBuilder() + .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) + .maximumSize(AI_MODELS_CACHE_SIZE) + .build(); } /** @@ -73,46 +104,47 @@ private AIModels() { * @param loading the list of AI models to load */ public void loadModels(final String host, final List loading) { - Optional.ofNullable(internalModels.get(host)) - .ifPresentOrElse( - model -> {}, - () -> internalModels.putIfAbsent( - host, - loading.stream() - .map(model -> Tuple.of(model.getType(), model)) - .collect(Collectors.toList()))); - loading.forEach(model -> model - .getNames() - .forEach(name -> { - final Tuple2 key = Tuple.of( - host, - name.toLowerCase().trim()); + final List> added = internalModels.putIfAbsent( + host, + loading.stream() + .map(model -> Tuple.of(model.getType(), model)) + .collect(Collectors.toList())); + loading.forEach(aiModel -> aiModel + .getModels() + .forEach(model -> { + final Tuple3 key = Tuple.of(host, model, aiModel.getType()); if (modelsByName.containsKey(key)) { - Logger.debug( - this, - String.format( + AppConfig.debugLogger( + getClass(), + () -> String.format( "Model [%s] already exists for host [%s], ignoring it", - name, + model, host)); return; } - modelsByName.putIfAbsent(key, model); + modelsByName.putIfAbsent(key, aiModel); })); + activateModels(host, added == null); } /** * Finds an AI model by the host and model name. The search is case-insensitive. * - * @param host the host for which the model is being searched + * @param appConfig the AppConfig for the host * @param modelName the name of the model to find + * @param type the type of the model to find * @return an Optional containing the found AIModel, or an empty Optional if not found */ - public Optional findModel(final String host, final String modelName) { + public Optional findModel(final AppConfig appConfig, + final String modelName, + final AIModelType type) { final String lowered = modelName.toLowerCase(); - final Set supported = getOrPullSupportedModels(); - return supported.contains(lowered) - ? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered))) - : Optional.empty(); + return Optional.ofNullable( + modelsByName.get( + Tuple.of( + appConfig.getHost(), + Model.builder().withName(lowered).build(), + type))); } /** @@ -130,6 +162,45 @@ public Optional findModel(final String host, final AIModelType type) { .findFirst()); } + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param host the host for which the model is being resolved + * @param type the type of the model to find + */ + public AIModel resolveModel(final String host, final AIModelType type) { + return findModel(host, type).orElse(AIModel.NOOP_MODEL); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + */ + public AIModel resolveAIModelOrThrow(final AppConfig appConfig, final String modelName, final AIModelType type) { + return findModel(appConfig, modelName, type) + .orElseThrow(() -> new DotAIModelNotFoundException( + String.format("Unable to find model: [%s] of type [%s].", modelName, type))); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return a Tuple2 containing the AIModel and the Model + */ + public Tuple2 resolveModelOrThrow(final AppConfig appConfig, + final String modelName, + final AIModelType type) { + final AIModel aiModel = resolveAIModelOrThrow(appConfig, modelName, type); + return Tuple.of(aiModel, aiModel.getModel(modelName)); + } + /** * Resets the internal models cache for the specified host. * @@ -145,27 +216,28 @@ public void resetModels(final String host) { .filter(key -> key._1.equals(host)) .collect(Collectors.toSet()) .forEach(modelsByName::remove); + cleanSupportedModelsCache(); } /** * Retrieves the list of supported models, either from the cache or by fetching them * from an external source if the cache is empty or expired. * + * @param apiKey the API key to use for fetching the supported models * @return a set of supported model names */ - public Set getOrPullSupportedModels() { + public Set getOrPullSupportedModels(final String apiKey) { final Set cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); if (CollectionUtils.isNotEmpty(cached)) { return cached; } - final AppConfig appConfig = appConfigSupplier.get(); - if (!appConfig.isEnabled()) { - AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty list of supported models"); - throw new DotRuntimeException("App dotAI config without API urls or API key"); + if (StringUtils.isBlank(apiKey)) { + Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models"); + throw new DotAIModelNotOperationalException("OpenAI is not enabled, cannot get list of models from OpenAI"); } - final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig); + final CircuitBreakerUrl.Response response = fetchOpenAIModels(apiKey); if (Objects.nonNull(response.getResponse().getError())) { throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage()); } @@ -176,7 +248,9 @@ public Set getOrPullSupportedModels() { .stream() .map(OpenAIModel::getId) .map(String::toLowerCase) + .map(String::trim) .collect(Collectors.toSet()); + supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); return supported; @@ -188,53 +262,50 @@ public Set getOrPullSupportedModels() { * @return a list of available model names */ public List getAvailableModels() { - final Set configured = internalModels.entrySet() + return internalModels.entrySet() .stream() .flatMap(entry -> entry.getValue().stream()) .map(Tuple2::_2) - .flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType()))) - .collect(Collectors.toSet()); - final Set supported = getOrPullSupportedModels() - .stream() - .map(SimpleModel::new) - .collect(Collectors.toSet()); - configured.retainAll(supported); - - return new ArrayList<>(configured); + .filter(AIModel::isOperational) + .flatMap(aiModel -> aiModel.getModels() + .stream() + .filter(Model::isOperational) + .map(model -> new SimpleModel( + model.getName(), + aiModel.getType(), + aiModel.getCurrentModelIndex() == model.getIndex()))) + .distinct() + .collect(Collectors.toList()); } - private static CircuitBreakerUrl.Response fetchOpenAIModels(final AppConfig appConfig) { - final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() - .setMethod(CircuitBreakerUrl.Method.GET) - .setUrl(AI_MODELS_API_URL) - .setTimeout(AI_MODELS_FETCH_TIMEOUT) - .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) - .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) - .setThrowWhenNot2xx(true) - .build() - .doResponse(OpenAIModels.class); + @VisibleForTesting + void cleanSupportedModelsCache() { + supportedModelsCache.invalidate(SUPPORTED_MODELS_KEY); + } - if (!CircuitBreakerUrl.isSuccessResponse(response)) { - Logger.debug( - AIModels.class, - String.format( - "Error fetching OpenAI supported models from [%s] (status code: [%d])", - AI_MODELS_API_URL, - response.getStatusCode())); - throw new DotRuntimeException("Error fetching OpenAI supported models"); + private void activateModels(final String host, boolean wasAdded) { + if (!wasAdded) { + return; } - return response; - } - - @VisibleForTesting - void setAppConfigSupplier(final Supplier appConfigSupplier) { - this.appConfigSupplier = appConfigSupplier; - } + final List aiModels = internalModels.get(host) + .stream() + .map(tuple -> tuple._2) + .collect(Collectors.toList()); - @VisibleForTesting - void cleanSupportedModelsCache() { - supportedModelsCache.invalidateAll(); + aiModels.forEach(aiModel -> + aiModel.getModels().forEach(model -> { + final String modelName = model.getName().trim().toLowerCase(); + final ModelStatus status; + status = ModelStatus.ACTIVE; + if (aiModel.getCurrentModelIndex() == -1) { + aiModel.setCurrentModelIndex(model.getIndex()); + } + Logger.debug( + this, + String.format("Model [%s] is supported by OpenAI, marking it as [%s]", modelName, status)); + model.setStatus(status); + })); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index 1053537f79f9..20df6fb078e7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,11 +1,12 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Config; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.liferay.util.StringPool; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.StringUtils; @@ -116,6 +117,15 @@ public static void setSystemHostConfig(final AppConfig systemHostConfig) { AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig); } + /** + * Retrieves the host. + * + * @return the host + */ + public String getHost() { + return host; + } + /** * Retrieves the API URL. * @@ -137,7 +147,7 @@ public String getApiImageUrl() { /** * Retrieves the API Embeddings URL. * - * @return + * @return the API Embeddings URL */ public String getApiEmbeddingsUrl() { return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl; @@ -287,33 +297,29 @@ public String getConfig(final AppKeys appKey) { * @param type the type of the model to find */ public AIModel resolveModel(final AIModelType type) { - return AIModels.get().findModel(host, type).orElse(AIModel.NOOP_MODEL); + return AIModels.get().resolveModel(host, type); } /** * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. * * @param modelName the name of the model to find + * @param type the type of the model to find */ - public AIModel resolveModelOrThrow(final String modelName) { - final AIModel aiModel = AIModels.get() - .findModel(host, modelName) - .orElseThrow(() -> { - final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels()); - return new DotRuntimeException( - "Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported "); - }); - - if (!aiModel.isOperational()) { - debugLogger( - AppConfig.class, - () -> String.format( - "Resolved model [%s] is not operational, avoiding its usage", - aiModel.getCurrentModel())); - throw new DotRuntimeException(String.format("Model [%s] is not operational", aiModel.getCurrentModel())); - } + public AIModel resolveAIModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveAIModelOrThrow(this, modelName, type); + } - return aiModel; + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return the resolved Model + */ + public Tuple2 resolveModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveModelOrThrow(this, modelName, type); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java index 115439388cc2..50f70eea4ad7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java @@ -7,6 +7,7 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; +import com.google.common.annotations.VisibleForTesting; import com.liferay.portal.model.User; import io.vavr.control.Try; @@ -23,7 +24,12 @@ public class ConfigService { private final LicenseValiditySupplier licenseValiditySupplier; private ConfigService() { - licenseValiditySupplier = new LicenseValiditySupplier() {}; + this(new LicenseValiditySupplier() {}); + } + + @VisibleForTesting + ConfigService(final LicenseValiditySupplier licenseValiditySupplier) { + this.licenseValiditySupplier = licenseValiditySupplier; } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java new file mode 100644 index 000000000000..5393bcaa35ac --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java @@ -0,0 +1,108 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import org.apache.http.client.methods.HttpDelete; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPatch; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpUriRequest; + +import javax.ws.rs.HttpMethod; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Interface representing an AI client capable of sending requests to an AI service. + * + *

+ * This interface defines methods for obtaining the AI provider and sending requests + * to the AI service. Implementations of this interface should handle the specifics + * of interacting with the AI service, including request formatting and response handling. + *

+ * + *

+ * The interface also provides a NOOP implementation that throws an + * {@link UnsupportedOperationException} for all operations. + *

+ * + * @author vico + */ +public interface AIClient { + + AIClient NOOP = new AIClient() { + @Override + public AIProvider getProvider() { + return AIProvider.NONE; + } + + @Override + public void sendRequest(final AIRequest request, final OutputStream output) { + throwUnsupported(); + } + + private void throwUnsupported() { + throw new UnsupportedOperationException("Noop client does not support sending requests"); + } + }; + + /** + * Resolves the appropriate HTTP method for the given method name and URL. + * + * @param method the HTTP method name (e.g., "GET", "POST", "PUT", "DELETE", "patch") + * @param url the URL to which the request will be sent + * @return the corresponding {@link HttpUriRequest} for the given method and URL + */ + static HttpUriRequest resolveMethod(final String method, final String url) { + switch(method) { + case HttpMethod.POST: + return new HttpPost(url); + case HttpMethod.PUT: + return new HttpPut(url); + case HttpMethod.DELETE: + return new HttpDelete(url); + case "patch": + return new HttpPatch(url); + case HttpMethod.GET: + default: + return new HttpGet(url); + } + } + + /** + * Validates and casts the given AI request to a {@link JSONObjectAIRequest}. + * + * @param the type of the request payload + * @param request the AI request to be validated and cast + * @return the validated and cast {@link JSONObjectAIRequest} + * @throws UnsupportedOperationException if the request is not an instance of {@link JSONObjectAIRequest} + */ + static JSONObjectAIRequest useRequestOrThrow(final AIRequest request) { + // When we get rid of JSONObject usage, we can remove this check + if (request instanceof JSONObjectAIRequest) { + return (JSONObjectAIRequest) request; + } + + throw new UnsupportedOperationException("Only JSONObjectAIRequest (JSONObject) is supported"); + } + + /** + * Returns the AI provider associated with this client. + * + * @return the AI provider + */ + AIProvider getProvider(); + + /** + * Sends the given AI request to the AI service and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @throws Exception if any error occurs during the request execution + */ + void sendRequest(AIRequest request, OutputStream output); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java new file mode 100644 index 000000000000..f015c5f83c13 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -0,0 +1,42 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Interface representing a strategy for handling AI client requests and responses. + * + *

+ * This interface defines a method for applying a strategy to an AI client request, + * allowing for different handling mechanisms to be implemented. The NOOP strategy + * is provided as a default implementation that performs no operations. + *

+ * + *

+ * Implementations of this interface should define how to process the AI request + * and handle the response, potentially writing the response to an output stream. + *

+ * + * @author vico + */ +public interface AIClientStrategy { + + AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build(); + + /** + * Applies the strategy to the given AI client request and handles the response. + * + * @param client the AI client to which the request is sent + * @param handler the response evaluator to handle the response + * @param request the AI request to be processed + * @param output the output stream to which the response will be written + */ + void applyStrategy(AIClient client, + AIResponseEvaluator handler, + AIRequest request, + OutputStream output); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java new file mode 100644 index 000000000000..22c3a5aec788 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -0,0 +1,34 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; + +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Default implementation of the {@link AIClientStrategy} interface. + * + *

+ * This class provides a default strategy for handling AI client requests by + * directly sending the request using the provided AI client and writing the + * response to the given output stream. + *

+ * + *

+ * The default strategy does not perform any additional processing or handling + * of the request or response, delegating the entire operation to the AI client. + *

+ * + * @author vico + */ +public class AIDefaultStrategy implements AIClientStrategy { + + @Override + public void applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream output) { + client.sendRequest(request, output); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java new file mode 100644 index 000000000000..a946b883fea5 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -0,0 +1,238 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; +import com.dotcms.ai.validator.AIAppValidator; +import com.dotmarketing.exception.DotRuntimeException; +import io.vavr.Tuple; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.commons.io.IOUtils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +/** + * Implementation of the {@link AIClientStrategy} interface that provides a fallback mechanism + * for handling AI client requests. + * + *

+ * This class attempts to send a request using a primary AI model and, if the request fails, + * it falls back to alternative models until a successful response is obtained or all models + * are exhausted. + *

+ * + *

+ * The fallback strategy ensures that the AI client can continue to function even if some models + * are not operational or fail to process the request. + *

+ * + * @author vico + */ +public class AIModelFallbackStrategy implements AIClientStrategy { + + /** + * Applies the fallback strategy to the given AI client request and handles the response. + * + *

+ * This method first attempts to send the request using the primary model. If the request + * fails, it falls back to alternative models until a successful response is obtained or + * all models are exhausted. + *

+ * + * @param client the AI client to which the request is sent + * @param handler the response evaluator to handle the response + * @param request the AI request to be processed + * @param output the output stream to which the response will be written + * @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained + */ + @Override + public void applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream output) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final Tuple2 modelTuple = resolveModel(jsonRequest); + + final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple); + if (firstAttempt.isSuccess()) { + return; + } + + runFallbacks(client, handler, jsonRequest, output, modelTuple); + } + + private static Tuple2 resolveModel(final JSONObjectAIRequest request) { + final String modelName = request.getPayload().optString(AiKeys.MODEL); + return request.getConfig().resolveModelOrThrow(modelName, request.getType()); + } + + private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { + if (firstAttempt.equals(model)) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the same as the current one [%s].", + model.getName(), + firstAttempt.getName())); + return true; + } + + return false; + } + + private static boolean isOperational(final Model model) { + if (!model.isOperational()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] is not operational. Skipping.", model.getName())); + return false; + } + + return true; + } + + private static AIResponseData doSend(final AIClient client, final AIRequest request) { + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + client.sendRequest(request, output); + + final AIResponseData responseData = new AIResponseData(); + responseData.setResponse(output.toString()); + IOUtils.closeQuietly(output); + + return responseData; + } + + private static void redirectOutput(final OutputStream output, final String response) { + try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { + IOUtils.copy(input, output); + } catch (IOException e) { + throw new DotRuntimeException(e); + } + } + + private static void notifyFailure(final AIModel aiModel, final AIRequest request) { + AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId()); + } + + private static void handleFailure(final Tuple2 modelTuple, + final AIRequest request, + final AIResponseData responseData) { + final AIModel aiModel = modelTuple._1; + final Model model = modelTuple._2; + + if (!responseData.getStatus().doesNeedToThrow()) { + model.setStatus(responseData.getStatus()); + } + + if (model.getIndex() == aiModel.getModels().size() - 1) { + aiModel.setCurrentModelIndex(-1); + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the last one. Cannot fallback anymore.", + model.getName())); + + notifyFailure(aiModel, request); + + throw new DotAIAllModelsExhaustedException( + String.format("All models for type [%s] has been exhausted.", aiModel.getType())); + } else { + aiModel.setCurrentModelIndex(model.getIndex() + 1); + } + } + + private static AIResponseData sendAttempt(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream output, + final Tuple2 modelTuple) { + + final AIResponseData responseData = Try + .of(() -> doSend(client, request)) + .getOrElseGet(exception -> fromException(evaluator, exception)); + + if (!responseData.isSuccess()) { + if (responseData.getStatus().doesNeedToThrow()) { + throw responseData.getException(); + } + } else { + evaluator.fromResponse(responseData.getResponse(), responseData, output instanceof ByteArrayOutputStream); + } + + if (responseData.isSuccess()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); + redirectOutput(output, responseData.getResponse()); + } else { + logFailure(modelTuple, responseData); + + handleFailure(modelTuple, request, responseData); + } + + return responseData; + } + + private static void logFailure(final Tuple2 modelTuple, final AIResponseData responseData) { + Optional + .ofNullable(responseData.getResponse()) + .ifPresentOrElse( + response -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with response:%s%s%s. Trying next model.", + modelTuple._2.getName(), + System.lineSeparator(), + response, + System.lineSeparator())), + () -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with error: [%s]. Trying next model.", + modelTuple._2.getName(), + responseData.getError()))); + } + + private static AIResponseData fromException(final AIResponseEvaluator evaluator, final Throwable exception) { + final AIResponseData metadata = new AIResponseData(); + evaluator.fromException(exception, metadata); + return metadata; + } + + private static void runFallbacks(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream output, + final Tuple2 modelTuple) { + for(final Model model : modelTuple._1.getModels()) { + if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) { + continue; + } + + request.getPayload().put(AiKeys.MODEL, model.getName()); + final AIResponseData responseData = sendAttempt( + client, + evaluator, + request, + output, + Tuple.of(modelTuple._1, model)); + if (responseData.isSuccess()) { + return; + } + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java new file mode 100644 index 000000000000..bd6020317f2a --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -0,0 +1,86 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Optional; + +/** + * A proxy client for interacting with an AI service using a specified strategy. + * + *

+ * This class provides a mechanism to send requests to an AI service through a proxied client, + * applying a given strategy for handling the requests and responses. It supports a NOOP implementation + * that performs no operations. + *

+ * + *

+ * The class allows for the creation of proxied clients with different strategies and response evaluators, + * enabling flexible handling of AI service interactions. + *

+ * + * @author vico + */ +public class AIProxiedClient { + + public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP, null); + + private final AIClient client; + private final AIClientStrategy strategy; + private final AIResponseEvaluator responseEvaluator; + + private AIProxiedClient(final AIClient client, + final AIClientStrategy strategy, + final AIResponseEvaluator responseEvaluator) { + this.client = client; + this.strategy = strategy; + this.responseEvaluator = responseEvaluator; + } + + /** + * Creates an AIProxiedClient with the specified client, strategy, and response evaluator. + * + * @param client the AI client to be proxied + * @param strategy the strategy to be applied for handling requests and responses + * @param responseParser the response evaluator to process responses + * @return a new instance of AIProxiedClient + */ + public static AIProxiedClient of(final AIClient client, + final AIProxyStrategy strategy, + final AIResponseEvaluator responseParser) { + return new AIProxiedClient(client, strategy.getStrategy(), responseParser); + } + + /** + * Creates an AIProxiedClient with the specified client and strategy. + * + * @param client the AI client to be proxied + * @param strategy the strategy to be applied for handling requests and responses + * @return a new instance of AIProxiedClient + */ + public static AIProxiedClient of(final AIClient client, final AIProxyStrategy strategy) { + return of(client, strategy, null); + } + + /** + * Sends the given AI request to the AI service and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse sendToAI(final AIRequest request, final OutputStream output) { + final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new); + + strategy.applyStrategy(client, responseEvaluator, request, finalOutput); + + return Optional.ofNullable(output) + .map(out -> AIResponse.EMPTY) + .orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build()); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java new file mode 100644 index 000000000000..9be824627675 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java @@ -0,0 +1,121 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.client.openai.OpenAIClient; +import com.dotcms.ai.client.openai.OpenAIResponseEvaluator; +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import com.google.common.annotations.VisibleForTesting; +import io.vavr.Lazy; + +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A proxy client for managing and interacting with multiple AI service providers. + * + *

+ * This class provides a mechanism to send requests to various AI service providers through proxied clients, + * applying different strategies for handling the requests and responses. It supports adding new clients and + * switching between different AI providers. + *

+ * + *

+ * The class allows for flexible handling of AI service interactions by maintaining a map of proxied clients + * and providing methods to send requests to the current or specified provider. + *

+ * + * @author vico + */ +public class AIProxyClient { + + private static final Lazy INSTANCE = Lazy.of(AIProxyClient::new); + + private final ConcurrentMap proxiedClients; + private final AtomicReference currentProvider; + + private AIProxyClient() { + proxiedClients = new ConcurrentHashMap<>(); + addClient( + AIProvider.OPEN_AI, + AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseEvaluator.get())); + currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); + } + + @VisibleForTesting + AIProxyClient(final AIProxiedClient client) { + proxiedClients = new ConcurrentHashMap<>(); + addClient(AIProvider.OPEN_AI, client); + currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); + } + + public static AIProxyClient get() { + return INSTANCE.get(); + } + + /** + * Adds a proxied client for the specified AI provider. + * + * @param provider the AI provider for which the client is added + * @param client the proxied client to be added + */ + public void addClient(final AIProvider provider, final AIProxiedClient client) { + proxiedClients.put(provider, client); + } + + /** + * Sends the given AI request to the specified AI provider and writes the response to the provided output stream. + * + * @param provider the AI provider to which the request is sent + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse callToAI(final AIProvider provider, + final AIRequest request, + final OutputStream output) { + return Optional.ofNullable(proxiedClients.getOrDefault(provider, AIProxiedClient.NOOP)) + .map(client -> client.sendToAI(request, output)) + .orElse(AIResponse.EMPTY); + } + + /** + * Sends the given AI request to the specified AI provider. + * + * @param the type of the request payload + * @param provider the AI provider to which the request is sent + * @param request the AI request to be sent + * @return the AI response + */ + public AIResponse callToAI(final AIProvider provider, final AIRequest request) { + return callToAI(provider, request, null); + } + + /** + * Sends the given AI request to the current AI provider and writes the response to the provided output stream. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @param output the output stream to which the response will be written + * @return the AI response + */ + public AIResponse callToAI(final AIRequest request, final OutputStream output) { + return callToAI(currentProvider.get(), request, output); + } + + /** + * Sends the given AI request to the current AI provider. + * + * @param the type of the request payload + * @param request the AI request to be sent + * @return the AI response + */ + public AIResponse callToAI(final AIRequest request) { + return callToAI(request, null); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java new file mode 100644 index 000000000000..08b2c34f0a6b --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java @@ -0,0 +1,34 @@ +package com.dotcms.ai.client; + +/** + * Enumeration representing different strategies for proxying AI client requests. + * + *

+ * This enum provides different strategies for handling AI client requests, including + * a default strategy and a model fallback strategy. Each strategy is associated with + * an implementation of the {@link AIClientStrategy} interface. + *

+ * + *

+ * The strategies can be used to customize the behavior of AI client interactions, + * allowing for flexible handling of requests and responses. + *

+ * + * @author vico + */ +public enum AIProxyStrategy { + + DEFAULT(new AIDefaultStrategy()), + MODEL_FALLBACK(new AIModelFallbackStrategy()); + + private final AIClientStrategy strategy; + + AIProxyStrategy(final AIClientStrategy strategy) { + this.strategy = strategy; + } + + public AIClientStrategy getStrategy() { + return strategy; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java new file mode 100644 index 000000000000..d7f8d2ba5ce4 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseEvaluator.java @@ -0,0 +1,36 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIResponseData; + +/** + * Interface for evaluating AI responses. + * It provides methods to process responses and exceptions, updating the provided metadata. + * + *

Methods:

+ *
    + *
  • \fromResponse\ - Processes a response string and updates the metadata.
  • + *
  • \fromThrowable\ - Processes an exception and updates the metadata.
  • + *
+ * + * @author vico + */ +public interface AIResponseEvaluator { + + /** + * Processes a response string and updates the metadata. + * + * @param response the response string to process + * @param metadata the metadata to update based on the response + * @param jsonExpected flag for expecting the response to be a JSON + */ + void fromResponse(String response, AIResponseData metadata, boolean jsonExpected); + + /** + * Processes an exception and updates the metadata. + * + * @param exception the exception to process + * @param metadata the metadata to update based on the exception + */ + void fromException(Throwable exception, AIResponseData metadata); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java new file mode 100644 index 000000000000..a698477bd381 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -0,0 +1,159 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.client.AIClient; +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIAppConfigDisabledException; +import com.dotcms.ai.exception.DotAIClientConnectException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; +import com.dotmarketing.util.Logger; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; + +import javax.ws.rs.core.MediaType; +import java.io.BufferedInputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of the {@link AIClient} interface for interacting with the OpenAI service. + * + *

+ * This class provides methods to send requests to the OpenAI service and handle responses. + * It includes functionality to manage rate limiting and ensure that models are operational + * before sending requests. + *

+ * + *

+ * The class uses a singleton pattern to ensure a single instance of the client is used + * throughout the application. It also maintains a record of the last REST call for each + * model to enforce rate limiting. + *

+ * + * @auhor vico + */ +public class OpenAIClient implements AIClient { + + private static final Lazy INSTANCE = Lazy.of(OpenAIClient::new); + + private final ConcurrentHashMap lastRestCall; + + public static OpenAIClient get() { + return INSTANCE.get(); + } + + private OpenAIClient() { + lastRestCall = new ConcurrentHashMap<>(); + } + + /** + * {@inheritDoc} + */ + @Override + public AIProvider getProvider() { + return AIProvider.OPEN_AI; + } + + /** + * {@inheritDoc} + */ + @Override + public void sendRequest(final AIRequest request, final OutputStream output) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final AppConfig appConfig = jsonRequest.getConfig(); + + AppConfig.debugLogger( + OpenAIClient.class, + () -> String.format( + "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", + jsonRequest.getUrl(), + jsonRequest.getMethod(), + System.lineSeparator(), + appConfig.toString(), + System.lineSeparator(), + jsonRequest.payloadToString())); + + if (!appConfig.isEnabled()) { + AppConfig.debugLogger(OpenAIClient.class, () -> "App dotAI is not enabled and will not send request."); + throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key"); + } + + final JSONObject payload = jsonRequest.getPayload(); + final String modelName = payload.optString(AiKeys.MODEL); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(modelName, jsonRequest.getType()); + final AIModel aiModel = modelTuple._1; + + if (!modelTuple._2.isOperational()) { + AppConfig.debugLogger( + getClass(), + () -> String.format("Resolved model [%s] is not operational, avoiding its usage", modelName)); + throw new DotAIModelNotOperationalException(String.format("Model [%s] is not operational", modelName)); + } + + final long sleep = lastRestCall.computeIfAbsent(aiModel, m -> 0L) + + aiModel.minIntervalBetweenCalls() + - System.currentTimeMillis(); + if (sleep > 0) { + Logger.info( + this, + "Rate limit:" + + aiModel.getApiPerMinute() + + "/minute, or 1 every " + + aiModel.minIntervalBetweenCalls() + + "ms. Sleeping:" + + sleep); + Try.run(() -> Thread.sleep(sleep)); + } + + lastRestCall.put(aiModel, System.currentTimeMillis()); + + try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + final StringEntity jsonEntity = new StringEntity(payload.toString(), ContentType.APPLICATION_JSON); + final HttpUriRequest httpRequest = AIClient.resolveMethod(jsonRequest.getMethod(), jsonRequest.getUrl()); + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); + httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); + + if (!payload.getAsMap().isEmpty()) { + Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); + } + + try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { + final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); + final byte[] buffer = new byte[1024]; + int len; + while ((len = in.read(buffer)) != -1) { + output.write(buffer, 0, len); + output.flush(); + } + } + } catch (Exception e) { + if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)){ + Logger.warn(this, "INVALID REQUEST: " + e.getMessage(), e); + } else { + Logger.warn(this, "INVALID REQUEST: " + e.getMessage()); + } + + Logger.warn(this, " - " + jsonRequest.getMethod() + " : " + payload); + + throw new DotAIClientConnectException("Error while sending request to OpenAI", e); + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java new file mode 100644 index 000000000000..d7c3f84cf7cf --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluator.java @@ -0,0 +1,91 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.client.AIResponseEvaluator; +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.exception.DotAIModelNotOperationalException; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; + +import java.util.Optional; +import java.util.stream.Stream; + +/** + * Evaluates AI responses from OpenAI and updates the provided metadata. + * This class implements the singleton pattern and provides methods to process responses and exceptions. + * + *

Methods:

+ *
    + *
  • \fromResponse\ - Processes a response string and updates the metadata.
  • + *
  • \fromThrowable\ - Processes an exception and updates the metadata.
  • + *
+ * + * @author vico + */ +public class OpenAIResponseEvaluator implements AIResponseEvaluator { + + private static final String JSON_ERROR_FIELD = "\"error\":"; + private static final Lazy INSTANCE = Lazy.of(OpenAIResponseEvaluator::new); + + public static OpenAIResponseEvaluator get() { + return INSTANCE.get(); + } + + private OpenAIResponseEvaluator() { + } + + /** + * {@inheritDoc} + */ + @Override + public void fromResponse(final String response, final AIResponseData metadata, final boolean jsonExpected) { + Optional.ofNullable(response) + .ifPresent(resp -> { + if (jsonExpected || resp.contains(JSON_ERROR_FIELD)) { + final JSONObject jsonResponse = new JSONObject(resp); + if (jsonResponse.has(AiKeys.ERROR)) { + final JSONObject error = jsonResponse.getJSONObject(AiKeys.ERROR); + final String message = error.getString(AiKeys.MESSAGE); + metadata.setError(message); + metadata.setStatus(resolveStatus(message)); + } + } + }); + } + + /** + * {@inheritDoc} + */ + @Override + public void fromException(final Throwable exception, final AIResponseData metadata) { + metadata.setError(exception.getMessage()); + metadata.setStatus(resolveStatus(exception)); + metadata.setException(exception instanceof DotRuntimeException + ? (DotRuntimeException) exception + : new DotRuntimeException(exception)); + } + + private ModelStatus resolveStatus(final String error) { + if (error.contains("has been deprecated")) { + return ModelStatus.DECOMMISSIONED; + } else if (error.contains("does not exist or you do not have access to it")) { + return ModelStatus.INVALID; + } else { + return ModelStatus.UNKNOWN; + } + } + + private ModelStatus resolveStatus(final Throwable throwable) { + if (Stream + .of(DotAIModelNotFoundException.class, DotAIModelNotOperationalException.class) + .anyMatch(exception -> exception.isInstance(throwable))) { + return ModelStatus.INVALID; + } else { + return ModelStatus.UNKNOWN; + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java b/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java index a5afa4b2bdc9..d621b2e0cd57 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java @@ -87,7 +87,8 @@ public static Builder from(final CompletionsForm form) { .withOperator(form.operator) .withThreshold(form.threshold) .withTemperature(form.temperature) - .withTokenCount(form.responseLengthTokens); + .withTokenCount(form.responseLengthTokens) + .withUser(form.user); } public static Builder from(final Map form) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java new file mode 100644 index 000000000000..9e844d47619f --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java @@ -0,0 +1,35 @@ +package com.dotcms.ai.domain; + +/** + * Enumeration representing different AI service providers. + * + *

+ * This enum defines various AI service providers that can be used within the application. + * Each provider is associated with a specific name that identifies the AI service. + *

+ * + *

+ * The providers can be used to configure and manage interactions with different AI services, + * allowing for flexible integration and switching between multiple AI providers. + *

+ * + * @author vico + */ +public enum AIProvider { + + NONE("None"), + OPEN_AI("OpenAI"), + BEDROCK("Amazon Bedrock"), + GEMINI("Google Gemini"); + + private final String provider; + + AIProvider(final String provider) { + this.provider = provider; + } + + public String getProvider() { + return provider; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java new file mode 100644 index 000000000000..2635267c3e2a --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java @@ -0,0 +1,219 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; + +import javax.ws.rs.HttpMethod; +import java.io.Serializable; + +/** + * Represents a request to an AI service. + * + *

+ * This class encapsulates the details of an AI request, including the URL, HTTP method, + * configuration, model type, payload, and user ID. It provides methods to create and + * configure AI requests for different model types such as text, image, and embeddings. + *

+ * + * @param the type of the request payload + * @author vico + */ +public class AIRequest { + + private final String url; + private final String method; + private final AppConfig config; + private final AIModelType type; + private final T payload; + private final String userId; + + > AIRequest(final Builder builder) { + this.url = builder.url; + this.method = builder.method; + this.config = builder.config; + this.type = builder.type; + this.payload = builder.payload; + this.userId = builder.userId; + } + + /** + * Creates a quick text AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickText(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.TEXT, appConfig, payload, userId); + } + + /** + * Creates a quick image AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickImage(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.IMAGE, appConfig, payload, userId); + } + + /** + * Creates a quick embeddings AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @param the type of the request payload + * @param the type of the AIRequest + * @return a new AIRequest instance + */ + public static > R quickEmbeddings(final AppConfig appConfig, + final T payload, + final String userId) { + return quick(AIModelType.EMBEDDINGS, appConfig, payload, userId); + } + + public static > Builder builder() { + return new Builder<>(); + } + + /** + * Resolves the URL for the specified model type and application configuration. + * + * @param type the AI model type + * @param appConfig the application configuration + * @return the resolved URL + */ + static String resolveUrl(final AIModelType type, final AppConfig appConfig) { + switch (type) { + case TEXT: + return appConfig.getApiUrl(); + case IMAGE: + return appConfig.getApiImageUrl(); + case EMBEDDINGS: + return appConfig.getApiEmbeddingsUrl(); + default: + throw new IllegalArgumentException("Invalid AIModelType: " + type); + } + } + + @SuppressWarnings("unchecked") + private static , R extends AIRequest> R quick( + final String url, + final AppConfig appConfig, + final AIModelType type, + final T payload, + final String usderId) { + return (R) AIRequest.builder() + .withUrl(url) + .withConfig(appConfig) + .withType(type) + .withPayload(payload) + .withUserId(usderId) + .build(); + } + + private static > R quick( + final AIModelType type, + final AppConfig appConfig, + final T payload, + final String userId) { + return quick(resolveUrl(type, appConfig), appConfig, type, payload, userId); + } + + public String getUrl() { + return url; + } + + public String getMethod() { + return method; + } + + public AppConfig getConfig() { + return config; + } + + public AIModelType getType() { + return type; + } + + public T getPayload() { + return payload; + } + + public String getUserId() { + return userId; + } + + @Override + public String toString() { + return "AIRequest{" + + "url='" + url + '\'' + + ", method='" + method + '\'' + + ", config=" + config + + ", type=" + type + + ", payload=" + payloadToString() + + ", userId='" + userId + '\'' + + '}'; + } + + public String payloadToString() { + return payload.toString(); + } + + public static class Builder> { + + String url; + String method = HttpMethod.POST; + AppConfig config; + AIModelType type; + T payload; + String userId; + + @SuppressWarnings("unchecked") + B self() { + return (B) this; + } + + public B withUrl(final String url) { + this.url = url; + return self(); + } + + public B withConfig(final AppConfig config) { + this.config = config; + return self(); + } + + public B withType(final AIModelType type) { + this.type = type; + return self(); + } + + public B withPayload(final T payload) { + this.payload = payload; + return self(); + } + + public B withUserId(final String userId) { + this.userId = userId; + return self(); + } + + public AIRequest build() { + return new AIRequest<>(this); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java new file mode 100644 index 000000000000..8d9887b24571 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java @@ -0,0 +1,50 @@ +package com.dotcms.ai.domain; + +/** + * Represents a response from an AI service. + * + *

+ * This class encapsulates the details of an AI response, including the response content. + * It provides methods to build and retrieve the response. + *

+ * + *

+ * The class also provides a static instance representing an empty response. + *

+ * + * @author vico + */ +public class AIResponse { + + public static final AIResponse EMPTY = builder().build(); + + private final String response; + + private AIResponse(final Builder builder) { + this.response = builder.response; + } + + public static Builder builder() { + return new Builder(); + } + + public String getResponse() { + return response; + } + + public static class Builder { + + private String response; + + public Builder withResponse(final String response) { + this.response = response; + return this; + } + + + public AIResponse build() { + return new AIResponse(this); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java new file mode 100644 index 000000000000..85ac2d9d0483 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java @@ -0,0 +1,68 @@ +package com.dotcms.ai.domain; + +import com.dotmarketing.exception.DotRuntimeException; +import org.apache.commons.lang3.StringUtils; + +/** + * Represents the data of a response from an AI service. + * + *

+ * This class encapsulates the details of an AI response, including the response content, error message, + * status, and any exceptions that may have occurred. It provides methods to retrieve and set these details, + * as well as a method to check if the response was successful. + *

+ * + * @author vico + */ +public class AIResponseData { + + private String response; + private String error; + private ModelStatus status; + private DotRuntimeException exception; + + public String getResponse() { + return response; + } + + public void setResponse(String response) { + this.response = response; + } + + public String getError() { + return error; + } + + public void setError(final String error) { + this.error = error; + } + + public ModelStatus getStatus() { + return status; + } + + public void setStatus(ModelStatus status) { + this.status = status; + } + + public DotRuntimeException getException() { + return exception; + } + + public void setException(DotRuntimeException exception) { + this.exception = exception; + } + + public boolean isSuccess() { + return StringUtils.isBlank(error); + } + + @Override + public String toString() { + return "AIResponseData{" + + "response='" + response + '\'' + + ", error='" + error + '\'' + + ", status=" + status + + '}'; + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java new file mode 100644 index 000000000000..8758f3a795cc --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java @@ -0,0 +1,106 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; +import com.dotmarketing.util.json.JSONObject; + +/** + * Represents a request to an AI service with a JSON payload. + * + *

+ * This class encapsulates the details of an AI request with a JSON payload, including the URL, HTTP method, + * configuration, model type, payload, and user ID. It provides methods to create and configure AI requests + * for different model types such as text, image, and embeddings. + *

+ * + * @author vico + */ +public class JSONObjectAIRequest extends AIRequest { + + JSONObjectAIRequest(final Builder builder) { + super(builder); + } + + /** + * Creates a quick text AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickText(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.TEXT, appConfig, payload, userId); + } + + /** + * Creates a quick image AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickImage(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.IMAGE, appConfig, payload, userId); + } + + /** + * Creates a quick embeddings AI request with the specified configuration, payload, and user ID. + * + * @param appConfig the application configuration + * @param payload the request payload + * @param userId the user ID + * @return a new JSONObjectAIRequest instance + */ + public static JSONObjectAIRequest quickEmbeddings(final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(AIModelType.EMBEDDINGS, appConfig, payload, userId); + } + + private static JSONObjectAIRequest quick(final String url, + final AppConfig appConfig, + final AIModelType type, + final JSONObject payload, + final String userId) { + return JSONObjectAIRequest.builder() + .withUrl(url) + .withConfig(appConfig) + .withType(type) + .withPayload(payload) + .withUserId(userId) + .build(); + } + + private static JSONObjectAIRequest quick(final AIModelType type, + final AppConfig appConfig, + final JSONObject payload, + final String userId) { + return quick(resolveUrl(type, appConfig), appConfig, type, payload, userId); + } + + @Override + public String payloadToString() { + return getPayload().toString(2); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends AIRequest.Builder { + + @Override + public JSONObjectAIRequest build() { + return new JSONObjectAIRequest(this); + } + + } + + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java new file mode 100644 index 000000000000..7b2b9ca150ba --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java @@ -0,0 +1,104 @@ +package com.dotcms.ai.domain; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Represents an AI model with a name, status, and index. + * + *

+ * This class encapsulates the details of an AI model, including its name, status, and index. + * It provides methods to retrieve and set these details, as well as methods to check if the model is operational. + *

+ * + *

+ * The class also provides a builder for constructing instances of the model. + *

+ * + * @author vico + */ +public class Model { + + private final String name; + private final AtomicReference status; + private final AtomicInteger index; + + private Model(final Builder builder) { + name = builder.name; + status = new AtomicReference<>(null); + index = new AtomicInteger(builder.index); + } + + public static Builder builder() { + return new Builder(); + } + + public String getName() { + return name; + } + + public ModelStatus getStatus() { + return status.get(); + } + + public void setStatus(final ModelStatus status) { + this.status.set(status); + } + + public int getIndex() { + return index.get(); + } + + public void setIndex(final int index) { + this.index.set(index); + } + + public boolean isOperational() { + return ModelStatus.ACTIVE == status.get(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Model model = (Model) o; + return Objects.equals(name, model.name); + } + + @Override + public int hashCode() { + return Objects.hashCode(name); + } + + @Override + public String toString() { + return "Model{" + + "name='" + name + '\'' + + ", status=" + status + + ", index=" + index.get() + + '}'; + } + + public static class Builder { + + private String name; + private int index; + + public Builder withName(final String name) { + this.name = name.toLowerCase().trim(); + return this; + } + + public Builder withIndex(final int index) { + this.index = index; + return this; + } + + public Model build() { + return new Model(this); + } + + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java b/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java new file mode 100644 index 000000000000..15aa6bd9b69c --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/ModelStatus.java @@ -0,0 +1,30 @@ +package com.dotcms.ai.domain; + +/** + * Represents the status of an AI model. + * + *

+ * This enum defines various statuses that an AI model can have, such as active, invalid, decommissioned, or unknown. + * Each status may have different implications for the operation of the model. + *

+ * + * @author vico + */ +public enum ModelStatus { + + ACTIVE(false), + INVALID(false), + DECOMMISSIONED(false), + UNKNOWN(true); + + private final boolean needsToThrow; + + ModelStatus(final boolean needsToThrow) { + this.needsToThrow = needsToThrow; + } + + public boolean doesNeedToThrow() { + return needsToThrow; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java new file mode 100644 index 000000000000..6833fdabb252 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAllModelsExhaustedException.java @@ -0,0 +1,22 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when all AI models have been exhausted. + * + *

+ * This exception is used to indicate that all available AI models have been exhausted and no further models + * are available for processing. It extends the {@link DotRuntimeException} to provide additional context + * specific to AI model exhaustion scenarios. + *

+ * + * @author vico + */ +public class DotAIAllModelsExhaustedException extends DotRuntimeException { + + public DotAIAllModelsExhaustedException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java new file mode 100644 index 000000000000..5b549c74fed5 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIAppConfigDisabledException.java @@ -0,0 +1,22 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when the AI application configuration is disabled. + * + *

+ * This exception is used to indicate that the AI application configuration is disabled and cannot be used. + * It extends the {@link DotRuntimeException} to provide additional context specific to AI application configuration + * disabled scenarios. + *

+ * + * @author vico + */ +public class DotAIAppConfigDisabledException extends DotRuntimeException { + + public DotAIAppConfigDisabledException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java new file mode 100644 index 000000000000..e17bf9b8d09f --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIClientConnectException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when there is a connection error with the AI client. + * + *

+ * This exception is used to indicate that there is a connection error with the AI client. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI client connection error scenarios. + *

+ * + * @author vico + */ +public class DotAIClientConnectException extends DotRuntimeException { + + public DotAIClientConnectException(final String message, final Throwable cause) { + super(message, cause); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java new file mode 100644 index 000000000000..3bd70811b123 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFoundException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when an AI model is not found. + * + *

+ * This exception is used to indicate that a specific AI model could not be found. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI model not found scenarios. + *

+ * + * @author vico + */ +public class DotAIModelNotFoundException extends DotRuntimeException { + + public DotAIModelNotFoundException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java new file mode 100644 index 000000000000..ddea5f6866f8 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperationalException.java @@ -0,0 +1,21 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +/** + * Exception thrown when there is a connection error with the AI client. + * + *

+ * This exception is used to indicate that there is a connection error with the AI client. It extends the {@link DotRuntimeException} + * to provide additional context specific to AI client connection error scenarios. + *

+ * + * @author vico + */ +public class DotAIModelNotOperationalException extends DotRuntimeException { + + public DotAIModelNotOperationalException(final String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java index 32bd759b58b2..a95b8f8df520 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java @@ -1,8 +1,10 @@ package com.dotcms.ai.listener; import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.validator.AIAppValidator; import com.dotcms.security.apps.AppSecretSavedEvent; import com.dotcms.system.event.local.model.EventSubscriber; import com.dotcms.system.event.local.model.KeyFilterable; @@ -35,6 +37,21 @@ public AIAppListener() { this(APILocator.getHostAPI()); } + /** + * Notifies the listener of an {@link AppSecretSavedEvent}. + * + *

+ * This method is called when an {@link AppSecretSavedEvent} occurs. It performs the following actions: + *

    + *
  • Logs a debug message if the event is null or the event's host identifier is blank.
  • + *
  • Finds the host associated with the event's host identifier.
  • + *
  • Resets the AI models for the found host's hostname.
  • + *
  • Validates the AI configuration using the {@link AIAppValidator}.
  • + *
+ *

+ * + * @param event the {@link AppSecretSavedEvent} that triggered the notification + */ @Override public void notify(final AppSecretSavedEvent event) { if (Objects.isNull(event)) { @@ -51,7 +68,9 @@ public void notify(final AppSecretSavedEvent event) { final Host host = Try.of(() -> hostAPI.find(hostId, APILocator.systemUser(), false)).getOrNull(); Optional.ofNullable(host).ifPresent(found -> AIModels.get().resetModels(found.getHostname())); - ConfigService.INSTANCE.config(host); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + + AIAppValidator.get().validateAIConfig(appConfig, event.getUserId()); } @Override diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index 9739bab313eb..5c5a7b24d5ef 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -3,6 +3,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.exception.DotAIAppConfigDisabledException; import com.dotcms.content.elasticsearch.business.event.ContentletArchiveEvent; import com.dotcms.content.elasticsearch.business.event.ContentletDeletedEvent; import com.dotcms.content.elasticsearch.business.event.ContentletPublishEvent; @@ -10,7 +11,6 @@ import com.dotcms.system.event.local.model.Subscriber; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.portlets.contentlet.model.ContentletListener; import com.dotmarketing.util.Logger; @@ -86,7 +86,7 @@ private AppConfig getAppConfig(final String hostId) { AppConfig.debugLogger( getClass(), () -> "dotAI is not enabled since no API urls or API key found in app config"); - throw new DotRuntimeException("App dotAI config without API urls or API key"); + throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key"); } return appConfig; diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java index 53f83c3ab149..e289b64c9a0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java @@ -1,6 +1,7 @@ package com.dotcms.ai.model; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; @@ -15,12 +16,11 @@ public class AIImageRequestDTO { private final String model; - public AIImageRequestDTO(Builder builder) { + public AIImageRequestDTO(final Builder builder) { this.numberOfImages = builder.numberOfImages; this.model = builder.model; this.prompt = builder.prompt; this.size = builder.size; - } public String getSize() { @@ -40,14 +40,15 @@ public String getModel() { } public static class Builder { + private AppConfig appConfig = ConfigService.INSTANCE.config(); @JsonSetter(nulls = Nulls.SKIP) private String prompt; @JsonSetter(nulls = Nulls.SKIP) private int numberOfImages = 1; @JsonSetter(nulls = Nulls.SKIP) - private String size = ConfigService.INSTANCE.config().getImageSize(); + private String size = appConfig.getImageSize(); @JsonSetter(nulls = Nulls.SKIP) - private String model = ConfigService.INSTANCE.config().getImageModel().getCurrentModel(); + private String model = appConfig.getImageModel().getCurrentModel(); public AIImageRequestDTO build() { return new AIImageRequestDTO(this); diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java index c5486b61191f..b24e042e853f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java @@ -17,16 +17,20 @@ public class SimpleModel implements Serializable { private final String name; private final AIModelType type; + private final boolean current; @JsonCreator - public SimpleModel(@JsonProperty("name") final String name, @JsonProperty("type") final AIModelType type) { + public SimpleModel(@JsonProperty("name") final String name, + @JsonProperty("type") final AIModelType type, + @JsonProperty("current") final boolean current) { this.name = name; this.type = type; + this.current = current; } @JsonCreator public SimpleModel(@JsonProperty("name") final String name) { - this(name, null); + this(name, null, false); } public String getName() { @@ -37,17 +41,30 @@ public AIModelType getType() { return type; } + public boolean isCurrent() { + return current; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SimpleModel that = (SimpleModel) o; - return Objects.equals(name, that.name); + return Objects.equals(name, that.name) && type == that.type; } @Override public int hashCode() { - return Objects.hashCode(name); + return Objects.hash(name, type); + } + + @Override + public String toString() { + return "SimpleModel{" + + "name='" + name + '\'' + + ", type=" + type + + ", current=" + current + + '}'; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index e7b62cf46712..5499de4ce660 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -61,7 +61,9 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().summarizeStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .summarizeStream(formIn, new LineReadingOutputStream(output))); } /** @@ -84,7 +86,9 @@ public final Response rawPrompt(@Context final HttpServletRequest request, response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().rawStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .rawStream(formIn, new LineReadingOutputStream(output))); } /** @@ -107,16 +111,15 @@ public final Response getConfig(@Context final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - final AppConfig app = ConfigService.INSTANCE.config(host); - + final AppConfig appConfig = ConfigService.INSTANCE.config(host); final Map map = new HashMap<>(); map.put(AiKeys.CONFIG_HOST, host.getHostname() + " (falls back to system host)"); for (final AppKeys config : AppKeys.values()) { - map.put(config.key, app.getConfig(config)); + map.put(config.key, appConfig.getConfig(config)); } - final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET"; + final String apiKey = UtilMethods.isSet(appConfig.getApiKey()) ? "*****" : "NOT SET"; map.put(AppKeys.API_KEY.key, apiKey); final List models = AIModels.get().getAvailableModels(); @@ -140,19 +143,25 @@ private static CompletionsForm resolveForm(final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - return (!user.isAdmin()) - ? CompletionsForm - .copy(formIn) - .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) - .build() - : formIn; + return withUserId( + !user.isAdmin() + ? CompletionsForm + .copy(formIn) + .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) + .build() + : formIn, + user); + } + + private static CompletionsForm withUserId(final CompletionsForm completionsForm, final User user) { + return CompletionsForm.copy(completionsForm).user(user).build(); } private static Response getResponse(final HttpServletRequest request, final HttpServletResponse response, final CompletionsForm formIn, final Supplier noStream, - final Consumer stream) { + final Consumer outputStream) { if (StringUtils.isBlank(formIn.prompt)) { return badRequestResponse(); } @@ -162,7 +171,7 @@ private static Response getResponse(final HttpServletRequest request, if (resolvedForm.stream) { final StreamingOutput streaming = output -> { - stream.accept(output); + outputStream.accept(output); output.flush(); output.close(); }; @@ -174,5 +183,4 @@ private static Response getResponse(final HttpServletRequest request, return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java index 375625d58adf..e536de66e87c 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java @@ -1,7 +1,7 @@ package com.dotcms.ai.rest; import com.dotcms.ai.AiKeys; -import com.dotcms.ai.Marshaller; +import com.dotcms.ai.util.Marshaller; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.model.AIImageRequestDTO; diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java index fae06a565d3b..6527eb984419 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java @@ -10,6 +10,7 @@ import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -56,7 +57,7 @@ public Response doGet(@Context final HttpServletRequest request, * * @param request the HTTP request * @param response the HTTP response - * @param formIn the form data containing the prompt + * @param form the form data containing the prompt * @return a Response object containing the generated text * @throws IOException if an I/O error occurs */ @@ -65,13 +66,14 @@ public Response doGet(@Context final HttpServletRequest request, @Produces(MediaType.APPLICATION_JSON) public Response doPost(@Context final HttpServletRequest request, @Context final HttpServletResponse response, - final CompletionsForm formIn) throws IOException { + final CompletionsForm form) throws IOException { - new WebResource.InitBuilder(request, response) + final User user = new WebResource.InitBuilder(request, response) .requiredBackendUser(true) .requiredFrontendUser(true) .init() .getUser(); + final CompletionsForm formIn = CompletionsForm.copy(form).user(user).build(); if (UtilMethods.isEmpty(formIn.prompt)) { return Response @@ -82,7 +84,12 @@ public Response doPost(@Context final HttpServletRequest request, final AppConfig config = ConfigService.INSTANCE.config(WebAPILocator.getHostWebAPI().getHost(request)); - return Response.ok(APILocator.getDotAIAPI().getCompletionsAPI().raw(generateRequest(formIn, config)).toString()).build(); + return Response.ok( + APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(generateRequest(formIn, config), user.getUserId()) + .toString()) + .build(); } /** @@ -93,12 +100,12 @@ public Response doPost(@Context final HttpServletRequest request, * @return a JSONObject representing the request */ private JSONObject generateRequest(final CompletionsForm form, final AppConfig config) { - final String systemPrompt = UtilMethods.isSet(config.getRolePrompt()) ? config.getRolePrompt() : null; final String model = form.model; final float temperature = form.temperature; final JSONObject request = new JSONObject(); final JSONArray messages = new JSONArray(); + final String systemPrompt = UtilMethods.isSet(config.getRolePrompt()) ? config.getRolePrompt() : null; if (UtilMethods.isSet(systemPrompt)) { messages.add(Map.of(AiKeys.ROLE, AiKeys.SYSTEM, AiKeys.CONTENT, systemPrompt)); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java index f4eb199d4bf2..2e1f58923556 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.liferay.portal.model.User; import io.vavr.control.Try; import javax.validation.constraints.Max; @@ -49,6 +50,7 @@ public class CompletionsForm { public final String model; public final String operator; public final String site; + public final User user; @Override public boolean equals(final Object o) { @@ -88,6 +90,7 @@ public String toString() { ", operator='" + operator + '\'' + ", site='" + site + '\'' + ", contentType=" + Arrays.toString(contentType) + + ", user=" + user + '}'; } @@ -118,6 +121,7 @@ private CompletionsForm(final Builder builder) { this.temperature = builder.temperature >= 2 ? 2 : builder.temperature; } this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getModel().getCurrentModel(); + this.user = builder.user; } private String validateBuilderQuery(final String query) { @@ -131,7 +135,6 @@ private long validateLanguage(final String language) { return Try.of(() -> Long.parseLong(language)) .recover(x -> APILocator.getLanguageAPI().getLanguage(language).getId()) .getOrElseTry(() -> APILocator.getLanguageAPI().getDefaultLanguage().getId()); - } public static Builder copy(final CompletionsForm form) { @@ -149,7 +152,8 @@ public static Builder copy(final CompletionsForm form) { .operator(form.operator) .indexName(form.indexName) .threshold(form.threshold) - .stream(form.stream); + .stream(form.stream) + .user(form.user); } public static final class Builder { @@ -182,6 +186,8 @@ public static final class Builder { private String operator = "cosine"; @JsonSetter(nulls = Nulls.SKIP) private String site; + @JsonSetter(nulls = Nulls.SKIP) + private User user; public Builder prompt(String queryOrPrompt) { this.prompt = queryOrPrompt; @@ -224,7 +230,7 @@ public Builder fieldVar(String fieldVar) { } public Builder model(String model) { - this.model =model; + this.model = model; return this; } @@ -254,7 +260,12 @@ public Builder operator(String operator) { } public Builder site(String site) { - this.site =site; + this.site = site; + return this; + } + + public Builder user(User user) { + this.user = user; return this; } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java index 61815b1307eb..62c61fa9d229 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java @@ -1,7 +1,6 @@ package com.dotcms.ai.rest.forms; import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.UtilMethods; @@ -65,8 +64,6 @@ public static final Builder copy(EmbeddingsForm form) { .fields(String.join(",", form.fields)) .velocityTemplate(form.velocityTemplate) .indexName(form.indexName); - - } @Override @@ -103,7 +100,6 @@ public String toString() { '}'; } - public static final class Builder { @JsonSetter(nulls = Nulls.SKIP) public String fields; @@ -135,7 +131,6 @@ public Builder limit(int limit) { return this; } - public Builder offset(int offset) { this.offset = offset; return this; @@ -161,10 +156,10 @@ public Builder velocityTemplate(String velocityTemplate) { return this; } - public EmbeddingsForm build() { return new EmbeddingsForm(this); - } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java similarity index 98% rename from dotCMS/src/main/java/com/dotcms/ai/Marshaller.java rename to dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java index fc39f5f88e8c..0f92396e50be 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java @@ -1,4 +1,4 @@ -package com.dotcms.ai; +package com.dotcms.ai.util; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java deleted file mode 100644 index b2a9b9adf789..000000000000 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.dotcms.ai.util; - -import com.dotcms.ai.AiKeys; -import com.dotcms.ai.app.AIModel; -import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.app.ConfigService; -import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; -import com.dotmarketing.util.json.JSONObject; -import io.vavr.control.Try; -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.*; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClients; - -import javax.ws.rs.HttpMethod; -import javax.ws.rs.core.MediaType; -import java.io.BufferedInputStream; -import java.io.ByteArrayOutputStream; -import java.io.OutputStream; -import java.util.concurrent.ConcurrentHashMap; - -/** - * The OpenAIRequest class is a utility class that handles HTTP requests to the OpenAI API. - * It provides methods for sending GET, POST, PUT, DELETE, and PATCH requests. - * This class also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * This class is implemented as a singleton, meaning that only one instance of the class is created throughout the execution of the program. - */ -public class OpenAIRequest { - - private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>(); - - private OpenAIRequest() {} - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is written to the provided OutputStream. - * This method also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * @param urlIn the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doRequest(final String urlIn, - final String method, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - AppConfig.debugLogger( - OpenAIRequest.class, - () -> String.format( - "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", - urlIn, - method, - System.lineSeparator(), - appConfig.toString(), - System.lineSeparator(), - json.toString(2))); - - if (!appConfig.isEnabled()) { - AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request."); - throw new DotRuntimeException("App dotAI config without API urls or API key"); - } - - final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); - final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) - + model.minIntervalBetweenCalls() - - System.currentTimeMillis(); - if (sleep > 0) { - Logger.info( - OpenAIRequest.class, - "Rate limit:" - + model.getApiPerMinute() - + "/minute, or 1 every " - + model.minIntervalBetweenCalls() - + "ms. Sleeping:" - + sleep); - Try.run(() -> Thread.sleep(sleep)); - } - - lastRestCall.put(model, System.currentTimeMillis()); - - try (CloseableHttpClient httpClient = HttpClients.createDefault()) { - final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); - final HttpUriRequest httpRequest = resolveMethod(method, urlIn); - httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); - httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); - - if (!json.getAsMap().isEmpty()) { - Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); - } - - try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { - final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); - final byte[] buffer = new byte[1024]; - int len; - while ((len = in.read(buffer)) != -1) { - out.write(buffer, 0, len); - out.flush(); - } - } - } catch (Exception e) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)){ - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e); - } else { - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage()); - } - - Logger.warn(OpenAIRequest.class, " - " + method + " : " +json); - - throw new DotRuntimeException(e); - } - } - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is returned as a string. - * - * @param url the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @return the response from the request as a string - */ - public static String doRequest(final String url, - final String method, - final AppConfig appConfig, - final JSONObject json) { - final ByteArrayOutputStream out = new ByteArrayOutputStream(); - doRequest(url, method, appConfig, json, out); - - return out.toString(); - } - - /** - * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doPost(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.POST, appConfig, json, out); - } - - /** - * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doGet(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.GET, appConfig, json, out); - } - - private static HttpUriRequest resolveMethod(final String method, final String urlIn) { - switch(method) { - case HttpMethod.POST: - return new HttpPost(urlIn); - case HttpMethod.PUT: - return new HttpPut(urlIn); - case HttpMethod.DELETE: - return new HttpDelete(urlIn); - case "patch": - return new HttpPatch(urlIn); - case HttpMethod.GET: - default: - return new HttpGet(urlIn); - } - } - -} diff --git a/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java new file mode 100644 index 000000000000..ad92b1167ca2 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java @@ -0,0 +1,125 @@ +package com.dotcms.ai.validator; + +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.Model; +import com.dotcms.api.system.event.message.MessageSeverity; +import com.dotcms.api.system.event.message.SystemMessageEventUtil; +import com.dotcms.api.system.event.message.builder.SystemMessage; +import com.dotcms.api.system.event.message.builder.SystemMessageBuilder; +import com.dotmarketing.util.DateUtil; +import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.language.LanguageUtil; +import io.vavr.Lazy; +import io.vavr.control.Try; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * The AIAppValidator class is responsible for validating AI configurations and model usage. + * It ensures that the AI models specified in the application configuration are supported + * and not exhausted. + * + * @author vico + */ +public class AIAppValidator { + + private static final Lazy INSTANCE = Lazy.of(AIAppValidator::new); + + private SystemMessageEventUtil systemMessageEventUtil; + + private AIAppValidator() { + setSystemMessageEventUtil(SystemMessageEventUtil.getInstance()); + } + + public static AIAppValidator get() { + return INSTANCE.get(); + } + + /** + * Validates the AI configuration for the specified user. + * If the user ID is null, the validation is skipped. + * Checks if the models specified in the application configuration are supported. + * If any unsupported models are found, a warning message is pushed to the user. + * + * @param appConfig the application configuration + * @param userId the user ID + */ + public void validateAIConfig(final AppConfig appConfig, final String userId) { + if (Objects.isNull(userId)) { + AppConfig.debugLogger(getClass(), () -> "User Id is null, skipping AI configuration validation"); + return; + } + + final Set supportedModels = AIModels.get().getOrPullSupportedModels(appConfig.getApiKey()); + final Set unsupportedModels = Stream.of( + appConfig.getModel(), + appConfig.getImageModel(), + appConfig.getEmbeddingsModel()) + .flatMap(aiModel -> aiModel.getModels().stream()) + .map(Model::getName) + .filter(model -> !supportedModels.contains(model)) + .collect(Collectors.toSet()); + if (unsupportedModels.isEmpty()) { + return; + } + + final String unsupported = String.join(", ", unsupportedModels); + final String message = Try + .of(() -> LanguageUtil.get("ai.unsupported.models", unsupported)) + .getOrElse(String.format("The following models are not supported: [%s]", unsupported)); + final SystemMessage systemMessage = new SystemMessageBuilder() + .setMessage(message) + .setSeverity(MessageSeverity.WARNING) + .setLife(DateUtil.SEVEN_SECOND_MILLIS) + .create(); + + systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId)); + } + + /** + * Validates the usage of AI models for the specified user. + * If the user ID is null, the validation is skipped. + * Checks if the models specified in the AI model are exhausted or invalid. + * If any exhausted or invalid models are found, a warning message is pushed to the user. + * + * @param aiModel the AI model + * @param userId the user ID + */ + public void validateModelsUsage(final AIModel aiModel, final String userId) { + if (Objects.isNull(userId)) { + AppConfig.debugLogger(getClass(), () -> "User Id is null, skipping AI models usage validation"); + return; + } + + final String unavailableModels = aiModel.getModels() + .stream() + .map(Model::getName) + .collect(Collectors.joining(", ")); + final String message = Try + .of(() -> LanguageUtil.get("ai.models.exhausted", aiModel.getType(), unavailableModels)). + getOrElse( + String.format( + "All the %s models: [%s] have been exhausted since they are invalid or has been decommissioned", + aiModel.getType(), + unavailableModels)); + final SystemMessage systemMessage = new SystemMessageBuilder() + .setMessage(message) + .setSeverity(MessageSeverity.WARNING) + .setLife(DateUtil.SEVEN_SECOND_MILLIS) + .create(); + + systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId)); + } + + @VisibleForTesting + void setSystemMessageEventUtil(SystemMessageEventUtil systemMessageEventUtil) { + this.systemMessageEventUtil = systemMessageEventUtil; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java index 050b56b1e535..0ad6d7837a2d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java @@ -4,9 +4,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.json.JSONObject; @@ -30,11 +28,13 @@ public class AIViewTool implements ViewTool { private AppConfig config; private ChatAPI chatService; private ImageAPI imageService; + private User user; @Override public void init(final Object obj) { context = (ViewContext) obj; config = config(); + user = user(); chatService = chatService(); imageService = imageService(); } @@ -128,12 +128,12 @@ User user() { @VisibleForTesting ChatAPI chatService() { - return APILocator.getDotAIAPI().getChatAPI(config); + return APILocator.getDotAIAPI().getChatAPI(config, user); } @VisibleForTesting ImageAPI imageService() { - return APILocator.getDotAIAPI().getImageAPI(config, user(), APILocator.getHostAPI(), APILocator.getTempFileAPI()); + return APILocator.getDotAIAPI().getImageAPI(config, user, APILocator.getHostAPI(), APILocator.getTempFileAPI()); } private

Try generate(final P prompt, final Function serviceCall) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java index 5508a23f4e32..3f8e2b5b4ea9 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java @@ -9,6 +9,8 @@ import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -24,14 +26,18 @@ */ public class CompletionsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig config; + private final User user; CompletionsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.config = config(); + this.user = user(); } @Override @@ -69,7 +75,11 @@ public Object summarize(final String prompt) { * @return The summarized object. */ public Object summarize(final String prompt, final String indexName) { - final CompletionsForm form = new CompletionsForm.Builder().indexName(indexName).prompt(prompt).build(); + final CompletionsForm form = new CompletionsForm.Builder() + .indexName(indexName) + .prompt(prompt) + .user(user) + .build(); try { return APILocator.getDotAIAPI().getCompletionsAPI(config).summarize(form); } catch (Exception e) { @@ -112,7 +122,7 @@ public Object raw(String prompt) { */ public Object raw(final JSONObject prompt) { try { - return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt); + return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt, user.getUserId()); } catch (Exception e) { return handleException(e); } @@ -141,4 +151,9 @@ AppConfig config() { return ConfigService.INSTANCE.config(this.host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java index daf7e1756139..e8dc49722255 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java @@ -7,7 +7,10 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; +import com.dotmarketing.util.UtilMethods; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -22,9 +25,11 @@ */ public class EmbeddingsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig appConfig; + private final User user; /** * Constructor for the EmbeddingsTool class. @@ -33,9 +38,11 @@ public class EmbeddingsTool implements ViewTool { * @param initData Initialization data for the tool. */ EmbeddingsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.appConfig = appConfig(); + this.user = user(); } @Override @@ -71,10 +78,14 @@ public List generateEmbeddings(final String prompt) { if (tokens > maxTokens) { Logger.warn( EmbeddingsTool.class, - "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); + "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); } - return APILocator.getDotAIAPI().getEmbeddingsAPI().pullOrGenerateEmbeddings(prompt)._2; + return APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .pullOrGenerateEmbeddings(prompt, UtilMethods.extractUserIdOrNull(user)) + ._2; } /** @@ -96,4 +107,9 @@ AppConfig appConfig() { return ConfigService.INSTANCE.config(host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java index 87a068eca10c..d8d01341bd0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; import com.dotmarketing.portlets.workflows.actionlet.WorkFlowActionlet; @@ -9,7 +10,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -24,7 +24,7 @@ public List getParameters() { WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite tags ", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); @@ -32,16 +32,26 @@ public List getParameters() { WorkflowActionletParameter limitTagsToHost = new MultiSelectionWorkflowActionletParameter( OpenAIParams.LIMIT_TAGS_TO_HOST.key, "Limit the keywords to pre-existing tags", "Limit", false, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true)) ) ); + + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( overwriteParameter, limitTagsToHost, - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false) + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0.", + ".1", + false) ); } @@ -63,5 +73,4 @@ public void executeAction(final WorkflowProcessor processor, new OpenAIAutoTagRunner(processor, params).run(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java index 2f7013bfd362..a63c03743686 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java @@ -146,7 +146,7 @@ private String openAIRequest(final Contentlet workingContentlet, final String co final String parsedContentPrompt = VelocityUtil.eval(contentToTag, ctx); final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI() - .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000); + .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000, user.getUserId()); return openAIResponse.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content"); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java index b6a14ab22d44..8a8e9320293b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; @@ -10,7 +11,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -23,22 +23,39 @@ public class OpenAIContentPromptActionlet extends WorkFlowActionlet { @Override public List getParameters() { - WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, + final WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite existing content (true|false)", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); - + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( - new WorkflowActionletParameter(OpenAIParams.FIELD_TO_WRITE.key, "The field where you want to write the results. " + - "
If your response is being returned as a json object, this field can be left blank" + - "
and the keys of the json object will be used to update the content fields.", "", false), + new WorkflowActionletParameter( + OpenAIParams.FIELD_TO_WRITE.key, + "The field where you want to write the results. " + + "
If your response is being returned as a json object, this field can be left blank" + + "
and the keys of the json object will be used to update the content fields.", + "", + false), overwriteParameter, - new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false) + new WorkflowActionletParameter( + OpenAIParams.OPEN_AI_PROMPT.key, + "The prompt that will be sent to the AI", + "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", + true), + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0. Defaults to " + + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java index 176b12d860c4..5ebca943a044 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java @@ -145,7 +145,9 @@ private String openAIRequest(final Contentlet workingContentlet) throws Exceptio final Context ctx = VelocityContextFactory.getMockContext(workingContentlet, user); final String parsedPrompt = VelocityUtil.eval(prompt, ctx); - final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI().raw(buildRequest(parsedPrompt, model, temperature)); + final JSONObject openAIResponse = APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(buildRequest(parsedPrompt, model, temperature), user.getUserId()); try { return openAIResponse diff --git a/dotCMS/src/main/java/com/dotmarketing/util/UtilMethods.java b/dotCMS/src/main/java/com/dotmarketing/util/UtilMethods.java index 61085e6bb0f7..0dfe0d2adc27 100644 --- a/dotCMS/src/main/java/com/dotmarketing/util/UtilMethods.java +++ b/dotCMS/src/main/java/com/dotmarketing/util/UtilMethods.java @@ -3678,4 +3678,15 @@ public static T isSetOrGet(final T toEvaluate, final T defaultValue){ public static boolean exceedsMaxLength(final T value, final int maxLength) { return value != null && value.length() > maxLength; } + + /** + * Extracts the user id from a User object or returns null if the object is null + * + * @param user User object + * @return User id or null + */ + public static String extractUserIdOrNull(final User user) { + return Optional.ofNullable(user).map(User::getUserId).orElse(null); + } + } \ No newline at end of file diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml index e8f03c0c5f80..425c9b42b058 100644 --- a/dotCMS/src/main/resources/apps/dotAI.yml +++ b/dotCMS/src/main/resources/apps/dotAI.yml @@ -15,11 +15,11 @@ params: hint: "Your ChatGPT API key" required: true textModelNames: - value: "gpt-4o" + value: "gpt-4o-mini" hidden: false type: "STRING" label: "Model Names" - hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-3.5-turbo-16k)" + hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-4o-mini)" required: true rolePrompt: value: "You are dotCMSbot, and AI assistant to help content creators generate and rewrite content in their content management system." diff --git a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties index 782f5f57793a..63c55fbc5354 100644 --- a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties +++ b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties @@ -189,6 +189,8 @@ anonymous=Anonymous another-layout-already-exists=Another Tool Group already exists in the system with the same name Any-Structure-Type=Any Content Type Any-Structure=Any Content Type +ai.unsupported.models=The following models are not supported: [{0}] +ai.models.exhausted=All the {0} models: [{1}] have been exhausted since they are invalid or has been decommissioned api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.comparison.placeholder=Comparison api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.language.placeholder=Language api.ruleengine.system.conditionlet.CurrentSessionLanguage.name=Selected Language diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js index 088436aef605..f2db45f1a34c 100644 --- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js +++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js @@ -136,11 +136,12 @@ const writeModelToDropdown = async () => { } const newOption = document.createElement("option"); + console.log(JSON.stringify(dotAiState.config, null, 2)); newOption.value = dotAiState.config.availableModels[i].name; newOption.text = `${dotAiState.config.availableModels[i].name}` - if (dotAiState.config.availableModels[i] === dotAiState.config.model) { + if (dotAiState.config.availableModels[i].current) { newOption.selected = true; - newOption.text = `${dotAiState.config.availableModels[i]} (default)` + newOption.text = `${dotAiState.config.availableModels[i].name} (default)` } modelName.appendChild(newOption); } diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java similarity index 86% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java index e4c43486c3f1..c51e9c6323a5 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java @@ -1,12 +1,11 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import org.junit.Before; import org.junit.Test; @@ -17,17 +16,19 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIChatServiceImplTest { +public class OpenAIChatAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; private AppConfig config; private ChatAPI service; + private User user; @Before public void setUp() { config = mock(AppConfig.class); + user = mock(User.class); service = prepareService(RESPONSE_JSON); } @@ -54,11 +55,9 @@ public void test_sendTextPrompt() { } private ChatAPI prepareService(final String response) { - return new OpenAIChatAPIImpl(config) { - - + return new OpenAIChatAPIImpl(config, user) { @Override - public String doRequest(final String urlIn, final JSONObject json) { + String doRequest(final JSONObject json, final String userId) { return response; } }; @@ -66,7 +65,7 @@ public String doRequest(final String urlIn, final JSONObject json) { private JSONObject prepareJsonObject(final String prompt) { when(config.getModel()) - .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withNames("some-model").build()); + .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withModelNames("some-model").build()); when(config.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)).thenReturn(123.321F); when(config.getRolePrompt()).thenReturn("some-role-prompt"); diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java similarity index 97% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java index 6c3fc6822473..e73d9352a59b 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java @@ -1,7 +1,5 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; @@ -27,7 +25,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIImageServiceImplTest { +public class OpenAIImageAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; @@ -220,7 +218,7 @@ public AIImageRequestDTO.Builder getDtoBuilder() { } private JSONObject prepareJsonObject(final String prompt, final boolean tempFileError) throws Exception { - when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withNames("some-image-model").build()); + when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withModelNames("some-image-model").build()); when(config.getImageSize()).thenReturn("some-image-size"); final File file = mock(File.class); when(file.getName()).thenReturn(UUIDGenerator.shorty()); diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java index c4d5c93b7627..8c1cd1e79c4e 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -1,10 +1,12 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; import org.junit.Before; import org.junit.Test; import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -127,7 +129,7 @@ public void testCreateTextModel() { AIModel model = aiAppUtil.createTextModel(secrets); assertNotNull(model); assertEquals(AIModelType.TEXT, model.getType()); - assertTrue(model.getNames().contains("textmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("textmodel")); } /** @@ -143,7 +145,7 @@ public void testCreateImageModel() { AIModel model = aiAppUtil.createImageModel(secrets); assertNotNull(model); assertEquals(AIModelType.IMAGE, model.getType()); - assertTrue(model.getNames().contains("imagemodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("imagemodel")); } /** @@ -159,7 +161,8 @@ public void testCreateEmbeddingsModel() { AIModel model = aiAppUtil.createEmbeddingsModel(secrets); assertNotNull(model); assertEquals(AIModelType.EMBEDDINGS, model.getType()); - assertTrue(model.getNames().contains("embeddingsmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()) + .contains("embeddingsmodel")); } @Test diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java new file mode 100644 index 000000000000..b90f3b24a63d --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java @@ -0,0 +1,102 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.client.AIClient; +import com.dotcms.ai.client.AIClientStrategy; +import com.dotcms.ai.client.AIProxiedClient; +import com.dotcms.ai.client.AIProxyStrategy; +import com.dotcms.ai.client.AIResponseEvaluator; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for the AIProxiedClient class. + * + * @author vico + */ +public class AIProxiedClientTest { + + private AIClient mockClient; + private AIProxyStrategy mockProxyStrategy; + private AIClientStrategy mockClientStrategy; + private AIResponseEvaluator mockResponseEvaluator; + private AIProxiedClient proxiedClient; + + @Before + public void setUp() { + mockClient = mock(AIClient.class); + mockProxyStrategy = mock(AIProxyStrategy.class); + mockClientStrategy = mock(AIClientStrategy.class); + when(mockProxyStrategy.getStrategy()).thenReturn(mockClientStrategy); + mockResponseEvaluator = mock(AIResponseEvaluator.class); + proxiedClient = AIProxiedClient.of(mockClient, mockProxyStrategy, mockResponseEvaluator); + } + + /** + * Scenario: Sending a valid AI request + * Given a valid AI request + * When the request is sent to the AI service + * Then the strategy should be applied + * And the response should be written to the output stream + */ + @Test + public void testSendToAI_withValidRequest() { + AIRequest request = mock(AIRequest.class); + OutputStream output = mock(OutputStream.class); + + AIResponse response = proxiedClient.sendToAI(request, output); + + verify(mockClientStrategy).applyStrategy(mockClient, mockResponseEvaluator, request, output); + assertEquals(AIResponse.EMPTY, response); + } + + /** + * Scenario: Sending an AI request with null output stream + * Given a valid AI request and a null output stream + * When the request is sent to the AI service + * Then the strategy should be applied + * And the response should be returned as a string + */ + @Test + public void testSendToAI_withNullOutput() { + AIRequest request = mock(AIRequest.class); + AIResponse response = proxiedClient.sendToAI(request, null); + + verify(mockClientStrategy).applyStrategy( + eq(mockClient), + eq(mockResponseEvaluator), + eq(request), + any(OutputStream.class)); + assertEquals("", response.getResponse()); + } + + /** + * Scenario: Sending an AI request with NOOP client + * Given a valid AI request and a NOOP client + * When the request is sent to the AI service + * Then no operations should be performed + * And the response should be empty + */ + @Test + public void testSendToAI_withNoopClient() { + proxiedClient = AIProxiedClient.NOOP; + AIRequest request = AIRequest.builder().build(); + OutputStream output = new ByteArrayOutputStream(); + + AIResponse response = proxiedClient.sendToAI(request, output); + + assertEquals(AIResponse.EMPTY, response); + } +} \ No newline at end of file diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java new file mode 100644 index 000000000000..0ee7548304e7 --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java @@ -0,0 +1,139 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotmarketing.exception.DotRuntimeException; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * Tests for the OpenAIResponseEvaluator class. + * + * @author vico + */ +public class OpenAIResponseEvaluatorTest { + + private OpenAIResponseEvaluator evaluator; + + @Before + public void setUp() { + evaluator = OpenAIResponseEvaluator.get(); + } + + /** + * Scenario: Processing a response with an error + * Given a response with an error message "Model has been deprecated" + * When the response is processed + * Then the metadata should contain the error message "Model has been deprecated" + * And the status should be set to DECOMMISSIONED + */ + @Test + public void testFromResponse_withError() { + String response = new JSONObject().put("error", "Model has been deprecated").toString(); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata, true); + + assertEquals("Model has been deprecated", metadata.getError()); + assertEquals(ModelStatus.DECOMMISSIONED, metadata.getStatus()); + } + + /** + * Scenario: Processing a response with an error + * Given a response with an error message "Model has been deprecated" + * When the response is processed as no JSON + * Then the metadata should contain the error message "Model has been deprecated" + * And the status should be set to DECOMMISSIONED + */ + @Test + public void testFromResponse_withErrorNoJson() { + String response = new JSONObject().put("error", "Model has been deprecated").toString(); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata, false); + + assertEquals("Model has been deprecated", metadata.getError()); + assertEquals(ModelStatus.DECOMMISSIONED, metadata.getStatus()); + } + + /** + * Scenario: Processing a response with an error + * Given a response with an error message "Model has been deprecated" + * When the response is processed as no JSON + * Then the metadata should contain the error message "Model has been deprecated" + * And the status should be set to DECOMMISSIONED + */ + @Test + public void testFromResponse_withoutErrorNoJson() { + String response = "not a json response"; + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata, false); + + assertNull(metadata.getError()); + assertNull(metadata.getStatus()); + } + + /** + * Scenario: Processing a response without an error + * Given a response without an error message + * When the response is processed + * Then the metadata should not contain any error message + * And the status should be null + */ + @Test + public void testFromResponse_withoutError() { + String response = new JSONObject().put("data", "some data").toString(); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromResponse(response, metadata, true); + + assertNull(metadata.getError()); + assertNull(metadata.getStatus()); + } + + /** + * Scenario: Processing an exception of type DotRuntimeException + * Given an exception of type DotAIModelNotFoundException with message "Model not found" + * When the exception is processed + * Then the metadata should contain the error message "Model not found" + * And the status should be set to INVALID + * And the exception should be set to the given DotRuntimeException + */ + @Test + public void testFromException_withDotRuntimeException() { + DotRuntimeException exception = new DotAIModelNotFoundException("Model not found"); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromException(exception, metadata); + + assertEquals("Model not found", metadata.getError()); + assertEquals(ModelStatus.INVALID, metadata.getStatus()); + assertEquals(exception, metadata.getException()); + } + + /** + * Scenario: Processing a general exception + * Given a general exception with message "General error" + * When the exception is processed + * Then the metadata should contain the error message "General error" + * And the status should be set to UNKNOWN + * And the exception should be wrapped in a DotRuntimeException + */ + @Test + public void testFromException_withOtherException() { + Exception exception = new Exception("General error"); + AIResponseData metadata = new AIResponseData(); + + evaluator.fromException(exception, metadata); + + assertEquals("General error", metadata.getError()); + assertEquals(ModelStatus.UNKNOWN, metadata.getStatus()); + assertEquals(DotRuntimeException.class, metadata.getException().getClass()); + } +} diff --git a/dotCMS/src/test/java/com/dotmarketing/util/UtilMethodsTest.java b/dotCMS/src/test/java/com/dotmarketing/util/UtilMethodsTest.java index 6a4e474aac4d..f6f3d4fd8213 100644 --- a/dotCMS/src/test/java/com/dotmarketing/util/UtilMethodsTest.java +++ b/dotCMS/src/test/java/com/dotmarketing/util/UtilMethodsTest.java @@ -5,9 +5,12 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.dotcms.UnitTestBase; import com.dotmarketing.portlets.contentlet.model.Contentlet; +import com.liferay.portal.model.User; import org.junit.Test; /** @@ -208,4 +211,29 @@ public void test_isImage_method(){ } } + /** + * Scenario: Extracting user ID from a User object + * Given a null User object + * When the user ID is extracted + * Then the result should be null + * + * Given a mocked User object with no user ID + * When the user ID is extracted + * Then the result should be null + * + * Given a mocked User object with a user ID "userId" + * When the user ID is extracted + * Then the result should be "userId" + */ + @Test + public void test_extractUserIdOrNull(){ + assertNull(UtilMethods.extractUserIdOrNull(null)); + + final User user = mock(User.class); + assertNull(UtilMethods.extractUserIdOrNull(user)); + + when(user.getUserId()).thenReturn("userId"); + assertEquals("userId", UtilMethods.extractUserIdOrNull(user)); + } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java index 81b74a231e3f..5b604f096adb 100644 --- a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java +++ b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java @@ -1,7 +1,10 @@ package com.dotcms; import com.dotcms.ai.app.AIModelsTest; +import com.dotcms.ai.app.ConfigServiceTest; +import com.dotcms.ai.client.AIProxyClientTest; import com.dotcms.ai.listener.EmbeddingContentListenerTest; +import com.dotcms.ai.validator.AIAppValidatorTest; import com.dotcms.ai.viewtool.AIViewToolTest; import com.dotcms.ai.viewtool.CompletionsToolTest; import com.dotcms.ai.viewtool.EmbeddingsToolTest; @@ -302,6 +305,9 @@ EmbeddingsToolTest.class, CompletionsToolTest.class, AIModelsTest.class, + ConfigServiceTest.class, + AIProxyClientTest.class, + AIAppValidatorTest.class, TimeMachineAPITest.class, Task240513UpdateContentTypesSystemFieldTest.class, PruneTimeMachineBackupJobTest.class, diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 855f61ad4572..557201f1aaa4 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -6,11 +6,11 @@ import com.dotcms.util.WireMockTestHelper; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; -import com.dotmarketing.exception.DotSecurityException; import com.github.tomakehurst.wiremock.WireMockServer; import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; public interface AiTest { @@ -31,55 +31,56 @@ static WireMockServer prepareWireMock() { return wireMockServer; } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, + static Map aiAppSecrets(final Host host, final String apiKey, final String textModels, final String imageModels, - final String embeddingsModel) - throws DotDataException, DotSecurityException { - final AppSecrets appSecrets = new AppSecrets.Builder() + final String embeddingsModel) throws Exception { + final AppSecrets.Builder builder = new AppSecrets.Builder() .withKey(AppKeys.APP_KEY) - .withSecret(AppKeys.API_URL.key, String.format(API_URL, wireMockServer.port())) - .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, wireMockServer.port())) - .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, wireMockServer.port())) + .withSecret(AppKeys.API_URL.key, String.format(API_URL, PORT)) + .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, PORT)) + .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, PORT)) .withHiddenSecret(AppKeys.API_KEY.key, apiKey) - .withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels) - .withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels) - .withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel) .withSecret(AppKeys.IMAGE_SIZE.key, IMAGE_SIZE) .withSecret(AppKeys.LISTENER_INDEXER.key, "{\"default\":\"blog\"}") .withSecret(AppKeys.COMPLETION_ROLE_PROMPT.key, AppKeys.COMPLETION_ROLE_PROMPT.defaultValue) - .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue) - .build(); + .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue); + + if (Objects.nonNull(textModels)) { + builder.withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels); + } + if (Objects.nonNull(imageModels)) { + builder.withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels); + } + if (Objects.nonNull(embeddingsModel)) { + builder.withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel); + } + + final AppSecrets appSecrets = builder.build(); APILocator.getAppsAPI().saveSecrets(appSecrets, host, APILocator.systemUser()); + TimeUnit.SECONDS.sleep(1); return appSecrets.getSecrets(); } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, - final String apiKey) - throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); + static Map aiAppSecrets(final Host host, final String apiKey) throws Exception { + return aiAppSecrets(host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, + static Map aiAppSecrets(final Host host, final String textModels, final String imageModels, - final String embeddingsModel) - throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, API_KEY, textModels, imageModels, embeddingsModel); + final String embeddingsModel) throws Exception { + return aiAppSecrets(host, API_KEY, textModels, imageModels, embeddingsModel); } - static Map aiAppSecrets(final WireMockServer wireMockServer, final Host host) - throws DotDataException, DotSecurityException { + static Map aiAppSecrets(final Host host) throws Exception { - return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); + return aiAppSecrets(host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } - static void removeSecrets(final Host host) throws DotDataException, DotSecurityException { - APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser()); + static void removeAiAppSecrets(final Host host) throws Exception { + APILocator.getAppsAPI().deleteSecrets(AppKeys.APP_KEY, host, APILocator.systemUser()); } } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index e08965e20843..8d8b3fe1441b 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -1,18 +1,19 @@ package com.dotcms.ai.app; import com.dotcms.ai.AiTest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.model.SimpleModel; import com.dotcms.datagen.SiteDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.exception.DotSecurityException; -import com.dotmarketing.util.DateUtil; import com.github.tomakehurst.wiremock.WireMockServer; +import io.vavr.Tuple2; import io.vavr.control.Try; -import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -23,9 +24,11 @@ import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; /** @@ -46,25 +49,22 @@ public class AIModelsTest { @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); + IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); + AiTest.aiAppSecrets(APILocator.systemHost()); } @AfterClass public static void afterClass() { wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); } @Before public void before() { - IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); otherHost = new SiteDataGen().nextPersisted(); - List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(wireMockServer, host)).get()); - } - - @After - public void after() { - IPUtils.disabledIpPrivateSubnet(false); + List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(h)).get()); } /** @@ -73,31 +73,31 @@ public void after() { * Then the correct models should be found and returned. */ @Test - public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - saveSecrets( + public void test_loadModels_andFindThem() throws Exception { + AiTest.aiAppSecrets( host, "text-model-1,text-model-2", "image-model-3,image-model-4", "embeddings-model-5,embeddings-model-6"); - saveSecrets(otherHost, "text-model-1", null, null); + AiTest.aiAppSecrets(otherHost, "text-model-1", null, null); final String hostId = host.getHostname(); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); - final Optional notFound = aiModels.findModel(hostId, "some-invalid-model-name"); + final Optional notFound = aiModels.findModel(appConfig, "some-invalid-model-name", AIModelType.TEXT); assertTrue(notFound.isEmpty()); - final Optional text1 = aiModels.findModel(hostId, "text-model-1"); - final Optional text2 = aiModels.findModel(hostId, "text-model-2"); - assertModels(text1, text2, AIModelType.TEXT); + final Optional text1 = aiModels.findModel(appConfig, "text-model-1", AIModelType.TEXT); + final Optional text2 = aiModels.findModel(appConfig, "text-model-2", AIModelType.TEXT); + assertModels(text1, text2, AIModelType.TEXT, true); - final Optional image1 = aiModels.findModel(hostId, "image-model-3"); - final Optional image2 = aiModels.findModel(hostId, "image-model-4"); - assertModels(image1, image2, AIModelType.IMAGE); + final Optional image1 = aiModels.findModel(appConfig, "image-model-3", AIModelType.IMAGE); + final Optional image2 = aiModels.findModel(appConfig, "image-model-4", AIModelType.IMAGE); + assertModels(image1, image2, AIModelType.IMAGE, true); - final Optional embeddings1 = aiModels.findModel(hostId, "embeddings-model-5"); - final Optional embeddings2 = aiModels.findModel(hostId, "embeddings-model-6"); - assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS); + final Optional embeddings1 = aiModels.findModel(appConfig, "embeddings-model-5", AIModelType.EMBEDDINGS); + final Optional embeddings2 = aiModels.findModel(appConfig, "embeddings-model-6", AIModelType.EMBEDDINGS); + assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS, true); assertNotSame(text1.get(), image1.get()); assertNotSame(text1.get(), embeddings1.get()); @@ -112,27 +112,132 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx final Optional embeddings3 = aiModels.findModel(hostId, AIModelType.EMBEDDINGS); assertSameModels(embeddings3, embeddings1, embeddings2); - final Optional text4 = aiModels.findModel(otherHost.getHostname(), "text-model-1"); + final AppConfig otherAppConfig = ConfigService.INSTANCE.config(otherHost); + final Optional text4 = aiModels.findModel(otherAppConfig, "text-model-1", AIModelType.TEXT); assertTrue(text3.isPresent()); assertNotSame(text1.get(), text4.get()); - saveSecrets( + AiTest.aiAppSecrets( host, "text-model-7,text-model-8", "image-model-9,image-model-10", "embeddings-model-11, embeddings-model-12"); - final Optional text7 = aiModels.findModel(hostId, "text-model-7"); - final Optional text8 = aiModels.findModel(hostId, "text-model-8"); + final Optional text7 = aiModels.findModel(otherAppConfig, "text-model-7", AIModelType.TEXT); + final Optional text8 = aiModels.findModel(otherAppConfig, "text-model-8", AIModelType.TEXT); assertNotPresentModels(text7, text8); - final Optional image9 = aiModels.findModel(hostId, "image-model-9"); - final Optional image10 = aiModels.findModel(hostId, "image-model-10"); + final Optional image9 = aiModels.findModel(otherAppConfig, "image-model-9", AIModelType.IMAGE); + final Optional image10 = aiModels.findModel(otherAppConfig, "image-model-10", AIModelType.IMAGE); assertNotPresentModels(image9, image10); - final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11"); - final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12"); + final Optional embeddings11 = aiModels.findModel(otherAppConfig, "embeddings-model-11", AIModelType.EMBEDDINGS); + final Optional embeddings12 = aiModels.findModel(otherAppConfig, "embeddings-model-12", AIModelType.EMBEDDINGS); assertNotPresentModels(embeddings11, embeddings12); + + final List available = aiModels.getAvailableModels(); + final List availableNames = List.of( + "gpt-3.5-turbo-16k", "dall-e-3", "text-embedding-ada-002", + "text-model-1", "text-model-7", "text-model-8", + "image-model-9", "image-model-10", + "embeddings-model-11", "embeddings-model-12"); + assertTrue(available.stream().anyMatch(model -> availableNames.contains(model.getName()))); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveModel method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveModel() throws Exception { + AiTest.aiAppSecrets(host, "text-model-20", "image-model-21", "embeddings-model-22"); + ConfigService.INSTANCE.config(host); + AiTest.aiAppSecrets(otherHost, "text-model-23", null, null); + ConfigService.INSTANCE.config(otherHost); + + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.TEXT).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.IMAGE).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + assertTrue(aiModels.resolveModel(otherHost.getHostname(), AIModelType.TEXT).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.IMAGE).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveAIModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveAIModelOrThrow() throws Exception { + AiTest.aiAppSecrets(host, "text-model-30", "image-model-31", "embeddings-model-32"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final AIModel aiModel30 = aiModels.resolveAIModelOrThrow(appConfig, "text-model-30", AIModelType.TEXT); + final AIModel aiModel31 = aiModels.resolveAIModelOrThrow(appConfig, "image-model-31", AIModelType.IMAGE); + final AIModel aiModel32 = aiModels.resolveAIModelOrThrow( + appConfig, + "embeddings-model-32", + AIModelType.EMBEDDINGS); + + assertNotNull(aiModel30); + assertNotNull(aiModel31); + assertNotNull(aiModel32); + assertEquals("text-model-30", aiModel30.getModel("text-model-30").getName()); + assertEquals("image-model-31", aiModel31.getModel("image-model-31").getName()); + assertEquals("embeddings-model-32", aiModel32.getModel("embeddings-model-32").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-33", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-34", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-35", AIModelType.EMBEDDINGS)); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resolveModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveModelOrThrow() throws Exception { + AiTest.aiAppSecrets(host, "text-model-40", "image-model-41", "embeddings-model-42"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final Tuple2 modelTuple40 = aiModels.resolveModelOrThrow( + appConfig, + "text-model-40", + AIModelType.TEXT); + final Tuple2 modelTuple41 = aiModels.resolveModelOrThrow( + appConfig, + "image-model-41", + AIModelType.IMAGE); + final Tuple2 modelTuple42 = aiModels.resolveModelOrThrow( + appConfig, + "embeddings-model-42", + AIModelType.EMBEDDINGS); + + assertNotNull(modelTuple40); + assertNotNull(modelTuple41); + assertNotNull(modelTuple42); + assertEquals("text-model-40", modelTuple40._1.getModel("text-model-40").getName()); + assertEquals("image-model-41", modelTuple41._1.getModel("image-model-41").getName()); + assertEquals("embeddings-model-42", modelTuple42._1.getModel("embeddings-model-42").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-43", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-44", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-45", AIModelType.EMBEDDINGS)); } /** @@ -141,69 +246,44 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx * Then a list of supported models should be returned. */ @Test - public void test_getOrPullSupportedModules() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); + public void test_getOrPullSupportedModels() throws Exception { AIModels.get().cleanSupportedModelsCache(); - Set supported = aiModels.getOrPullSupportedModels(); + Set supported = aiModels.getOrPullSupportedModels(AiTest.API_KEY); assertNotNull(supported); assertEquals(38, supported.size()); - - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /** * Given an invalid URL for supported models * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown */ @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_withNetworkError() { + public void test_getOrPullSupportedModuels_withNetworkError() { AIModels.get().cleanSupportedModelsCache(); IPUtils.disabledIpPrivateSubnet(false); - final Set supported = aiModels.getOrPullSupportedModels(); - assertSupported(supported); - + aiModels.getOrPullSupportedModels(AiTest.API_KEY); IPUtils.disabledIpPrivateSubnet(true); - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); - } - - /** - * Given no API key - * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. - */ - @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), null); - - AIModels.get().cleanSupportedModelsCache(); - aiModels.getOrPullSupportedModels(); } /** * Given no API key * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown. */ @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException { - AiTest.removeSecrets(APILocator.systemHost()); + public void test_getOrPullSupportedModels_noApiKey() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost(), null); AIModels.get().cleanSupportedModelsCache(); - aiModels.getOrPullSupportedModels(); + aiModels.getOrPullSupportedModels(null); } - private void saveSecrets(final Host host, - final String textModels, - final String imageModels, - final String embeddingsModels) throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, textModels, imageModels, embeddingsModels); - DateUtil.sleep(1000); - } - - private static void assertSameModels(Optional text3, Optional text1, Optional text2) { + private static void assertSameModels(final Optional text3, + final Optional text1, + final Optional text2) { assertTrue(text3.isPresent()); assertSame(text1.get(), text3.get()); assertSame(text2.get(), text3.get()); @@ -211,12 +291,17 @@ private static void assertSameModels(Optional text3, Optional private static void assertModels(final Optional model1, final Optional model2, - final AIModelType type) { + final AIModelType type, + final boolean assertModelNames) { assertTrue(model1.isPresent()); assertTrue(model2.isPresent()); assertSame(model1.get(), model2.get()); assertSame(type, model1.get().getType()); assertSame(type, model2.get().getType()); + if (assertModelNames) { + assertTrue(model1.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + assertTrue(model2.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + } } private static void assertNotPresentModels(final Optional model1, final Optional model2) { @@ -224,9 +309,4 @@ private static void assertNotPresentModels(final Optional model1, final assertTrue(model2.isEmpty()); } - private static void assertSupported(Set supported) { - assertNotNull(supported); - assertTrue(supported.isEmpty()); - } - } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java new file mode 100644 index 000000000000..2e6143037095 --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java @@ -0,0 +1,101 @@ +package com.dotcms.ai.app; + +import com.dotcms.ai.AiTest; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.LicenseValiditySupplier; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for the ConfigService class. + * + *

+ * This class contains tests to verify the behavior of the ConfigService, + * including scenarios with valid and invalid licenses, and configurations + * with and without secrets. + *

+ * + *

+ * The tests ensure that the ConfigService correctly initializes and + * configures the AppConfig based on the provided Host and license validity. + *

+ * + * @author vico + */ +public class ConfigServiceTest { + + private Host host; + private ConfigService configService; + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + } + + @Before + public void before() { + host = new SiteDataGen().nextPersisted(); + configService = ConfigService.INSTANCE; + } + + /** + * Given a ConfigService with an invalid license + * When the config method is called with a host + * Then the models should not be operational. + */ + @Test + public void test_invalidLicense() { + configService = new ConfigService(new LicenseValiditySupplier() { + @Override + public boolean hasValidLicense() { + return false; + } + }); + final AppConfig appConfig = configService.config(host); + + assertFalse(appConfig.getModel().isOperational()); + assertFalse(appConfig.getImageModel().isOperational()); + assertFalse(appConfig.getEmbeddingsModel().isOperational()); + } + + /** + * Given a host with secrets and a ConfigService + * When the config method is called with the host + * Then the models should be operational and the host should be correctly set in the AppConfig. + */ + @Test + public void test_config_hostWithSecrets() throws Exception { + AiTest.aiAppSecrets(host, "text-model-0", "image-model-1", "embeddings-model-2"); + final AppConfig appConfig = configService.config(host); + + assertTrue(appConfig.getModel().isOperational()); + assertTrue(appConfig.getImageModel().isOperational()); + assertTrue(appConfig.getEmbeddingsModel().isOperational()); + assertEquals(host.getHostname(), appConfig.getHost()); + } + + /** + * Given a host without secrets and a ConfigService + * When the config method is called with the host + * Then the models should be operational and the host should be set to "System Host" in the AppConfig. + */ + @Test + public void test_config_hostWithoutSecrets() throws Exception { + AiTest.aiAppSecrets(APILocator.systemHost(), "text-model-10", "image-model-11", "embeddings-model-12"); + final AppConfig appConfig = configService.config(host); + + assertTrue(appConfig.getModel().isOperational()); + assertTrue(appConfig.getImageModel().isOperational()); + assertTrue(appConfig.getEmbeddingsModel().isOperational()); + assertEquals("System Host", appConfig.getHost()); + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java new file mode 100644 index 000000000000..2913af554309 --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java @@ -0,0 +1,289 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.AiTest; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; +import com.dotcms.ai.util.LineReadingOutputStream; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import com.dotmarketing.util.UtilMethods; +import com.dotmarketing.util.json.JSONArray; +import com.dotmarketing.util.json.JSONObject; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; +import io.vavr.Tuple2; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for the AIProxyClient class. + * + * @author vico + */ +public class AIProxyClientTest { + + private static WireMockServer wireMockServer; + private static User user; + private Host host; + private AppConfig appConfig; + private final AIProxyClient aiProxyClient = AIProxyClient.get(); + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + wireMockServer = AiTest.prepareWireMock(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost); + ConfigService.INSTANCE.config(systemHost); + user = new UserDataGen().nextPersisted(); + } + + @AfterClass + public static void afterClass() { + wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); + } + + @Before + public void before() { + IPUtils.disabledIpPrivateSubnet(true); + host = new SiteDataGen().nextPersisted(); + } + + @After + public void after() throws Exception { + AiTest.removeAiAppSecrets(host); + } + + /** + * Scenario: Calling AI with a valid model + * Given a valid model "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_happiestPath() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets(host, model, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with multiple models + * Given multiple models including "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_happyPath_withMultipleModels() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets( + host, + String.format("%s,some-made-up-model-1", model), + "dall-e-3", + "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with an invalid model + * Given an invalid model "some-made-up-model-10" + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withInvalidModel() throws Exception { + final String invalidModel = "some-made-up-model-10"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + invalidModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT); + assertSame(ModelStatus.INVALID, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(invalidModel))); + } + + /** + * Scenario: Calling AI with a decommissioned model + * Given a decommissioned model "some-decommissioned-model-20" + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withDecommissionedModel() throws Exception { + final String decommissionedModel = "some-decommissioned-model-20"; + AiTest.aiAppSecrets(host, decommissionedModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + decommissionedModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(decommissionedModel, AIModelType.TEXT); + assertSame(ModelStatus.DECOMMISSIONED, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(decommissionedModel))); + } + + /** + * Scenario: Calling AI with multiple models including invalid, decommissioned, and valid models + * Given models "some-made-up-model-30", "some-decommissioned-model-31", and "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_withMultipleModels_invalidAndDecommissionedAndValid() throws Exception { + final String invalidModel = "some-made-up-model-30"; + final String decommissionedModel = "some-decommissioned-model-31"; + final String validModel = "gpt-4o-mini"; + AiTest.aiAppSecrets( + host, + String.format("%s,%s,%s", invalidModel, decommissionedModel, validModel), + "dall-e-3", + "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest(invalidModel, "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertSame(ModelStatus.INVALID, appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT)._2.getStatus()); + assertSame( + ModelStatus.DECOMMISSIONED, + appConfig.resolveModelOrThrow(decommissionedModel, AIModelType.TEXT)._2.getStatus()); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(validModel, AIModelType.TEXT); + assertSame(ModelStatus.ACTIVE, modelTuple._2.getStatus()); + assertEquals(2, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> List.of(invalidModel, decommissionedModel).contains(model.getName()))); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .anyMatch(model -> model.getName().equals(validModel))); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with a valid model and provided output stream + * Given a valid model "gpt-4o-mini" and a provided output stream + * When the request is sent to the AI service + * Then the response should be written to the output stream + */ + @Test + public void test_callToAI_withProvidedOutput() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets(host, model, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI( + request, + new LineReadingOutputStream(new ByteArrayOutputStream())); + assertNotNull(aiResponse); + assertNull(aiResponse.getResponse()); + } + + /** + * Scenario: Calling AI with an invalid model and provided output stream + * Given an invalid model "some-made-up-model-40" and a provided output stream + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withInvalidModel_withProvidedOutput() throws Exception { + final String invalidModel = "some-made-up-model-40"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + invalidModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT); + assertSame(ModelStatus.INVALID, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(invalidModel))); + } + + private JSONObjectAIRequest textRequest(final String model, final String prompt) { + final JSONObject payload = new JSONObject(); + final JSONArray messages = new JSONArray(); + + final String systemPrompt = UtilMethods.isSet(appConfig.getRolePrompt()) ? appConfig.getRolePrompt() : null; + if (UtilMethods.isSet(systemPrompt)) { + messages.add(Map.of(AiKeys.ROLE, AiKeys.SYSTEM, AiKeys.CONTENT, systemPrompt)); + } + messages.add(Map.of(AiKeys.ROLE, AiKeys.USER, AiKeys.CONTENT, prompt)); + + payload.put(AiKeys.MODEL, model); + payload.put(AiKeys.TEMPERATURE, appConfig.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)); + payload.put(AiKeys.MESSAGES, messages); + + return JSONObjectAIRequest.quickText(appConfig, payload, user.getUserId()); + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java index 3c61cd335f55..ce61bbb8b6d6 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java @@ -190,9 +190,9 @@ private static boolean waitForEmbeddings(final Contentlet blogContent, final Str return embeddingsExist; } - private static void addDotAISecrets() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, AiTest.API_KEY); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), AiTest.API_KEY); + private static void addDotAISecrets() throws Exception { + AiTest.aiAppSecrets(host, AiTest.API_KEY); + AiTest.aiAppSecrets(APILocator.systemHost(), AiTest.API_KEY); } private static void removeDotAISecrets() throws DotDataException, DotSecurityException { @@ -232,4 +232,5 @@ public static java.util.function.Predicate distinctByKey( Set seen = ConcurrentHashMap.newKeySet(); return t -> seen.add(keyExtractor.apply(t)); } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java new file mode 100644 index 000000000000..bf297cdfac3b --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java @@ -0,0 +1,108 @@ +package com.dotcms.ai.validator; + +import com.dotcms.ai.AiTest; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.ConfigService; +import com.dotcms.api.system.event.message.SystemMessageEventUtil; +import com.dotcms.api.system.event.message.builder.SystemMessage; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for the AIAppValidator class. + * This class tests the validation of AI configurations and model usage. + * It ensures that the AI models specified in the application configuration are supported + * and not exhausted. + * + * The tests cover scenarios for valid configurations, invalid configurations, and configurations + * with missing fields. + * + * @author vico + */ +public class AIAppValidatorTest { + + private static WireMockServer wireMockServer; + private static User user; + private static SystemMessageEventUtil systemMessageEventUtil; + private Host host; + private AppConfig appConfig; + private AIAppValidator validator = AIAppValidator.get(); + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + wireMockServer = AiTest.prepareWireMock(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost); + ConfigService.INSTANCE.config(systemHost); + user = new UserDataGen().nextPersisted(); + systemMessageEventUtil = mock(SystemMessageEventUtil.class); + } + + @AfterClass + public static void afterClass() { + wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); + } + + @Before + public void before() { + IPUtils.disabledIpPrivateSubnet(true); + host = new SiteDataGen().nextPersisted(); + validator.setSystemMessageEventUtil(systemMessageEventUtil); + } + + @After + public void after() throws Exception { + AiTest.removeAiAppSecrets(host); + } + + @Test/** + * Scenario: Validating AI configuration with unsupported models + * Given an AI configuration with unsupported models + * When the configuration is validated + * Then a warning message should be pushed to the user + */ + public void test_validateAIConfig() throws Exception { + final String invalidModel = "some-made-up-model-10"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + + verify(systemMessageEventUtil).pushMessage(any(SystemMessage.class), anyList()); + } + + /** + * Scenario: Validating AI models usage with exhausted models + * Given an AI model with exhausted models + * When the models usage is validated + * Then a warning message should be pushed to the user for each exhausted model + */ + @Test + public void test_validateModelsUsage() throws Exception { + final String invalidModels = "some-made-up-model-20,some-decommissioned-model-21"; + AiTest.aiAppSecrets(host, invalidModels, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + + validator.validateModelsUsage(appConfig.getModel(), user.getUserId()); + + verify(systemMessageEventUtil, times(2)) + .pushMessage(any(SystemMessage.class), anyList()); + } +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java index 314079da2a93..1062044719a9 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java @@ -7,6 +7,7 @@ import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; @@ -48,8 +49,9 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - config = ConfigService.INSTANCE.config(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + config = ConfigService.INSTANCE.config(systemHost); } @AfterClass diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java index f00c1ab31ae8..e481133edbca 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java @@ -6,6 +6,7 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; @@ -13,6 +14,7 @@ import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.commons.lang3.StringUtils; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; @@ -41,6 +43,7 @@ public class CompletionsToolTest { private static AppConfig appConfig; + private static User user; private static WireMockServer wireMockServer; private static Host host; @@ -52,9 +55,10 @@ public static void beforeClass() throws Exception { IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(APILocator.systemHost(), "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + AiTest.aiAppSecrets(host, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); } @AfterClass @@ -86,7 +90,7 @@ public void test_getConfig() { assertNotNull(config); assertEquals(AppKeys.COMPLETION_ROLE_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_ROLE_PROMPT.key)); assertEquals(AppKeys.COMPLETION_TEXT_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_TEXT_PROMPT.key)); - assertEquals("gpt-3.5-turbo-16k", config.get(AppKeys.TEXT_MODEL_NAMES.key)); + assertEquals("gpt-4o-mini", config.get(AppKeys.TEXT_MODEL_NAMES.key)); } /** @@ -119,7 +123,7 @@ public void test_summarize() { @Test public void test_raw() { final String query = "What is the speed of light in the vacuum"; - final String prompt = String.format("{\"model\":\"gpt-3.5-turbo-16k\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); + final String prompt = String.format("{\"model\":\"gpt-4o-mini\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); final JSONObject result = (JSONObject) completionsTool.raw(prompt); assertResult(result); @@ -142,7 +146,7 @@ public void test_raw_json() { messages.put(new JSONObject().put("role", "user").put("content", query + "?")); messages.put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style")); json.put("messages", messages); - json.put("model", "gpt-3.5-turbo-16k"); + json.put("model", "gpt-4o-mini"); final JSONObject result = (JSONObject) completionsTool.raw(json); assertResult(result); @@ -160,7 +164,7 @@ public void test_raw_json() { @Test public void test_raw_map() { final String query = "Who was the first president of the United States"; - final Map map = Map.of("model", "gpt-3.5-turbo-16k", "messages", new JSONArray() + final Map map = Map.of("model", "gpt-4o-mini", "messages", new JSONArray() .put(new JSONObject().put("role", "user").put("content", query + "?")) .put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style"))); @@ -179,6 +183,11 @@ Host host() { AppConfig config() { return appConfig; } + + @Override + User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java index a58bf579bc65..5f5b875a7f31 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java @@ -5,13 +5,13 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; -import com.dotmarketing.exception.DotSecurityException; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; import org.junit.Before; @@ -43,6 +43,7 @@ public class EmbeddingsToolTest { private Host host; private AppConfig appConfig; + private User user; private EmbeddingsTool embeddingsTool; @BeforeClass @@ -50,16 +51,17 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); + AiTest.aiAppSecrets(APILocator.systemHost()); } @Before - public void before() throws DotDataException, DotSecurityException { + public void before() throws Exception { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); host = new SiteDataGen().nextPersisted(); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(host); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); embeddingsTool = prepareEmbeddingsTool(viewContext); } @@ -136,6 +138,11 @@ Host host() { AppConfig appConfig() { return appConfig; } + + @Override + public User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java index a9cd8a62cd3b..4d7b7ef6e891 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java @@ -1,5 +1,6 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.AiTest; import com.dotcms.contenttype.model.field.Field; import com.dotcms.contenttype.model.field.TextField; import com.dotcms.contenttype.model.type.ContentType; @@ -12,6 +13,7 @@ import com.dotcms.datagen.TemplateDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.portlets.htmlpageasset.model.HTMLPageAsset; import com.dotmarketing.portlets.templates.model.Template; @@ -47,14 +49,16 @@ public class SearchToolTest { @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); + AiTest.aiAppSecrets(APILocator.systemHost(), "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); } @Before - public void before() { + public void before() throws Exception { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); host = new SiteDataGen().nextPersisted(); searchTool = prepareSearchTool(viewContext); + AiTest.aiAppSecrets(host, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); } /** diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java index 344efe66d3af..2d2d64362f76 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java @@ -78,7 +78,7 @@ public static void beforeClass() throws Exception { .withValue("{\"default\":\"blog\"}".toCharArray()) .build()); config = new AppConfig(host.getHostname(), secrets); - DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments)-> new CompletionsAPI() { + DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments) -> new CompletionsAPI() { @Override public JSONObject summarize(CompletionsForm searcher) { return null; @@ -90,7 +90,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(JSONObject promptJSON) { + public JSONObject raw(JSONObject promptJSON, final String userId) { return null; } @@ -104,7 +104,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java index 9e92628e2df6..3db2f537423c 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java @@ -89,7 +89,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(final JSONObject promptJSON) { + public JSONObject raw(final JSONObject promptJSON, final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + @@ -126,7 +126,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-integration/src/test/resources/mappings/apollo-space-program.json b/dotcms-integration/src/test/resources/mappings/apollo-space-program.json new file mode 100644 index 000000000000..7118ba306929 --- /dev/null +++ b/dotcms-integration/src/test/resources/mappings/apollo-space-program.json @@ -0,0 +1,41 @@ +{ + "request": { + "method": "POST", + "url": "/c", + "headers": { + "Content-Type": { + "equalTo": "application/json" + }, + "Authorization": { + "equalTo": "Bearer some-api-key-1a2bc3" + } + }, + "bodyPatterns": [ + { + "matches": ".*\"model\":\"gpt-4o-mini\".*\"content\":\"What are the major achievements of the Apollo space program.*" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "id": "cmpl-11", + "object": "text_completion", + "created": 1699999999, + "model": "gpt-4o-mini", + "choices": [ + { + "text": "The Apollo space program, conducted by NASA, achieved several major milestones in space exploration. Its most significant achievement was the successful landing of humans on the Moon. Apollo 11, in 1969, saw astronauts Neil Armstrong and Buzz Aldrin become the first humans to set foot on the lunar surface. The program also provided extensive scientific data and technological advancements.", + "index": 0, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 57, + "total_tokens": 67 + } + } + } +} diff --git a/dotcms-integration/src/test/resources/mappings/decommissioned-model.json b/dotcms-integration/src/test/resources/mappings/decommissioned-model.json new file mode 100644 index 000000000000..46f83ccb4b0b --- /dev/null +++ b/dotcms-integration/src/test/resources/mappings/decommissioned-model.json @@ -0,0 +1,30 @@ +{ + "request": { + "method": "POST", + "url": "/c", + "headers": { + "Content-Type": { + "equalTo": "application/json" + }, + "Authorization": { + "equalTo": "Bearer some-api-key-1a2bc3" + } + }, + "bodyPatterns": [ + { + "matches": ".*\"model\":\"some-decommissioned-model-..\".*" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "error": { + "message": "The model `some-decommissioned-model` has been deprecated, learn more here: https://platform.openai.com/docs/deprecations", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + } + } +} diff --git a/dotcms-integration/src/test/resources/mappings/invalid-model.json b/dotcms-integration/src/test/resources/mappings/invalid-model.json new file mode 100644 index 000000000000..810073b0b0ec --- /dev/null +++ b/dotcms-integration/src/test/resources/mappings/invalid-model.json @@ -0,0 +1,30 @@ +{ + "request": { + "method": "POST", + "url": "/c", + "headers": { + "Content-Type": { + "equalTo": "application/json" + }, + "Authorization": { + "equalTo": "Bearer some-api-key-1a2bc3" + } + }, + "bodyPatterns": [ + { + "matches": ".*\"model\":\"some-made-up-model-..\".*" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "error": { + "message": "The model `some-made-up-mode` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + } + } +} diff --git a/dotcms-postman/pom.xml b/dotcms-postman/pom.xml index 8957e0ff6835..a84e1627f93b 100644 --- a/dotcms-postman/pom.xml +++ b/dotcms-postman/pom.xml @@ -116,9 +116,6 @@ [WireMock] green - - --verbose - diff --git a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json index 9a49e8bcd38e..dd1b0b760d92 100644 --- a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json +++ b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json @@ -60,7 +60,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"apiKey\": {\n \"value\": \"some-api-key-1a2bc3\"\n },\n \"textModelNames\": {\n \"value\": \"gpt-4o\"\n },\n \"textModelMaxTokens\": {\n \"value\":\"16384\"\n },\n \"imageModelNames\": {\n \"value\": \"dall-e-3\"\n },\n \"imageSize\": {\n \"value\": \"1024x1024\"\n },\n \"imageModelMaxTokens\": {\n \"value\":\"0\"\n },\n \"embeddingsModelNames\": {\n \"value\": \"text-embedding-ada-002\"\n },\n \"embeddingsModelMaxTokens\": {\n \"value\":\"8191\"\n },\n \"listenerIndexer\": {\n \"value\": \"{\\\"default\\\":\\\"blog,dotcmsdocumentation,feature,ProductBriefs,news,report.file,builds,casestudy\\\",\\\"documentation\\\":\\\"dotcmsdocumentation\\\"}\"\n }\n}\n" + "raw": "{\n \"apiKey\": {\n \"value\": \"some-api-key-1a2bc3\"\n },\n \"textModelNames\": {\n \"value\": \"gpt-4o-mini\"\n },\n \"textModelMaxTokens\": {\n \"value\":\"16384\"\n },\n \"imageModelNames\": {\n \"value\": \"dall-e-3\"\n },\n \"imageSize\": {\n \"value\": \"1024x1024\"\n },\n \"imageModelMaxTokens\": {\n \"value\":\"0\"\n },\n \"embeddingsModelNames\": {\n \"value\": \"text-embedding-ada-002\"\n },\n \"embeddingsModelMaxTokens\": {\n \"value\":\"8191\"\n },\n \"listenerIndexer\": {\n \"value\": \"{\\\"default\\\":\\\"blog,dotcmsdocumentation,feature,ProductBriefs,news,report.file,builds,casestudy\\\",\\\"documentation\\\":\\\"dotcmsdocumentation\\\"}\"\n }\n}\n" }, "url": { "raw": "{{serverURL}}/api/v1/apps/dotAI/SYSTEM_HOST", @@ -3035,7 +3035,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"model\": \"text-embedding-ada-002\",\n \"responseLengthTokens\": 1\n}", + "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1\n}", "options": { "raw": { "language": "json" @@ -3111,7 +3111,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1,\n \"model\": \"text-embedding-ada-002\",\n \"stream\": true\n}", + "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1,\n \"stream\": true\n}", "options": { "raw": { "language": "json" @@ -3140,7 +3140,7 @@ "listen": "test", "script": { "exec": [ - "pm.test('Status code should be ok 200', function () {", + "pm.test('Status code should be ok 20', function () {", " pm.response.to.have.status(200);", "});", "", @@ -3430,20 +3430,16 @@ } ], "variable": [ - { - "key": "seoIndex", - "value": "" - }, { "key": "seoId", "value": "" }, { - "key": "seoContentTypeId", + "key": "seoText", "value": "" }, { - "key": "seoContentTypeVar", + "key": "seoIndex", "value": "" }, { @@ -3451,11 +3447,11 @@ "value": "" }, { - "key": "seoText", + "key": "seoContentTypeId", "value": "" }, { - "key": "key", + "key": "seoContentTypeVar", "value": "" } ] diff --git a/dotcms-postman/src/test/resources/mappings/apollo-apce-program.json b/dotcms-postman/src/test/resources/mappings/apollo-space-program.json similarity index 100% rename from dotcms-postman/src/test/resources/mappings/apollo-apce-program.json rename to dotcms-postman/src/test/resources/mappings/apollo-space-program.json