diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java index d980647f7b41..38ea42606703 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java @@ -1,9 +1,7 @@ package com.dotcms.ai.api; -import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotmarketing.util.json.JSONObject; -import io.vavr.Lazy; import java.io.OutputStream; @@ -37,9 +35,10 @@ public interface CompletionsAPI { * this method takes a prompt in the form of json and returns a json AI response based upon that prompt * * @param promptJSON + * @param userId * @return */ - JSONObject raw(JSONObject promptJSON); + JSONObject raw(JSONObject promptJSON, String userId); /** * this method takes a prompt and returns the AI response based upon that prompt @@ -58,9 +57,15 @@ public interface CompletionsAPI { * @param model * @param temperature * @param maxTokens + * @param userId * @return */ - JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens); + JSONObject prompt(String systemPrompt, + String userPrompt, + String model, + float temperature, + int maxTokens, + String userId); /** * this method takes a prompt in the form of json and returns streaming AI response based upon that prompt diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index 008f65609096..9c04e189ed55 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -2,13 +2,17 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.client.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.mock.request.FakeHttpRequest; import com.dotcms.mock.response.BaseResponse; @@ -16,18 +20,17 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; import io.vavr.Lazy; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.ArrayUtils; import org.apache.velocity.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.HttpMethod; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; @@ -42,15 +45,13 @@ public class CompletionsAPIImpl implements CompletionsAPI { private final AppConfig config; - private final Lazy defaultConfig; public CompletionsAPIImpl(final AppConfig config) { - defaultConfig = - Lazy.of(() -> ConfigService.INSTANCE.config( - Try.of(() -> WebAPILocator - .getHostWebAPI() - .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) - .getOrElse(APILocator.systemHost()))); + final Lazy defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config( + Try.of(() -> WebAPILocator + .getHostWebAPI() + .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest())) + .getOrElse(APILocator.systemHost()))); this.config = Optional.ofNullable(config).orElse(defaultConfig.get()); } @@ -59,8 +60,9 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String modelIn, final float temperature, - final int maxTokens) { - final AIModel model = config.resolveModelOrThrow(modelIn); + final int maxTokens, + final String userId) { + final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2; final JSONObject json = new JSONObject(); json.put(AiKeys.TEMPERATURE, temperature); @@ -70,15 +72,17 @@ public JSONObject prompt(final String systemPrompt, json.put(AiKeys.MAX_TOKENS, maxTokens); } - json.put(AiKeys.MODEL, model.getCurrentModel()); + json.put(AiKeys.MODEL, model.getName()); - return raw(json); + return raw(json, userId); } @Override public JSONObject summarize(final CompletionsForm summaryRequest) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); - final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); + final List localResults = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .getEmbeddingResults(searcher); // send all this as a json blob to OpenAI final JSONObject json = buildRequestJson(summaryRequest, localResults); @@ -87,58 +91,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) { } json.put(AiKeys.STREAM, false); - final String openAiResponse = - Try.of(() -> OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json)) - .getOrElseThrow(DotRuntimeException::new); - final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults); + final String openAiResponse = Try + .of(() -> sendRequest(config, json, UtilMethods.extractUserIdOrNull(summaryRequest.user))) + .getOrElseThrow(DotRuntimeException::new) + .getResponse(); + final JSONObject dotCMSResponse = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .reduceChunksToContent(searcher, localResults); dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse)); return dotCMSResponse; } @Override - public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) { + public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) { final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build(); - final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher); + final List localResults = APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .getEmbeddingResults(searcher); final JSONObject json = buildRequestJson(summaryRequest, localResults); json.put(AiKeys.STREAM, true); - OpenAIRequest.doPost(config.getApiUrl(), config, json, out); + AIProxyClient.get().callToAI( + JSONObjectAIRequest.quickText( + config, + json, + UtilMethods.extractUserIdOrNull(summaryRequest.user)), + output); } @Override - public JSONObject raw(final JSONObject json) { - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI request:" + json.toString(2)); - } + public JSONObject raw(final JSONObject json, final String userId) { + AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2)); - final String response = OpenAIRequest.doRequest( - config.getApiUrl(), - HttpMethod.POST, - config, - json); - if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(this.getClass(), "OpenAI response:" + response); - } + final String response = sendRequest(config, json, userId).getResponse(); + AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response); return new JSONObject(response); } @Override - public JSONObject raw(CompletionsForm promptForm) { + public JSONObject raw(final CompletionsForm promptForm) { JSONObject jsonObject = buildRequestJson(promptForm); - return raw(jsonObject); + return raw(jsonObject, UtilMethods.extractUserIdOrNull(promptForm.user)); } @Override - public void rawStream(final CompletionsForm promptForm, final OutputStream out) { + public void rawStream(final CompletionsForm promptForm, final OutputStream output) { final JSONObject json = buildRequestJson(promptForm); json.put(AiKeys.STREAM, true); - OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out); + AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText( + config, + json, + UtilMethods.extractUserIdOrNull(promptForm.user)), + output); + } + + private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) { + return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(appConfig, payload, userId)); } private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) { @@ -151,7 +161,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f } private JSONObject buildRequestJson(final CompletionsForm form, final List searchResults) { - final AIModel model = config.resolveModelOrThrow(form.model); + final Tuple2 modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT); // aggregate matching results into text final StringBuilder supportingContent = new StringBuilder(); searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" ")); @@ -162,7 +172,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List 0 && initArguments[0] instanceof AppConfig) { - return new OpenAIChatAPIImpl((AppConfig) initArguments[0]); + final User user = initArguments.length > 1 && initArguments[1] instanceof User + ? (User) initArguments[1] + : null; + return new OpenAIChatAPIImpl((AppConfig) initArguments[0], user); } throw new IllegalArgumentException("To create a ChatAPI you need to provide an AppConfig"); @@ -147,7 +150,7 @@ public static final void setDefaultImageAPIProvider(final ImageAPIProvider image /** * Set the default chat API Provider. - * @param chatAPIProviderq + * @param chatAPIProvider */ public static final void setDefaultChatAPIProvider(final ChatAPIProvider chatAPIProvider) { chatProviderMap.put(DEFAULT, chatAPIProvider); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java index 2ffaa8e702db..e619ae11ab03 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java @@ -140,10 +140,11 @@ public interface EmbeddingsAPI { * Embeddings * * @param content The content that will be tokenized and sent to OpenAI. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String content); + Tuple2> pullOrGenerateEmbeddings(String content, String userId); /** * this method takes a snippet of content and will try to see if we have already generated @@ -154,10 +155,11 @@ public interface EmbeddingsAPI { * * @param contentId The ID of the Contentlet being sent to the OpenAI Endpoint. * @param content The actual indexable data that will be tokenized and sent to OpenAI service. + * @param userId The ID of the user making the request. * * @return Tuple(Count of Tokens Input, List of Embeddings Output) */ - Tuple2> pullOrGenerateEmbeddings(final String contentId, final String content); + Tuple2> pullOrGenerateEmbeddings(String contentId, String content, String userId); /** * Checks if the embeddings for the given inode, indexName, and extractedText already exist in the database. diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java index c40a8e1692e4..a02d3e8fdf8f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java @@ -4,12 +4,13 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.client.AIProxyClient; import com.dotcms.ai.db.EmbeddingsDTO; import com.dotcms.ai.db.EmbeddingsDTO.Builder; import com.dotcms.ai.db.EmbeddingsFactory; +import com.dotcms.ai.client.JSONObjectAIRequest; import com.dotcms.ai.util.ContentToStringUtil; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.VelocityContextFactory; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.api.web.HttpServletResponseThreadLocal; @@ -43,7 +44,6 @@ import org.apache.velocity.context.Context; import javax.validation.constraints.NotNull; -import javax.ws.rs.HttpMethod; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -311,13 +311,15 @@ public void initEmbeddingsTable() { } @Override - public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content) { - return pullOrGenerateEmbeddings("N/A", content); + public Tuple2> pullOrGenerateEmbeddings(@NotNull final String content, final String userId) { + return pullOrGenerateEmbeddings("N/A", content, userId); } @WrapInTransaction @Override - public Tuple2> pullOrGenerateEmbeddings(final String contentId, @NotNull final String content) { + public Tuple2> pullOrGenerateEmbeddings(final String contentId, + @NotNull final String content, + final String userId) { if (UtilMethods.isEmpty(content)) { return Tuple.of(0, List.of()); } @@ -349,7 +351,7 @@ public Tuple2> pullOrGenerateEmbeddings(final String conten final Tuple2> openAiEmbeddings = Tuple.of( tokens.size(), - sendTokensToOpenAI(contentId, tokens)); + sendTokensToOpenAI(contentId, tokens, userId)); saveEmbeddingsForCache(content, openAiEmbeddings); EMBEDDING_CACHE.put(hashed, openAiEmbeddings); @@ -420,19 +422,20 @@ private void saveEmbeddingsForCache(final String content, final Tuple2 sendTokensToOpenAI(final String contentId, @NotNull final List tokens) { + private List sendTokensToOpenAI(final String contentId, + @NotNull final List tokens, + final String userId) { final JSONObject json = new JSONObject(); json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel()); json.put(AiKeys.INPUT, tokens); debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens)); - final String responseString = OpenAIRequest.doRequest( - config.getApiEmbeddingsUrl(), - HttpMethod.POST, - config, - json); + final String responseString = AIProxyClient.get() + .callToAI(JSONObjectAIRequest.quickEmbeddings(config, json, userId)) + .getResponse(); debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s", contentId, responseString.replace("\n", BLANK))); final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> { @@ -490,8 +493,10 @@ private List getEmbeddingsFromJSON(final String contentId, final JSONObje } } - private EmbeddingsDTO getSearcher(EmbeddingsDTO searcher) { - final List queryEmbeddings = pullOrGenerateEmbeddings(searcher.query)._2; + private EmbeddingsDTO getSearcher(final EmbeddingsDTO searcher) { + final List queryEmbeddings = pullOrGenerateEmbeddings( + searcher.query, + UtilMethods.extractUserIdOrNull(searcher.user))._2; return EmbeddingsDTO.copy(searcher).withEmbeddings(queryEmbeddings).build(); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java index 47058164d4ca..91d3e9a01723 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java @@ -6,6 +6,7 @@ import com.dotcms.ai.util.EncodingUtil; import com.dotcms.business.WrapInTransaction; import com.dotcms.exception.ExceptionUtil; +import com.dotmarketing.business.APILocator; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; @@ -119,7 +120,10 @@ private void saveEmbedding(@NotNull final String initial) { } final Tuple2> embeddings = - this.embeddingsAPI.pullOrGenerateEmbeddings(this.contentlet.getIdentifier(), normalizedContent); + this.embeddingsAPI.pullOrGenerateEmbeddings( + contentlet.getIdentifier(), + normalizedContent, + UtilMethods.extractUserIdOrNull(APILocator.systemUser())); if (embeddings._2.isEmpty()) { Logger.info(this.getClass(), String.format("No tokens for Content Type " + "'%s'. Normalized content: %s", this.contentlet.getContentType().variable(), normalizedContent)); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java index 2e94dbe4218d..b99d31b196a3 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIChatAPIImpl.java @@ -3,21 +3,24 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.util.OpenAIRequest; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.client.JSONObjectAIRequest; import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; -import javax.ws.rs.HttpMethod; import java.util.List; import java.util.Map; public class OpenAIChatAPIImpl implements ChatAPI { private final AppConfig config; + private final User user; - public OpenAIChatAPIImpl(final AppConfig appConfig) { + public OpenAIChatAPIImpl(final AppConfig appConfig, final User user) { this.config = appConfig; + this.user = user; } @Override @@ -36,7 +39,7 @@ public JSONObject sendRawRequest(final JSONObject prompt) { prompt.remove(AiKeys.PROMPT); - return new JSONObject(doRequest(config.getApiUrl(), prompt)); + return new JSONObject(doRequest(prompt, UtilMethods.extractUserIdOrNull(user)); } @Override @@ -47,8 +50,8 @@ public JSONObject sendTextPrompt(final String textPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final JSONObject json, final String userId) { + return AIProxyClient.get().callToAI(JSONObjectAIRequest.quickText(config, json, userId)).getResponse(); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java index 041a49b6e2d1..8178f6a31c03 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/OpenAIImageAPIImpl.java @@ -2,8 +2,9 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.client.AIProxyClient; +import com.dotcms.ai.client.JSONObjectAIRequest; import com.dotcms.ai.model.AIImageRequestDTO; -import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.ai.util.OpenAiRequestUtil; import com.dotcms.ai.util.StopWordsUtil; import com.dotcms.api.web.HttpServletRequestThreadLocal; @@ -24,7 +25,6 @@ import io.vavr.control.Try; import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.HttpMethod; import java.net.URL; import java.text.SimpleDateFormat; import java.util.Date; @@ -147,7 +147,7 @@ private HttpServletRequest getRequest() { .request()); requestProxy.setAttribute(WebKeys.CMS_USER, user); requestProxy.getSession().setAttribute(WebKeys.CMS_USER, user); - requestProxy.setAttribute(com.liferay.portal.util.WebKeys.USER_ID, user.getUserId()); + requestProxy.setAttribute(com.liferay.portal.util.WebKeys.USER_ID, UtilMethods.extractUserIdOrNull(user)); return requestProxy; } @@ -173,8 +173,10 @@ private String generateFileName(final String originalPrompt) { } @VisibleForTesting - public String doRequest(final String urlIn, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); + String doRequest(final String urlIn, final JSONObject json) { + return AIProxyClient.get() + .callToAI(JSONObjectAIRequest.quickImage(config, json, UtilMethods.extractUserIdOrNull(user))) + .getResponse(); } @VisibleForTesting diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 7c25179bf886..91be3359de7f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -1,5 +1,8 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; import com.dotcms.ai.model.OpenAIModel; import com.dotcms.ai.model.OpenAIModels; import com.dotcms.ai.model.SimpleModel; @@ -13,17 +16,16 @@ import io.vavr.Lazy; import io.vavr.Tuple; import io.vavr.Tuple2; +import io.vavr.Tuple3; import org.apache.commons.collections4.CollectionUtils; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -40,28 +42,55 @@ public class AIModels { private static final int AI_MODELS_FETCH_ATTEMPTS = Config.getIntProperty(AI_MODELS_FETCH_ATTEMPTS_KEY, 3); private static final String AI_MODELS_FETCH_TIMEOUT_KEY = "ai.models.fetch.timeout"; private static final int AI_MODELS_FETCH_TIMEOUT = Config.getIntProperty(AI_MODELS_FETCH_TIMEOUT_KEY, 4000); - private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private static final String AI_MODELS_API_URL_KEY = "DOT_AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_KEY = "AI_MODELS_API_URL"; + private static final String AI_MODELS_API_URL_DEFAULT = "https://api.openai.com/v1/models"; private static final String AI_MODELS_API_URL = Config.getStringProperty( AI_MODELS_API_URL_KEY, - "https://api.openai.com/v1/models"); + AI_MODELS_API_URL_DEFAULT); private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours - private static final int AI_MODELS_CACHE_SIZE = 128; + private static final int AI_MODELS_CACHE_SIZE = 256; + private static final Lazy INSTANCE = Lazy.of(AIModels::new); - private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); - private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = - Caffeine.newBuilder() - .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) - .maximumSize(AI_MODELS_CACHE_SIZE) - .build(); - private Supplier appConfigSupplier = ConfigService.INSTANCE::config; + private final ConcurrentMap>> internalModels; + private final ConcurrentMap, AIModel> modelsByName; + private final Cache> supportedModelsCache; public static AIModels get() { return INSTANCE.get(); } + private static CircuitBreakerUrl.Response fetchOpenAIModels(final String apiKey) { + final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() + .setMethod(CircuitBreakerUrl.Method.GET) + .setUrl(AI_MODELS_API_URL) + .setTimeout(AI_MODELS_FETCH_TIMEOUT) + .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) + .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + apiKey)) + .setThrowWhenNot2xx(true) + .build() + .doResponse(OpenAIModels.class); + + if (!CircuitBreakerUrl.isSuccessResponse(response)) { + AppConfig.debugLogger( + AIModels.class, + () -> String.format( + "Error fetching OpenAI supported models from [%s] (status code: [%d])", + AI_MODELS_API_URL, + response.getStatusCode())); + throw new DotRuntimeException("Error fetching OpenAI supported models"); + } + + return response; + } + private AIModels() { + internalModels = new ConcurrentHashMap<>(); + modelsByName = new ConcurrentHashMap<>(); + supportedModelsCache = + Caffeine.newBuilder() + .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) + .maximumSize(AI_MODELS_CACHE_SIZE) + .build(); } /** @@ -73,46 +102,47 @@ private AIModels() { * @param loading the list of AI models to load */ public void loadModels(final String host, final List loading) { - Optional.ofNullable(internalModels.get(host)) - .ifPresentOrElse( - model -> {}, - () -> internalModels.putIfAbsent( - host, - loading.stream() - .map(model -> Tuple.of(model.getType(), model)) - .collect(Collectors.toList()))); + final List> added = internalModels.putIfAbsent( + host, + loading.stream() + .map(model -> Tuple.of(model.getType(), model)) + .collect(Collectors.toList())); loading.forEach(aiModel -> aiModel .getModels() .forEach(model -> { - final Tuple2 key = Tuple.of( - host, - model.getName().toLowerCase().trim()); + final Tuple3 key = Tuple.of(host, model, aiModel.getType()); if (modelsByName.containsKey(key)) { - Logger.debug( - this, - String.format( + AppConfig.debugLogger( + getClass(), + () -> String.format( "Model [%s] already exists for host [%s], ignoring it", - model.getName(), + model, host)); return; } modelsByName.putIfAbsent(key, aiModel); })); + activateModels(host, added == null); } /** * Finds an AI model by the host and model name. The search is case-insensitive. * - * @param host the host for which the model is being searched + * @param appConfig the AppConfig for the host * @param modelName the name of the model to find + * @param type the type of the model to find * @return an Optional containing the found AIModel, or an empty Optional if not found */ - public Optional findModel(final String host, final String modelName) { + public Optional findModel(final AppConfig appConfig, + final String modelName, + final AIModelType type) { final String lowered = modelName.toLowerCase(); - final Set supported = getOrPullSupportedModels(); - return supported.contains(lowered) - ? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered))) - : Optional.empty(); + return Optional.ofNullable( + modelsByName.get( + Tuple.of( + appConfig.getHost(), + Model.builder().withName(lowered).build(), + type))); } /** @@ -130,6 +160,45 @@ public Optional findModel(final String host, final AIModelType type) { .findFirst()); } + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param host the host for which the model is being resolved + * @param type the type of the model to find + */ + public AIModel resolveModel(final String host, final AIModelType type) { + return findModel(host, type).orElse(AIModel.NOOP_MODEL); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + */ + public AIModel resolveAIModelOrThrow(final AppConfig appConfig, final String modelName, final AIModelType type) { + return findModel(appConfig, modelName, type) + .orElseThrow(() -> new DotAIModelNotFoundException( + String.format("Unable to find model: [%s] of type [%s].", modelName, type))); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param appConfig the AppConfig for the host + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return a Tuple2 containing the AIModel and the Model + */ + public Tuple2 resolveModelOrThrow(final AppConfig appConfig, + final String modelName, + final AIModelType type) { + final AIModel aiModel = resolveAIModelOrThrow(appConfig, modelName, type); + return Tuple.of(aiModel, aiModel.getModel(modelName)); + } + /** * Resets the internal models cache for the specified host. * @@ -145,6 +214,7 @@ public void resetModels(final String host) { .filter(key -> key._1.equals(host)) .collect(Collectors.toSet()) .forEach(modelsByName::remove); + cleanSupportedModelsCache(); } /** @@ -153,19 +223,18 @@ public void resetModels(final String host) { * * @return a set of supported model names */ - public Set getOrPullSupportedModels() { + public Set getOrPullSupportedModels(final AppConfig appConfig) { final Set cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); if (CollectionUtils.isNotEmpty(cached)) { return cached; } - final AppConfig appConfig = appConfigSupplier.get(); if (!appConfig.isEnabled()) { AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty set of supported models"); return Set.of(); } - final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig); + final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig.getApiKey()); if (Objects.nonNull(response.getResponse().getError())) { throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage()); } @@ -188,55 +257,50 @@ public Set getOrPullSupportedModels() { * @return a list of available model names */ public List getAvailableModels() { - final Set configured = internalModels.entrySet() + return internalModels.entrySet() .stream() .flatMap(entry -> entry.getValue().stream()) .map(Tuple2::_2) + .filter(AIModel::isOperational) .flatMap(aiModel -> aiModel.getModels() .stream() - .map(model -> new SimpleModel(model.getName(), aiModel.getType()))) - .collect(Collectors.toSet()); - final Set supported = getOrPullSupportedModels() - .stream() - .map(SimpleModel::new) - .collect(Collectors.toSet()); - configured.retainAll(supported); - - return new ArrayList<>(configured); + .filter(Model::isOperational) + .map(model -> new SimpleModel( + model.getName(), + aiModel.getType(), + aiModel.getCurrentModelIndex() == model.getIndex()))) + .distinct() + .collect(Collectors.toList()); } - private static CircuitBreakerUrl.Response fetchOpenAIModels(final AppConfig appConfig) { - final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() - .setMethod(CircuitBreakerUrl.Method.GET) - .setUrl(AI_MODELS_API_URL) - .setTimeout(AI_MODELS_FETCH_TIMEOUT) - .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) - .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) - .setThrowWhenNot2xx(true) - .build() - .doResponse(OpenAIModels.class); + @VisibleForTesting + void cleanSupportedModelsCache() { + supportedModelsCache.invalidate(SUPPORTED_MODELS_KEY); + } - if (!CircuitBreakerUrl.isSuccessResponse(response)) { - Logger.debug( - AIModels.class, - String.format( - "Error fetching OpenAI supported models from [%s] (status code: [%d])", - AI_MODELS_API_URL, - response.getStatusCode())); - throw new DotRuntimeException("Error fetching OpenAI supported models"); + private void activateModels(final String host, boolean wasAdded) { + if (!wasAdded) { + return; } - return response; - } - - @VisibleForTesting - void setAppConfigSupplier(final Supplier appConfigSupplier) { - this.appConfigSupplier = appConfigSupplier; - } + final List aiModels = internalModels.get(host) + .stream() + .map(tuple -> tuple._2) + .collect(Collectors.toList()); - @VisibleForTesting - void cleanSupportedModelsCache() { - supportedModelsCache.invalidateAll(); + aiModels.forEach(aiModel -> + aiModel.getModels().forEach(model -> { + final String modelName = model.getName().trim().toLowerCase(); + final ModelStatus status; + status = ModelStatus.ACTIVE; + if (aiModel.getCurrentModelIndex() == -1) { + aiModel.setCurrentModelIndex(model.getIndex()); + } + Logger.debug( + this, + String.format("Model [%s] is supported by OpenAI, marking it as [%s]", modelName, status)); + model.setStatus(status); + })); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index bd8515e55834..04647f8be4b0 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,11 +1,11 @@ package com.dotcms.ai.app; -import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import com.liferay.util.StringPool; +import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.commons.lang3.StringUtils; @@ -294,19 +294,29 @@ public String getConfig(final AppKeys appKey) { * @param type the type of the model to find */ public AIModel resolveModel(final AIModelType type) { - return AIModels.get().findModel(host, type).orElse(AIModel.NOOP_MODEL); + return AIModels.get().resolveModel(host, type); } /** * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. * * @param modelName the name of the model to find + * @param type the type of the model to find */ - public AIModel resolveModelOrThrow(final String modelName) { - return AIModels.get() - .findModel(host, modelName) - .orElseThrow(() -> - new DotAIModelNotFoundException(String.format("Unable to find model: [%s].", modelName))); + public AIModel resolveAIModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveAIModelOrThrow(this, modelName, type); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * If the model is not found or is not operational, it throws an appropriate exception. + * + * @param modelName the name of the model to find + * @param type the type of the model to find + * @return the resolved Model + */ + public Tuple2 resolveModelOrThrow(final String modelName, final AIModelType type) { + return AIModels.get().resolveModelOrThrow(this, modelName, type); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java new file mode 100644 index 000000000000..d5689c71da2e --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -0,0 +1,262 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.AIResponseData; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; +import com.dotcms.ai.validator.AIAppValidator; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.UtilMethods; +import io.vavr.Tuple; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.commons.io.IOUtils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +/** + * Implementation of the {@link AIClientStrategy} interface that provides a fallback mechanism + * for handling AI client requests. + * + *

+ * This class attempts to send a request using a primary AI model and, if the request fails, + * it falls back to alternative models until a successful response is obtained or all models + * are exhausted. + *

+ * + *

+ * The fallback strategy ensures that the AI client can continue to function even if some models + * are not operational or fail to process the request. + *

+ * + * @author vico + */ +public class AIModelFallbackStrategy implements AIClientStrategy { + + /** + * Applies the fallback strategy to the given AI client request and handles the response. + * + *

+ * This method first attempts to send the request using the primary model. If the request + * fails, it falls back to alternative models until a successful response is obtained or + * all models are exhausted. + *

+ * + * @param client the AI client to which the request is sent + * @param handler the response evaluator to handle the response + * @param request the AI request to be processed + * @param output the output stream to which the response will be written + * @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained + */ + @Override + public void applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream output) { + final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); + final Tuple2 modelTuple = resolveModel(jsonRequest); + + final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple); + if (firstAttempt.isSuccess()) { + return; + } + + runFallbacks(client, handler, jsonRequest, output, modelTuple); + } + + private static Tuple2 resolveModel(final JSONObjectAIRequest request) { + final AppConfig appConfig = request.getConfig(); + final String modelName = request.getPayload().optString(AiKeys.MODEL); + if (UtilMethods.isSet(modelName)) { + return appConfig.resolveModelOrThrow(modelName, request.getType()); + } + + final Optional aiModelOpt = AIModels.get().findModel(appConfig.getHost(), request.getType()); + if (aiModelOpt.isPresent()) { + final AIModel aiModel = aiModelOpt.get(); + if (aiModel.isOperational()) { + aiModel.repairCurrentIndexIfNeeded(); + return appConfig.resolveModelOrThrow(aiModel.getCurrentModel(), aiModel.getType()); + } + + notifyFailure(aiModel, request); + } + + throw new DotAIAllModelsExhaustedException(String.format("No models found for type [%s]", request.getType())); + } + + private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { + if (firstAttempt.equals(model)) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the same as the current one [%s].", + model.getName(), + firstAttempt.getName())); + return true; + } + + return false; + } + + private static boolean isOperational(final Model model) { + if (!model.isOperational()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] is not operational. Skipping.", model.getName())); + return false; + } + + return true; + } + + private static AIResponseData doSend(final AIClient client, final AIRequest request) { + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + client.sendRequest(request, output); + + final AIResponseData responseData = new AIResponseData(); + responseData.setResponse(output.toString()); + IOUtils.closeQuietly(output); + + return responseData; + } + + private static void redirectOutput(final OutputStream output, final String response) { + try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { + IOUtils.copy(input, output); + } catch (IOException e) { + throw new DotRuntimeException(e); + } + } + + private static void notifyFailure(final AIModel aiModel, final AIRequest request) { + AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId()); + } + + private static void handleFailure(final Tuple2 modelTuple, + final AIRequest request, + final AIResponseData responseData) { + final AIModel aiModel = modelTuple._1; + final Model model = modelTuple._2; + + if (!responseData.getStatus().doesNeedToThrow()) { + model.setStatus(responseData.getStatus()); + } + + if (model.getIndex() == aiModel.getModels().size() - 1) { + aiModel.setCurrentModelIndex(-1); + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] is the last one. Cannot fallback anymore.", + model.getName())); + + notifyFailure(aiModel, request); + + throw new DotAIAllModelsExhaustedException( + String.format("All models for type [%s] has been exhausted.", aiModel.getType())); + } else { + aiModel.setCurrentModelIndex(model.getIndex() + 1); + } + } + + private static AIResponseData sendAttempt(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream output, + final Tuple2 modelTuple) { + + final AIResponseData responseData = Try + .of(() -> doSend(client, request)) + .getOrElseGet(exception -> fromException(evaluator, exception)); + + if (!responseData.isSuccess()) { + if (responseData.getStatus().doesNeedToThrow()) { + if (!modelTuple._1.isOperational()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "All models from type [%s] are not operational. Throwing exception.", + modelTuple._1.getType())); + notifyFailure(modelTuple._1, request); + } + throw responseData.getException(); + } + } else { + evaluator.fromResponse(responseData.getResponse(), responseData, output instanceof ByteArrayOutputStream); + } + + if (responseData.isSuccess()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); + redirectOutput(output, responseData.getResponse()); + } else { + logFailure(modelTuple, responseData); + + handleFailure(modelTuple, request, responseData); + } + + return responseData; + } + + private static void logFailure(final Tuple2 modelTuple, final AIResponseData responseData) { + Optional + .ofNullable(responseData.getResponse()) + .ifPresentOrElse( + response -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with response:%s%s%s. Trying next model.", + modelTuple._2.getName(), + System.lineSeparator(), + response, + System.lineSeparator())), + () -> AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed with error: [%s]. Trying next model.", + modelTuple._2.getName(), + responseData.getError()))); + } + + private static AIResponseData fromException(final AIResponseEvaluator evaluator, final Throwable exception) { + final AIResponseData metadata = new AIResponseData(); + evaluator.fromException(exception, metadata); + return metadata; + } + + private static void runFallbacks(final AIClient client, + final AIResponseEvaluator evaluator, + final JSONObjectAIRequest request, + final OutputStream output, + final Tuple2 modelTuple) { + for(final Model model : modelTuple._1.getModels()) { + if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) { + continue; + } + + request.getPayload().put(AiKeys.MODEL, model.getName()); + final AIResponseData responseData = sendAttempt( + client, + evaluator, + request, + output, + Tuple.of(modelTuple._1, model)); + if (responseData.isSuccess()) { + return; + } + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java index 1040f3516cf6..08b2c34f0a6b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java @@ -19,9 +19,7 @@ public enum AIProxyStrategy { DEFAULT(new AIDefaultStrategy()), - // TODO: pr-split -> uncomment this line - //MODEL_FALLBACK(new AIModelFallbackStrategy()); - MODEL_FALLBACK(null); + MODEL_FALLBACK(new AIModelFallbackStrategy()); private final AIClientStrategy strategy; diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java index 89d3638d15a7..ab12dbba58f3 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -102,9 +102,7 @@ public void sendRequest(final AIRequest request, fin final String modelName = Optional .ofNullable(payload.optString(AiKeys.MODEL)) .orElseThrow(() -> new DotAIModelNotFoundException("Model is not present in the request")); - // TODO: pr-split -> uncomment this line - //final Tuple2 modelTuple = appConfig.resolveModelOrThrow(modelName, jsonRequest.getType()); - final Tuple2 modelTuple = Tuple.of(null, null); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(modelName, jsonRequest.getType()); final AIModel aiModel = modelTuple._1; if (!modelTuple._2.isOperational()) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index 9739bab313eb..5c5a7b24d5ef 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -3,6 +3,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.db.EmbeddingsDTO; +import com.dotcms.ai.exception.DotAIAppConfigDisabledException; import com.dotcms.content.elasticsearch.business.event.ContentletArchiveEvent; import com.dotcms.content.elasticsearch.business.event.ContentletDeletedEvent; import com.dotcms.content.elasticsearch.business.event.ContentletPublishEvent; @@ -10,7 +11,6 @@ import com.dotcms.system.event.local.model.Subscriber; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.portlets.contentlet.model.ContentletListener; import com.dotmarketing.util.Logger; @@ -86,7 +86,7 @@ private AppConfig getAppConfig(final String hostId) { AppConfig.debugLogger( getClass(), () -> "dotAI is not enabled since no API urls or API key found in app config"); - throw new DotRuntimeException("App dotAI config without API urls or API key"); + throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key"); } return appConfig; diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java index 53f83c3ab149..e289b64c9a0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java @@ -1,6 +1,7 @@ package com.dotcms.ai.model; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; @@ -15,12 +16,11 @@ public class AIImageRequestDTO { private final String model; - public AIImageRequestDTO(Builder builder) { + public AIImageRequestDTO(final Builder builder) { this.numberOfImages = builder.numberOfImages; this.model = builder.model; this.prompt = builder.prompt; this.size = builder.size; - } public String getSize() { @@ -40,14 +40,15 @@ public String getModel() { } public static class Builder { + private AppConfig appConfig = ConfigService.INSTANCE.config(); @JsonSetter(nulls = Nulls.SKIP) private String prompt; @JsonSetter(nulls = Nulls.SKIP) private int numberOfImages = 1; @JsonSetter(nulls = Nulls.SKIP) - private String size = ConfigService.INSTANCE.config().getImageSize(); + private String size = appConfig.getImageSize(); @JsonSetter(nulls = Nulls.SKIP) - private String model = ConfigService.INSTANCE.config().getImageModel().getCurrentModel(); + private String model = appConfig.getImageModel().getCurrentModel(); public AIImageRequestDTO build() { return new AIImageRequestDTO(this); diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java index c5486b61191f..b24e042e853f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java @@ -17,16 +17,20 @@ public class SimpleModel implements Serializable { private final String name; private final AIModelType type; + private final boolean current; @JsonCreator - public SimpleModel(@JsonProperty("name") final String name, @JsonProperty("type") final AIModelType type) { + public SimpleModel(@JsonProperty("name") final String name, + @JsonProperty("type") final AIModelType type, + @JsonProperty("current") final boolean current) { this.name = name; this.type = type; + this.current = current; } @JsonCreator public SimpleModel(@JsonProperty("name") final String name) { - this(name, null); + this(name, null, false); } public String getName() { @@ -37,17 +41,30 @@ public AIModelType getType() { return type; } + public boolean isCurrent() { + return current; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SimpleModel that = (SimpleModel) o; - return Objects.equals(name, that.name); + return Objects.equals(name, that.name) && type == that.type; } @Override public int hashCode() { - return Objects.hashCode(name); + return Objects.hash(name, type); + } + + @Override + public String toString() { + return "SimpleModel{" + + "name='" + name + '\'' + + ", type=" + type + + ", current=" + current + + '}'; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index e7b62cf46712..5499de4ce660 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -61,7 +61,9 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().summarizeStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .summarizeStream(formIn, new LineReadingOutputStream(output))); } /** @@ -84,7 +86,9 @@ public final Response rawPrompt(@Context final HttpServletRequest request, response, formIn, () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn), - out -> APILocator.getDotAIAPI().getCompletionsAPI().rawStream(formIn, new LineReadingOutputStream(out))); + output -> APILocator.getDotAIAPI() + .getCompletionsAPI() + .rawStream(formIn, new LineReadingOutputStream(output))); } /** @@ -107,16 +111,15 @@ public final Response getConfig(@Context final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - final AppConfig app = ConfigService.INSTANCE.config(host); - + final AppConfig appConfig = ConfigService.INSTANCE.config(host); final Map map = new HashMap<>(); map.put(AiKeys.CONFIG_HOST, host.getHostname() + " (falls back to system host)"); for (final AppKeys config : AppKeys.values()) { - map.put(config.key, app.getConfig(config)); + map.put(config.key, appConfig.getConfig(config)); } - final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET"; + final String apiKey = UtilMethods.isSet(appConfig.getApiKey()) ? "*****" : "NOT SET"; map.put(AppKeys.API_KEY.key, apiKey); final List models = AIModels.get().getAvailableModels(); @@ -140,19 +143,25 @@ private static CompletionsForm resolveForm(final HttpServletRequest request, .init() .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); - return (!user.isAdmin()) - ? CompletionsForm - .copy(formIn) - .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) - .build() - : formIn; + return withUserId( + !user.isAdmin() + ? CompletionsForm + .copy(formIn) + .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) + .build() + : formIn, + user); + } + + private static CompletionsForm withUserId(final CompletionsForm completionsForm, final User user) { + return CompletionsForm.copy(completionsForm).user(user).build(); } private static Response getResponse(final HttpServletRequest request, final HttpServletResponse response, final CompletionsForm formIn, final Supplier noStream, - final Consumer stream) { + final Consumer outputStream) { if (StringUtils.isBlank(formIn.prompt)) { return badRequestResponse(); } @@ -162,7 +171,7 @@ private static Response getResponse(final HttpServletRequest request, if (resolvedForm.stream) { final StreamingOutput streaming = output -> { - stream.accept(output); + outputStream.accept(output); output.flush(); output.close(); }; @@ -174,5 +183,4 @@ private static Response getResponse(final HttpServletRequest request, return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java index 375625d58adf..5ea184f26880 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java @@ -6,7 +6,6 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.model.AIImageRequestDTO; import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotcms.rest.WebResource; import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java index fae06a565d3b..6527eb984419 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java @@ -10,6 +10,7 @@ import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -56,7 +57,7 @@ public Response doGet(@Context final HttpServletRequest request, * * @param request the HTTP request * @param response the HTTP response - * @param formIn the form data containing the prompt + * @param form the form data containing the prompt * @return a Response object containing the generated text * @throws IOException if an I/O error occurs */ @@ -65,13 +66,14 @@ public Response doGet(@Context final HttpServletRequest request, @Produces(MediaType.APPLICATION_JSON) public Response doPost(@Context final HttpServletRequest request, @Context final HttpServletResponse response, - final CompletionsForm formIn) throws IOException { + final CompletionsForm form) throws IOException { - new WebResource.InitBuilder(request, response) + final User user = new WebResource.InitBuilder(request, response) .requiredBackendUser(true) .requiredFrontendUser(true) .init() .getUser(); + final CompletionsForm formIn = CompletionsForm.copy(form).user(user).build(); if (UtilMethods.isEmpty(formIn.prompt)) { return Response @@ -82,7 +84,12 @@ public Response doPost(@Context final HttpServletRequest request, final AppConfig config = ConfigService.INSTANCE.config(WebAPILocator.getHostWebAPI().getHost(request)); - return Response.ok(APILocator.getDotAIAPI().getCompletionsAPI().raw(generateRequest(formIn, config)).toString()).build(); + return Response.ok( + APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(generateRequest(formIn, config), user.getUserId()) + .toString()) + .build(); } /** @@ -93,12 +100,12 @@ public Response doPost(@Context final HttpServletRequest request, * @return a JSONObject representing the request */ private JSONObject generateRequest(final CompletionsForm form, final AppConfig config) { - final String systemPrompt = UtilMethods.isSet(config.getRolePrompt()) ? config.getRolePrompt() : null; final String model = form.model; final float temperature = form.temperature; final JSONObject request = new JSONObject(); final JSONArray messages = new JSONArray(); + final String systemPrompt = UtilMethods.isSet(config.getRolePrompt()) ? config.getRolePrompt() : null; if (UtilMethods.isSet(systemPrompt)) { messages.add(Map.of(AiKeys.ROLE, AiKeys.SYSTEM, AiKeys.CONTENT, systemPrompt)); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java index 61815b1307eb..62c61fa9d229 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java @@ -1,7 +1,6 @@ package com.dotcms.ai.rest.forms; import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.UtilMethods; @@ -65,8 +64,6 @@ public static final Builder copy(EmbeddingsForm form) { .fields(String.join(",", form.fields)) .velocityTemplate(form.velocityTemplate) .indexName(form.indexName); - - } @Override @@ -103,7 +100,6 @@ public String toString() { '}'; } - public static final class Builder { @JsonSetter(nulls = Nulls.SKIP) public String fields; @@ -135,7 +131,6 @@ public Builder limit(int limit) { return this; } - public Builder offset(int offset) { this.offset = offset; return this; @@ -161,10 +156,10 @@ public Builder velocityTemplate(String velocityTemplate) { return this; } - public EmbeddingsForm build() { return new EmbeddingsForm(this); - } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java deleted file mode 100644 index b2a9b9adf789..000000000000 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.dotcms.ai.util; - -import com.dotcms.ai.AiKeys; -import com.dotcms.ai.app.AIModel; -import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.app.ConfigService; -import com.dotmarketing.exception.DotRuntimeException; -import com.dotmarketing.util.Logger; -import com.dotmarketing.util.json.JSONObject; -import io.vavr.control.Try; -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.*; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClients; - -import javax.ws.rs.HttpMethod; -import javax.ws.rs.core.MediaType; -import java.io.BufferedInputStream; -import java.io.ByteArrayOutputStream; -import java.io.OutputStream; -import java.util.concurrent.ConcurrentHashMap; - -/** - * The OpenAIRequest class is a utility class that handles HTTP requests to the OpenAI API. - * It provides methods for sending GET, POST, PUT, DELETE, and PATCH requests. - * This class also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * This class is implemented as a singleton, meaning that only one instance of the class is created throughout the execution of the program. - */ -public class OpenAIRequest { - - private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>(); - - private OpenAIRequest() {} - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is written to the provided OutputStream. - * This method also manages rate limiting for the OpenAI API by keeping track of the last time a request was made. - * - * @param urlIn the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doRequest(final String urlIn, - final String method, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - AppConfig.debugLogger( - OpenAIRequest.class, - () -> String.format( - "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", - urlIn, - method, - System.lineSeparator(), - appConfig.toString(), - System.lineSeparator(), - json.toString(2))); - - if (!appConfig.isEnabled()) { - AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request."); - throw new DotRuntimeException("App dotAI config without API urls or API key"); - } - - final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); - final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) - + model.minIntervalBetweenCalls() - - System.currentTimeMillis(); - if (sleep > 0) { - Logger.info( - OpenAIRequest.class, - "Rate limit:" - + model.getApiPerMinute() - + "/minute, or 1 every " - + model.minIntervalBetweenCalls() - + "ms. Sleeping:" - + sleep); - Try.run(() -> Thread.sleep(sleep)); - } - - lastRestCall.put(model, System.currentTimeMillis()); - - try (CloseableHttpClient httpClient = HttpClients.createDefault()) { - final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); - final HttpUriRequest httpRequest = resolveMethod(method, urlIn); - httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); - httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); - - if (!json.getAsMap().isEmpty()) { - Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); - } - - try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { - final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); - final byte[] buffer = new byte[1024]; - int len; - while ((len = in.read(buffer)) != -1) { - out.write(buffer, 0, len); - out.flush(); - } - } - } catch (Exception e) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)){ - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e); - } else { - Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage()); - } - - Logger.warn(OpenAIRequest.class, " - " + method + " : " +json); - - throw new DotRuntimeException(e); - } - } - - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is returned as a string. - * - * @param url the URL to send the request to - * @param method the HTTP method to use for the request - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @return the response from the request as a string - */ - public static String doRequest(final String url, - final String method, - final AppConfig appConfig, - final JSONObject json) { - final ByteArrayOutputStream out = new ByteArrayOutputStream(); - doRequest(url, method, appConfig, json, out); - - return out.toString(); - } - - /** - * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doPost(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.POST, appConfig, json, out); - } - - /** - * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doGet(final String urlIn, - final AppConfig appConfig, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.GET, appConfig, json, out); - } - - private static HttpUriRequest resolveMethod(final String method, final String urlIn) { - switch(method) { - case HttpMethod.POST: - return new HttpPost(urlIn); - case HttpMethod.PUT: - return new HttpPut(urlIn); - case HttpMethod.DELETE: - return new HttpDelete(urlIn); - case "patch": - return new HttpPatch(urlIn); - case HttpMethod.GET: - default: - return new HttpGet(urlIn); - } - } - -} diff --git a/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java index f2036cba8955..7c51f40c883e 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java +++ b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java @@ -1,7 +1,9 @@ package com.dotcms.ai.validator; import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModels; import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.domain.Model; import com.dotcms.api.system.event.message.MessageSeverity; import com.dotcms.api.system.event.message.SystemMessageEventUtil; import com.dotcms.api.system.event.message.builder.SystemMessage; @@ -15,6 +17,8 @@ import java.util.Collections; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * The AIAppValidator class is responsible for validating AI configurations and model usage. @@ -52,8 +56,7 @@ public void validateAIConfig(final AppConfig appConfig, final String userId) { return; } - // TODO: pr-split -> uncomment this lines - /*final Set supportedModels = AIModels.get().getOrPullSupportedModels(appConfig.getApiKey()); + final Set supportedModels = AIModels.get().getOrPullSupportedModels(appConfig); final Set unsupportedModels = Stream.of( appConfig.getModel(), appConfig.getImageModel(), @@ -61,9 +64,7 @@ public void validateAIConfig(final AppConfig appConfig, final String userId) { .flatMap(aiModel -> aiModel.getModels().stream()) .map(Model::getName) .filter(model -> !supportedModels.contains(model)) - .collect(Collectors.toSet());*/ - final Set supportedModels = Set.of(); - final Set unsupportedModels = Set.of(); + .collect(Collectors.toSet()); if (unsupportedModels.isEmpty()) { return; } @@ -96,12 +97,10 @@ public void validateModelsUsage(final AIModel aiModel, final String userId) { return; } - // TODO: pr-split -> uncomment this line - /*final String unavailableModels = aiModel.getModels() + final String unavailableModels = aiModel.getModels() .stream() .map(Model::getName) - .collect(Collectors.joining(", "));*/ - final String unavailableModels = ""; + .collect(Collectors.joining(", ")); final String message = Try .of(() -> LanguageUtil.get("ai.models.exhausted", aiModel.getType(), unavailableModels)). getOrElse( diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java index 050b56b1e535..0ad6d7837a2d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java @@ -4,9 +4,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.json.JSONObject; @@ -30,11 +28,13 @@ public class AIViewTool implements ViewTool { private AppConfig config; private ChatAPI chatService; private ImageAPI imageService; + private User user; @Override public void init(final Object obj) { context = (ViewContext) obj; config = config(); + user = user(); chatService = chatService(); imageService = imageService(); } @@ -128,12 +128,12 @@ User user() { @VisibleForTesting ChatAPI chatService() { - return APILocator.getDotAIAPI().getChatAPI(config); + return APILocator.getDotAIAPI().getChatAPI(config, user); } @VisibleForTesting ImageAPI imageService() { - return APILocator.getDotAIAPI().getImageAPI(config, user(), APILocator.getHostAPI(), APILocator.getTempFileAPI()); + return APILocator.getDotAIAPI().getImageAPI(config, user, APILocator.getHostAPI(), APILocator.getTempFileAPI()); } private

Try generate(final P prompt, final Function serviceCall) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java index 5508a23f4e32..adc3a77cc60c 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java @@ -7,8 +7,11 @@ import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; +import com.dotmarketing.util.UtilMethods; import com.dotmarketing.util.json.JSONObject; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -24,14 +27,18 @@ */ public class CompletionsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig config; + private final User user; CompletionsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.config = config(); + this.user = user(); } @Override @@ -69,7 +76,11 @@ public Object summarize(final String prompt) { * @return The summarized object. */ public Object summarize(final String prompt, final String indexName) { - final CompletionsForm form = new CompletionsForm.Builder().indexName(indexName).prompt(prompt).build(); + final CompletionsForm form = new CompletionsForm.Builder() + .indexName(indexName) + .prompt(prompt) + .user(user) + .build(); try { return APILocator.getDotAIAPI().getCompletionsAPI(config).summarize(form); } catch (Exception e) { @@ -112,7 +123,9 @@ public Object raw(String prompt) { */ public Object raw(final JSONObject prompt) { try { - return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt); + return APILocator.getDotAIAPI() + .getCompletionsAPI(config) + .raw(prompt, UtilMethods.extractUserIdOrNull(user)); } catch (Exception e) { return handleException(e); } @@ -141,4 +154,9 @@ AppConfig config() { return ConfigService.INSTANCE.config(this.host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java index daf7e1756139..e8dc49722255 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java @@ -7,7 +7,10 @@ import com.dotmarketing.business.APILocator; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; +import com.dotmarketing.util.UtilMethods; import com.google.common.annotations.VisibleForTesting; +import com.liferay.portal.model.User; +import com.liferay.portal.util.PortalUtil; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -22,9 +25,11 @@ */ public class EmbeddingsTool implements ViewTool { + private final ViewContext context; private final HttpServletRequest request; private final Host host; private final AppConfig appConfig; + private final User user; /** * Constructor for the EmbeddingsTool class. @@ -33,9 +38,11 @@ public class EmbeddingsTool implements ViewTool { * @param initData Initialization data for the tool. */ EmbeddingsTool(Object initData) { - this.request = ((ViewContext) initData).getRequest(); + this.context = (ViewContext) initData; + this.request = this.context.getRequest(); this.host = host(); this.appConfig = appConfig(); + this.user = user(); } @Override @@ -71,10 +78,14 @@ public List generateEmbeddings(final String prompt) { if (tokens > maxTokens) { Logger.warn( EmbeddingsTool.class, - "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); + "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens "); } - return APILocator.getDotAIAPI().getEmbeddingsAPI().pullOrGenerateEmbeddings(prompt)._2; + return APILocator.getDotAIAPI() + .getEmbeddingsAPI() + .pullOrGenerateEmbeddings(prompt, UtilMethods.extractUserIdOrNull(user)) + ._2; } /** @@ -96,4 +107,9 @@ AppConfig appConfig() { return ConfigService.INSTANCE.config(host); } + @VisibleForTesting + User user() { + return PortalUtil.getUser(context.getRequest()); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java index 87a068eca10c..d8d01341bd0d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; import com.dotmarketing.portlets.workflows.actionlet.WorkFlowActionlet; @@ -9,7 +10,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -24,7 +24,7 @@ public List getParameters() { WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite tags ", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); @@ -32,16 +32,26 @@ public List getParameters() { WorkflowActionletParameter limitTagsToHost = new MultiSelectionWorkflowActionletParameter( OpenAIParams.LIMIT_TAGS_TO_HOST.key, "Limit the keywords to pre-existing tags", "Limit", false, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true)) ) ); + + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( overwriteParameter, limitTagsToHost, - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false) + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0.", + ".1", + false) ); } @@ -63,5 +73,4 @@ public void executeAction(final WorkflowProcessor processor, new OpenAIAutoTagRunner(processor, params).run(); } - } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java index 2f7013bfd362..a63c03743686 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java @@ -146,7 +146,7 @@ private String openAIRequest(final Contentlet workingContentlet, final String co final String parsedContentPrompt = VelocityUtil.eval(contentToTag, ctx); final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI() - .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000); + .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000, user.getUserId()); return openAIResponse.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content"); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java index b6a14ab22d44..8a8e9320293b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java @@ -1,5 +1,6 @@ package com.dotcms.ai.workflow; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; @@ -10,7 +11,6 @@ import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException; import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter; import com.dotmarketing.portlets.workflows.model.WorkflowProcessor; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; @@ -23,22 +23,39 @@ public class OpenAIContentPromptActionlet extends WorkFlowActionlet { @Override public List getParameters() { - WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, + final WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key, "Overwrite existing content (true|false)", Boolean.toString(true), true, - () -> ImmutableList.of( + () -> List.of( new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)), new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))) ); - + final AppConfig appConfig = ConfigService.INSTANCE.config(); return List.of( - new WorkflowActionletParameter(OpenAIParams.FIELD_TO_WRITE.key, "The field where you want to write the results. " + - "
If your response is being returned as a json object, this field can be left blank" + - "
and the keys of the json object will be used to update the content fields.", "", false), + new WorkflowActionletParameter( + OpenAIParams.FIELD_TO_WRITE.key, + "The field where you want to write the results. " + + "
If your response is being returned as a json object, this field can be left blank" + + "
and the keys of the json object will be used to update the content fields.", + "", + false), overwriteParameter, - new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), - new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false) + new WorkflowActionletParameter( + OpenAIParams.OPEN_AI_PROMPT.key, + "The prompt that will be sent to the AI", + "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", + true), + new WorkflowActionletParameter( + OpenAIParams.MODEL.key, + "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(), + appConfig.getModel().getCurrentModel(), + false), + new WorkflowActionletParameter( + OpenAIParams.TEMPERATURE.key, + "The AI temperature for the response. Between .1 and 2.0. Defaults to " + + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE), + false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java index 176b12d860c4..5ebca943a044 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java @@ -145,7 +145,9 @@ private String openAIRequest(final Contentlet workingContentlet) throws Exceptio final Context ctx = VelocityContextFactory.getMockContext(workingContentlet, user); final String parsedPrompt = VelocityUtil.eval(prompt, ctx); - final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI().raw(buildRequest(parsedPrompt, model, temperature)); + final JSONObject openAIResponse = APILocator.getDotAIAPI() + .getCompletionsAPI() + .raw(buildRequest(parsedPrompt, model, temperature), user.getUserId()); try { return openAIResponse diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml index e8f03c0c5f80..425c9b42b058 100644 --- a/dotCMS/src/main/resources/apps/dotAI.yml +++ b/dotCMS/src/main/resources/apps/dotAI.yml @@ -15,11 +15,11 @@ params: hint: "Your ChatGPT API key" required: true textModelNames: - value: "gpt-4o" + value: "gpt-4o-mini" hidden: false type: "STRING" label: "Model Names" - hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-3.5-turbo-16k)" + hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-4o-mini)" required: true rolePrompt: value: "You are dotCMSbot, and AI assistant to help content creators generate and rewrite content in their content management system." diff --git a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties index 98cae197cf4b..ddf059c7f949 100644 --- a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties +++ b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties @@ -189,6 +189,8 @@ anonymous=Anonymous another-layout-already-exists=Another Tool Group already exists in the system with the same name Any-Structure-Type=Any Content Type Any-Structure=Any Content Type +ai.unsupported.models=The following models are not supported: [{0}] +ai.models.exhausted=All the {0} models: [{1}] have been exhausted since they are invalid or has been decommissioned api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.comparison.placeholder=Comparison api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.language.placeholder=Language api.ruleengine.system.conditionlet.CurrentSessionLanguage.name=Selected Language diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js index 088436aef605..f2db45f1a34c 100644 --- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js +++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js @@ -136,11 +136,12 @@ const writeModelToDropdown = async () => { } const newOption = document.createElement("option"); + console.log(JSON.stringify(dotAiState.config, null, 2)); newOption.value = dotAiState.config.availableModels[i].name; newOption.text = `${dotAiState.config.availableModels[i].name}` - if (dotAiState.config.availableModels[i] === dotAiState.config.model) { + if (dotAiState.config.availableModels[i].current) { newOption.selected = true; - newOption.text = `${dotAiState.config.availableModels[i]} (default)` + newOption.text = `${dotAiState.config.availableModels[i].name} (default)` } modelName.appendChild(newOption); } diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java similarity index 88% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java index 3fe155f2d01b..c51e9c6323a5 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java @@ -1,12 +1,11 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ChatAPI; -import com.dotcms.ai.api.OpenAIChatAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotmarketing.util.json.JSONObject; +import com.liferay.portal.model.User; import org.junit.Before; import org.junit.Test; @@ -17,17 +16,19 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIChatServiceImplTest { +public class OpenAIChatAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; private AppConfig config; private ChatAPI service; + private User user; @Before public void setUp() { config = mock(AppConfig.class); + user = mock(User.class); service = prepareService(RESPONSE_JSON); } @@ -54,11 +55,9 @@ public void test_sendTextPrompt() { } private ChatAPI prepareService(final String response) { - return new OpenAIChatAPIImpl(config) { - - + return new OpenAIChatAPIImpl(config, user) { @Override - public String doRequest(final String urlIn, final JSONObject json) { + String doRequest(final JSONObject json, final String userId) { return response; } }; diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java similarity index 96% rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java index 4d0afb444b88..e73d9352a59b 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java @@ -1,7 +1,5 @@ -package com.dotcms.ai.service; +package com.dotcms.ai.api; -import com.dotcms.ai.api.ImageAPI; -import com.dotcms.ai.api.OpenAIImageAPIImpl; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; @@ -27,7 +25,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class OpenAIImageServiceImplTest { +public class OpenAIImageAPIImplTest { private static final String RESPONSE_JSON = "{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}"; @@ -220,8 +218,7 @@ public AIImageRequestDTO.Builder getDtoBuilder() { } private JSONObject prepareJsonObject(final String prompt, final boolean tempFileError) throws Exception { - when(config.getImageModel()) - .thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withModelNames("some-image-model").build()); + when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withModelNames("some-image-model").build()); when(config.getImageSize()).thenReturn("some-image-size"); final File file = mock(File.class); when(file.getName()).thenReturn(UUIDGenerator.shorty()); diff --git a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java index 81b74a231e3f..5b604f096adb 100644 --- a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java +++ b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java @@ -1,7 +1,10 @@ package com.dotcms; import com.dotcms.ai.app.AIModelsTest; +import com.dotcms.ai.app.ConfigServiceTest; +import com.dotcms.ai.client.AIProxyClientTest; import com.dotcms.ai.listener.EmbeddingContentListenerTest; +import com.dotcms.ai.validator.AIAppValidatorTest; import com.dotcms.ai.viewtool.AIViewToolTest; import com.dotcms.ai.viewtool.CompletionsToolTest; import com.dotcms.ai.viewtool.EmbeddingsToolTest; @@ -302,6 +305,9 @@ EmbeddingsToolTest.class, CompletionsToolTest.class, AIModelsTest.class, + ConfigServiceTest.class, + AIProxyClientTest.class, + AIAppValidatorTest.class, TimeMachineAPITest.class, Task240513UpdateContentTypesSystemFieldTest.class, PruneTimeMachineBackupJobTest.class, diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 0c790bb24d74..557201f1aaa4 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -6,8 +6,6 @@ import com.dotcms.util.WireMockTestHelper; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; -import com.dotmarketing.exception.DotSecurityException; import com.github.tomakehurst.wiremock.WireMockServer; import java.util.Map; @@ -33,61 +31,6 @@ static WireMockServer prepareWireMock() { return wireMockServer; } - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, - final String apiKey, - final String textModels, - final String imageModels, - final String embeddingsModel) - throws DotDataException, DotSecurityException { - final AppSecrets appSecrets = new AppSecrets.Builder() - .withKey(AppKeys.APP_KEY) - .withSecret(AppKeys.API_URL.key, String.format(API_URL, wireMockServer.port())) - .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, wireMockServer.port())) - .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, wireMockServer.port())) - .withHiddenSecret(AppKeys.API_KEY.key, apiKey) - .withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels) - .withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels) - .withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel) - .withSecret(AppKeys.IMAGE_SIZE.key, IMAGE_SIZE) - .withSecret(AppKeys.LISTENER_INDEXER.key, "{\"default\":\"blog\"}") - .withSecret(AppKeys.COMPLETION_ROLE_PROMPT.key, AppKeys.COMPLETION_ROLE_PROMPT.defaultValue) - .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue) - .build(); - APILocator.getAppsAPI().saveSecrets(appSecrets, host, APILocator.systemUser()); - return appSecrets.getSecrets(); - } - - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, - final String apiKey) - throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); - } - - static Map aiAppSecrets(final WireMockServer wireMockServer, - final Host host, - final String textModels, - final String imageModels, - final String embeddingsModel) - throws DotDataException, DotSecurityException { - return aiAppSecrets(wireMockServer, host, API_KEY, textModels, imageModels, embeddingsModel); - } - - static Map aiAppSecrets(final WireMockServer wireMockServer, final Host host) - throws DotDataException, DotSecurityException { - - return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); - } - - static void removeSecrets(final Host host) throws DotDataException, DotSecurityException { - APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser()); - } - - - - - // TODO: pr-split -> remove methods below static Map aiAppSecrets(final Host host, final String apiKey, final String textModels, @@ -140,6 +83,4 @@ static void removeAiAppSecrets(final Host host) throws Exception { APILocator.getAppsAPI().deleteSecrets(AppKeys.APP_KEY, host, APILocator.systemUser()); } - - } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index b24cebcae020..0b00b151ef55 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -1,17 +1,19 @@ package com.dotcms.ai.app; import com.dotcms.ai.AiTest; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIModelNotFoundException; +import com.dotcms.ai.model.SimpleModel; import com.dotcms.datagen.SiteDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; -import com.dotmarketing.exception.DotSecurityException; -import com.dotmarketing.util.DateUtil; +import com.dotmarketing.exception.DotRuntimeException; import com.github.tomakehurst.wiremock.WireMockServer; +import io.vavr.Tuple2; import io.vavr.control.Try; -import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -22,9 +24,11 @@ import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; /** @@ -45,25 +49,22 @@ public class AIModelsTest { @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); + IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); + AiTest.aiAppSecrets(APILocator.systemHost()); } @AfterClass public static void afterClass() { wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); } @Before public void before() { - IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); otherHost = new SiteDataGen().nextPersisted(); - List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(wireMockServer, host)).get()); - } - - @After - public void after() { - IPUtils.disabledIpPrivateSubnet(false); + List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(h)).get()); } /** @@ -72,31 +73,31 @@ public void after() { * Then the correct models should be found and returned. */ @Test - public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - saveSecrets( + public void test_loadModels_andFindThem() throws Exception { + AiTest.aiAppSecrets( host, "text-model-1,text-model-2", "image-model-3,image-model-4", "embeddings-model-5,embeddings-model-6"); - saveSecrets(otherHost, "text-model-1", null, null); + AiTest.aiAppSecrets(otherHost, "text-model-1", null, null); final String hostId = host.getHostname(); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); - final Optional notFound = aiModels.findModel(hostId, "some-invalid-model-name"); + final Optional notFound = aiModels.findModel(appConfig, "some-invalid-model-name", AIModelType.TEXT); assertTrue(notFound.isEmpty()); - final Optional text1 = aiModels.findModel(hostId, "text-model-1"); - final Optional text2 = aiModels.findModel(hostId, "text-model-2"); - assertModels(text1, text2, AIModelType.TEXT); + final Optional text1 = aiModels.findModel(appConfig, "text-model-1", AIModelType.TEXT); + final Optional text2 = aiModels.findModel(appConfig, "text-model-2", AIModelType.TEXT); + assertModels(text1, text2, AIModelType.TEXT, true); - final Optional image1 = aiModels.findModel(hostId, "image-model-3"); - final Optional image2 = aiModels.findModel(hostId, "image-model-4"); - assertModels(image1, image2, AIModelType.IMAGE); + final Optional image1 = aiModels.findModel(appConfig, "image-model-3", AIModelType.IMAGE); + final Optional image2 = aiModels.findModel(appConfig, "image-model-4", AIModelType.IMAGE); + assertModels(image1, image2, AIModelType.IMAGE, true); - final Optional embeddings1 = aiModels.findModel(hostId, "embeddings-model-5"); - final Optional embeddings2 = aiModels.findModel(hostId, "embeddings-model-6"); - assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS); + final Optional embeddings1 = aiModels.findModel(appConfig, "embeddings-model-5", AIModelType.EMBEDDINGS); + final Optional embeddings2 = aiModels.findModel(appConfig, "embeddings-model-6", AIModelType.EMBEDDINGS); + assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS, true); assertNotSame(text1.get(), image1.get()); assertNotSame(text1.get(), embeddings1.get()); @@ -111,102 +112,183 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx final Optional embeddings3 = aiModels.findModel(hostId, AIModelType.EMBEDDINGS); assertSameModels(embeddings3, embeddings1, embeddings2); - final Optional text4 = aiModels.findModel(otherHost.getHostname(), "text-model-1"); + final AppConfig otherAppConfig = ConfigService.INSTANCE.config(otherHost); + final Optional text4 = aiModels.findModel(otherAppConfig, "text-model-1", AIModelType.TEXT); assertTrue(text3.isPresent()); assertNotSame(text1.get(), text4.get()); - saveSecrets( + AiTest.aiAppSecrets( host, "text-model-7,text-model-8", "image-model-9,image-model-10", "embeddings-model-11, embeddings-model-12"); - final Optional text7 = aiModels.findModel(hostId, "text-model-7"); - final Optional text8 = aiModels.findModel(hostId, "text-model-8"); + final Optional text7 = aiModels.findModel(otherAppConfig, "text-model-7", AIModelType.TEXT); + final Optional text8 = aiModels.findModel(otherAppConfig, "text-model-8", AIModelType.TEXT); assertNotPresentModels(text7, text8); - final Optional image9 = aiModels.findModel(hostId, "image-model-9"); - final Optional image10 = aiModels.findModel(hostId, "image-model-10"); + final Optional image9 = aiModels.findModel(otherAppConfig, "image-model-9", AIModelType.IMAGE); + final Optional image10 = aiModels.findModel(otherAppConfig, "image-model-10", AIModelType.IMAGE); assertNotPresentModels(image9, image10); - final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11"); - final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12"); + final Optional embeddings11 = aiModels.findModel(otherAppConfig, "embeddings-model-11", AIModelType.EMBEDDINGS); + final Optional embeddings12 = aiModels.findModel(otherAppConfig, "embeddings-model-12", AIModelType.EMBEDDINGS); assertNotPresentModels(embeddings11, embeddings12); + + final List available = aiModels.getAvailableModels(); + final List availableNames = List.of( + "gpt-3.5-turbo-16k", "dall-e-3", "text-embedding-ada-002", + "text-model-1", "text-model-7", "text-model-8", + "image-model-9", "image-model-10", + "embeddings-model-11", "embeddings-model-12"); + assertTrue(available.stream().anyMatch(model -> availableNames.contains(model.getName()))); } /** - * Given a URL for supported models - * When the getOrPullSupportedModules method is called - * Then a list of supported models should be returned. + * Given a set of models loaded into the AIModels instance + * When the resolveModel method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. */ @Test - public void test_getOrPullSupportedModules() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - AIModels.get().cleanSupportedModelsCache(); + public void test_resolveModel() throws Exception { + AiTest.aiAppSecrets(host, "text-model-20", "image-model-21", "embeddings-model-22"); + ConfigService.INSTANCE.config(host); + AiTest.aiAppSecrets(otherHost, "text-model-23", null, null); + ConfigService.INSTANCE.config(otherHost); + + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.TEXT).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.IMAGE).isOperational()); + assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + assertTrue(aiModels.resolveModel(otherHost.getHostname(), AIModelType.TEXT).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.IMAGE).isOperational()); + assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.EMBEDDINGS).isOperational()); + } - Set supported = aiModels.getOrPullSupportedModels(); - assertNotNull(supported); - assertEquals(38, supported.size()); + /** + * Given a set of models loaded into the AIModels instance + * When the resolveAIModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveAIModelOrThrow() throws Exception { + AiTest.aiAppSecrets(host, "text-model-30", "image-model-31", "embeddings-model-32"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final AIModel aiModel30 = aiModels.resolveAIModelOrThrow(appConfig, "text-model-30", AIModelType.TEXT); + final AIModel aiModel31 = aiModels.resolveAIModelOrThrow(appConfig, "image-model-31", AIModelType.IMAGE); + final AIModel aiModel32 = aiModels.resolveAIModelOrThrow( + appConfig, + "embeddings-model-32", + AIModelType.EMBEDDINGS); + + assertNotNull(aiModel30); + assertNotNull(aiModel31); + assertNotNull(aiModel32); + assertEquals("text-model-30", aiModel30.getModel("text-model-30").getName()); + assertEquals("image-model-31", aiModel31.getModel("image-model-31").getName()); + assertEquals("embeddings-model-32", aiModel32.getModel("embeddings-model-32").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-33", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-34", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-35", AIModelType.EMBEDDINGS)); + } - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); + /** + * Given a set of models loaded into the AIModels instance + * When the resolveModelOrThrow method is called with various model names and types + * Then the correct models should be resolved and their operational status verified. + */ + @Test + public void test_resolveModelOrThrow() throws Exception { + AiTest.aiAppSecrets(host, "text-model-40", "image-model-41", "embeddings-model-42"); + + final AppConfig appConfig = ConfigService.INSTANCE.config(host); + final Tuple2 modelTuple40 = aiModels.resolveModelOrThrow( + appConfig, + "text-model-40", + AIModelType.TEXT); + final Tuple2 modelTuple41 = aiModels.resolveModelOrThrow( + appConfig, + "image-model-41", + AIModelType.IMAGE); + final Tuple2 modelTuple42 = aiModels.resolveModelOrThrow( + appConfig, + "embeddings-model-42", + AIModelType.EMBEDDINGS); + + assertNotNull(modelTuple40); + assertNotNull(modelTuple41); + assertNotNull(modelTuple42); + assertEquals("text-model-40", modelTuple40._1.getModel("text-model-40").getName()); + assertEquals("image-model-41", modelTuple41._1.getModel("image-model-41").getName()); + assertEquals("embeddings-model-42", modelTuple42._1.getModel("embeddings-model-42").getName()); + + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-43", AIModelType.TEXT)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-44", AIModelType.IMAGE)); + assertThrows( + DotAIModelNotFoundException.class, + () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-45", AIModelType.EMBEDDINGS)); } /** - * Given an invalid URL for supported models + * Given a URL for supported models * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then a list of supported models should be returned. */ @Test - public void test_getOrPullSupportedModules_withNetworkError() { + public void test_getOrPullSupportedModels() { AIModels.get().cleanSupportedModelsCache(); - IPUtils.disabledIpPrivateSubnet(false); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); - final Set supported = aiModels.getOrPullSupportedModels(); - assertTrue(supported.isEmpty()); - - IPUtils.disabledIpPrivateSubnet(true); - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); + Set supported = aiModels.getOrPullSupportedModels(appConfig); + assertNotNull(supported); + assertEquals(38, supported.size()); } /** - * Given no API key + * Given an invalid URL for supported models * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown */ @Test - public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), null); - + public void test_getOrPullSupportedModuels_withNetworkError() { + final AppConfig appConfig = ConfigService.INSTANCE.config(host); AIModels.get().cleanSupportedModelsCache(); - final Set supported = aiModels.getOrPullSupportedModels(); + IPUtils.disabledIpPrivateSubnet(false); - assertTrue(supported.isEmpty()); + assertThrows(DotRuntimeException.class, () ->aiModels.getOrPullSupportedModels(appConfig)); + IPUtils.disabledIpPrivateSubnet(true); } /** * Given no API key * When the getOrPullSupportedModules method is called - * Then an empty list of supported models should be returned. + * Then an exception should be thrown. */ @Test - public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException { - AiTest.removeSecrets(APILocator.systemHost()); + public void test_getOrPullSupportedModels_noApiKey() throws Exception { + AiTest.aiAppSecrets(host, null); + final AppConfig appConfig = ConfigService.INSTANCE.config(host); AIModels.get().cleanSupportedModelsCache(); - final Set supported = aiModels.getOrPullSupportedModels(); + final Set supported = aiModels.getOrPullSupportedModels(appConfig); assertTrue(supported.isEmpty()); } - private void saveSecrets(final Host host, - final String textModels, - final String imageModels, - final String embeddingsModels) throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, textModels, imageModels, embeddingsModels); - DateUtil.sleep(1000); - } - - private static void assertSameModels(Optional text3, Optional text1, Optional text2) { + private static void assertSameModels(final Optional text3, + final Optional text1, + final Optional text2) { assertTrue(text3.isPresent()); assertSame(text1.get(), text3.get()); assertSame(text2.get(), text3.get()); @@ -214,12 +296,17 @@ private static void assertSameModels(Optional text3, Optional private static void assertModels(final Optional model1, final Optional model2, - final AIModelType type) { + final AIModelType type, + final boolean assertModelNames) { assertTrue(model1.isPresent()); assertTrue(model2.isPresent()); assertSame(model1.get(), model2.get()); assertSame(type, model1.get().getType()); assertSame(type, model2.get().getType()); + if (assertModelNames) { + assertTrue(model1.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + assertTrue(model2.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE)); + } } private static void assertNotPresentModels(final Optional model1, final Optional model2) { @@ -227,9 +314,4 @@ private static void assertNotPresentModels(final Optional model1, final assertTrue(model2.isEmpty()); } - private static void assertSupported(Set supported) { - assertNotNull(supported); - assertTrue(supported.isEmpty()); - } - } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java new file mode 100644 index 000000000000..af100d725f2f --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java @@ -0,0 +1,318 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.AiTest; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.domain.AIResponse; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; +import com.dotcms.ai.exception.DotAIAllModelsExhaustedException; +import com.dotcms.ai.exception.DotAIClientConnectException; +import com.dotcms.ai.util.LineReadingOutputStream; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import com.dotmarketing.util.UtilMethods; +import com.dotmarketing.util.json.JSONArray; +import com.dotmarketing.util.json.JSONObject; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; +import io.vavr.Tuple2; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for the AIProxyClient class. + * + * @author vico + */ +public class AIProxyClientTest { + + private static WireMockServer wireMockServer; + private static User user; + private Host host; + private AppConfig appConfig; + private final AIProxyClient aiProxyClient = AIProxyClient.get(); + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + IPUtils.disabledIpPrivateSubnet(true); + wireMockServer = AiTest.prepareWireMock(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost); + ConfigService.INSTANCE.config(systemHost); + user = new UserDataGen().nextPersisted(); + } + + @AfterClass + public static void afterClass() { + wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); + } + + @Before + public void before() { + host = new SiteDataGen().nextPersisted(); + } + + @After + public void after() throws Exception { + AiTest.removeAiAppSecrets(host); + } + + /** + * Scenario: Calling AI with a valid model + * Given a valid model "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_happiestPath() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets(host, model, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with multiple models + * Given multiple models including "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_happyPath_withMultipleModels() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets( + host, + String.format("%s,some-made-up-model-1", model), + "dall-e-3", + "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with an invalid model + * Given an invalid model "some-made-up-model-10" + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withInvalidModel() throws Exception { + final String invalidModel = "some-made-up-model-10"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + invalidModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT); + assertSame(ModelStatus.INVALID, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(invalidModel))); + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + } + + /** + * Scenario: Calling AI with a decommissioned model + * Given a decommissioned model "some-decommissioned-model-20" + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withDecommissionedModel() throws Exception { + final String decommissionedModel = "some-decommissioned-model-20"; + AiTest.aiAppSecrets(host, decommissionedModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + decommissionedModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(decommissionedModel, AIModelType.TEXT); + assertSame(ModelStatus.DECOMMISSIONED, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(decommissionedModel))); + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + } + + /** + * Scenario: Calling AI with multiple models including invalid, decommissioned, and valid models + * Given models "some-made-up-model-30", "some-decommissioned-model-31", and "gpt-4o-mini" + * When the request is sent to the AI service + * Then the response should contain the model name "gpt-4o-mini" + */ + @Test + public void test_callToAI_withMultipleModels_invalidAndDecommissionedAndValid() throws Exception { + final String invalidModel = "some-made-up-model-30"; + final String decommissionedModel = "some-decommissioned-model-31"; + final String validModel = "gpt-4o-mini"; + AiTest.aiAppSecrets( + host, + String.format("%s,%s,%s", invalidModel, decommissionedModel, validModel), + "dall-e-3", + "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest(invalidModel, "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI(request); + + assertNotNull(aiResponse); + assertNotNull(aiResponse.getResponse()); + assertSame(ModelStatus.INVALID, appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT)._2.getStatus()); + assertSame( + ModelStatus.DECOMMISSIONED, + appConfig.resolveModelOrThrow(decommissionedModel, AIModelType.TEXT)._2.getStatus()); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(validModel, AIModelType.TEXT); + assertSame(ModelStatus.ACTIVE, modelTuple._2.getStatus()); + assertEquals(2, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> List.of(invalidModel, decommissionedModel).contains(model.getName()))); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .anyMatch(model -> model.getName().equals(validModel))); + assertEquals("gpt-4o-mini", new JSONObject(aiResponse.getResponse()).getString(AiKeys.MODEL)); + } + + /** + * Scenario: Calling AI with a valid model and provided output stream + * Given a valid model "gpt-4o-mini" and a provided output stream + * When the request is sent to the AI service + * Then the response should be written to the output stream + */ + @Test + public void test_callToAI_withProvidedOutput() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets(host, model, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + final AIResponse aiResponse = aiProxyClient.callToAI( + request, + new LineReadingOutputStream(new ByteArrayOutputStream())); + assertNotNull(aiResponse); + assertNull(aiResponse.getResponse()); + } + + /** + * Scenario: Calling AI with an invalid model and provided output stream + * Given an invalid model "some-made-up-model-40" and a provided output stream + * When the request is sent to the AI service + * Then a DotAIAllModelsExhaustedException should be thrown + */ + @Test + public void test_callToAI_withInvalidModel_withProvidedOutput() throws Exception { + final String invalidModel = "some-made-up-model-40"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + invalidModel, + "What are the major achievements of the Apollo space program?"); + + assertThrows(DotAIAllModelsExhaustedException.class, () -> aiProxyClient.callToAI(request)); + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(invalidModel, AIModelType.TEXT); + assertSame(ModelStatus.INVALID, modelTuple._2.getStatus()); + assertEquals(-1, modelTuple._1.getCurrentModelIndex()); + assertTrue(AIModels.get() + .getAvailableModels() + .stream() + .noneMatch(model -> model.getName().equals(invalidModel))); + } + + /** + * Scenario: Calling AI with network issues + * Given a valid model "gpt-4o-mini" + * And the AI service is unavailable due to network issues + * When the request is sent to the AI service + * Then a DotAIClientConnectException should be thrown + * And the model should remain operational after the network is restored + */ + @Test + public void test_callToAI_withNetworkIssues() throws Exception { + final String model = "gpt-4o-mini"; + AiTest.aiAppSecrets(host, model, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + final JSONObjectAIRequest request = textRequest( + model, + "What are the major achievements of the Apollo space program?"); + + wireMockServer.stop(); + + assertThrows(DotAIClientConnectException.class, () -> aiProxyClient.callToAI(request)); + + wireMockServer = AiTest.prepareWireMock(); + + final Tuple2 modelTuple = appConfig.resolveModelOrThrow(model, AIModelType.TEXT); + assertTrue(modelTuple._2.isOperational()); + } + + private JSONObjectAIRequest textRequest(final String model, final String prompt) { + final JSONObject payload = new JSONObject(); + final JSONArray messages = new JSONArray(); + + final String systemPrompt = UtilMethods.isSet(appConfig.getRolePrompt()) ? appConfig.getRolePrompt() : null; + if (UtilMethods.isSet(systemPrompt)) { + messages.add(Map.of(AiKeys.ROLE, AiKeys.SYSTEM, AiKeys.CONTENT, systemPrompt)); + } + messages.add(Map.of(AiKeys.ROLE, AiKeys.USER, AiKeys.CONTENT, prompt)); + + payload.put(AiKeys.MODEL, model); + payload.put(AiKeys.TEMPERATURE, appConfig.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)); + payload.put(AiKeys.MESSAGES, messages); + + return JSONObjectAIRequest.quickText(appConfig, payload, user.getUserId()); + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java index 3c61cd335f55..ce61bbb8b6d6 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java @@ -190,9 +190,9 @@ private static boolean waitForEmbeddings(final Contentlet blogContent, final Str return embeddingsExist; } - private static void addDotAISecrets() throws DotDataException, DotSecurityException { - AiTest.aiAppSecrets(wireMockServer, host, AiTest.API_KEY); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), AiTest.API_KEY); + private static void addDotAISecrets() throws Exception { + AiTest.aiAppSecrets(host, AiTest.API_KEY); + AiTest.aiAppSecrets(APILocator.systemHost(), AiTest.API_KEY); } private static void removeDotAISecrets() throws DotDataException, DotSecurityException { @@ -232,4 +232,5 @@ public static java.util.function.Predicate distinctByKey( Set seen = ConcurrentHashMap.newKeySet(); return t -> seen.add(keyExtractor.apply(t)); } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java new file mode 100644 index 000000000000..8c7379ab6313 --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/validator/AIAppValidatorTest.java @@ -0,0 +1,109 @@ +package com.dotcms.ai.validator; + +import com.dotcms.ai.AiTest; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.ConfigService; +import com.dotcms.api.system.event.message.SystemMessageEventUtil; +import com.dotcms.api.system.event.message.builder.SystemMessage; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for the AIAppValidator class. + * This class tests the validation of AI configurations and model usage. + * It ensures that the AI models specified in the application configuration are supported + * and not exhausted. + * + * The tests cover scenarios for valid configurations, invalid configurations, and configurations + * with missing fields. + * + * @author vico + */ +public class AIAppValidatorTest { + + private static WireMockServer wireMockServer; + private static User user; + private static SystemMessageEventUtil systemMessageEventUtil; + private Host host; + private AppConfig appConfig; + private AIAppValidator validator = AIAppValidator.get(); + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + wireMockServer = AiTest.prepareWireMock(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost); + ConfigService.INSTANCE.config(systemHost); + user = new UserDataGen().nextPersisted(); + systemMessageEventUtil = mock(SystemMessageEventUtil.class); + } + + @AfterClass + public static void afterClass() { + wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); + } + + @Before + public void before() { + IPUtils.disabledIpPrivateSubnet(true); + host = new SiteDataGen().nextPersisted(); + validator.setSystemMessageEventUtil(systemMessageEventUtil); + } + + @After + public void after() throws Exception { + AiTest.removeAiAppSecrets(host); + } + + @Test/** + * Scenario: Validating AI configuration with unsupported models + * Given an AI configuration with unsupported models + * When the configuration is validated + * Then a warning message should be pushed to the user + */ + public void test_validateAIConfig() throws Exception { + final String invalidModel = "some-made-up-model-10"; + AiTest.aiAppSecrets(host, invalidModel, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + + verify(systemMessageEventUtil, atLeastOnce()).pushMessage(any(SystemMessage.class), anyList()); + } + + /** + * Scenario: Validating AI models usage with exhausted models + * Given an AI model with exhausted models + * When the models usage is validated + * Then a warning message should be pushed to the user for each exhausted model + */ + @Test + public void test_validateModelsUsage() throws Exception { + final String invalidModels = "some-made-up-model-20,some-decommissioned-model-21"; + AiTest.aiAppSecrets(host, invalidModels, "dall-e-3", "text-embedding-ada-002"); + appConfig = ConfigService.INSTANCE.config(host); + + validator.validateModelsUsage(appConfig.getModel(), user.getUserId()); + + verify(systemMessageEventUtil, atLeast(2)) + .pushMessage(any(SystemMessage.class), anyList()); + } +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java index 314079da2a93..1062044719a9 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java @@ -7,6 +7,7 @@ import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; @@ -48,8 +49,9 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - config = ConfigService.INSTANCE.config(); + final Host systemHost = APILocator.systemHost(); + AiTest.aiAppSecrets(systemHost, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + config = ConfigService.INSTANCE.config(systemHost); } @AfterClass diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java index f00c1ab31ae8..e481133edbca 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java @@ -6,6 +6,7 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; @@ -13,6 +14,7 @@ import com.dotmarketing.util.json.JSONArray; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.commons.lang3.StringUtils; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; @@ -41,6 +43,7 @@ public class CompletionsToolTest { private static AppConfig appConfig; + private static User user; private static WireMockServer wireMockServer; private static Host host; @@ -52,9 +55,10 @@ public static void beforeClass() throws Exception { IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(APILocator.systemHost(), "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); + AiTest.aiAppSecrets(host, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); } @AfterClass @@ -86,7 +90,7 @@ public void test_getConfig() { assertNotNull(config); assertEquals(AppKeys.COMPLETION_ROLE_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_ROLE_PROMPT.key)); assertEquals(AppKeys.COMPLETION_TEXT_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_TEXT_PROMPT.key)); - assertEquals("gpt-3.5-turbo-16k", config.get(AppKeys.TEXT_MODEL_NAMES.key)); + assertEquals("gpt-4o-mini", config.get(AppKeys.TEXT_MODEL_NAMES.key)); } /** @@ -119,7 +123,7 @@ public void test_summarize() { @Test public void test_raw() { final String query = "What is the speed of light in the vacuum"; - final String prompt = String.format("{\"model\":\"gpt-3.5-turbo-16k\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); + final String prompt = String.format("{\"model\":\"gpt-4o-mini\",\"messages\":[{\"role\":\"user\",\"content\":\"%s?\"},{\"role\":\"system\",\"content\":\"You are a helpful assistant with a descriptive writing style.\"}]}", query); final JSONObject result = (JSONObject) completionsTool.raw(prompt); assertResult(result); @@ -142,7 +146,7 @@ public void test_raw_json() { messages.put(new JSONObject().put("role", "user").put("content", query + "?")); messages.put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style")); json.put("messages", messages); - json.put("model", "gpt-3.5-turbo-16k"); + json.put("model", "gpt-4o-mini"); final JSONObject result = (JSONObject) completionsTool.raw(json); assertResult(result); @@ -160,7 +164,7 @@ public void test_raw_json() { @Test public void test_raw_map() { final String query = "Who was the first president of the United States"; - final Map map = Map.of("model", "gpt-3.5-turbo-16k", "messages", new JSONArray() + final Map map = Map.of("model", "gpt-4o-mini", "messages", new JSONArray() .put(new JSONObject().put("role", "user").put("content", query + "?")) .put(new JSONObject().put("role", "system").put("content", "You are a helpful assistant with a descriptive writing style"))); @@ -179,6 +183,11 @@ Host host() { AppConfig config() { return appConfig; } + + @Override + User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java index a58bf579bc65..5f5b875a7f31 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java @@ -5,13 +5,13 @@ import com.dotcms.ai.app.ConfigService; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; +import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; -import com.dotmarketing.exception.DotDataException; -import com.dotmarketing.exception.DotSecurityException; import com.github.tomakehurst.wiremock.WireMockServer; +import com.liferay.portal.model.User; import org.apache.velocity.tools.view.context.ViewContext; import org.junit.AfterClass; import org.junit.Before; @@ -43,6 +43,7 @@ public class EmbeddingsToolTest { private Host host; private AppConfig appConfig; + private User user; private EmbeddingsTool embeddingsTool; @BeforeClass @@ -50,16 +51,17 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); + AiTest.aiAppSecrets(APILocator.systemHost()); } @Before - public void before() throws DotDataException, DotSecurityException { + public void before() throws Exception { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); host = new SiteDataGen().nextPersisted(); - AiTest.aiAppSecrets(wireMockServer, host); + AiTest.aiAppSecrets(host); appConfig = ConfigService.INSTANCE.config(host); + user = new UserDataGen().nextPersisted(); embeddingsTool = prepareEmbeddingsTool(viewContext); } @@ -136,6 +138,11 @@ Host host() { AppConfig appConfig() { return appConfig; } + + @Override + public User user() { + return user; + } }; } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java index a9cd8a62cd3b..4d7b7ef6e891 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/SearchToolTest.java @@ -1,5 +1,6 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.AiTest; import com.dotcms.contenttype.model.field.Field; import com.dotcms.contenttype.model.field.TextField; import com.dotcms.contenttype.model.type.ContentType; @@ -12,6 +13,7 @@ import com.dotcms.datagen.TemplateDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; import com.dotmarketing.portlets.contentlet.model.Contentlet; import com.dotmarketing.portlets.htmlpageasset.model.HTMLPageAsset; import com.dotmarketing.portlets.templates.model.Template; @@ -47,14 +49,16 @@ public class SearchToolTest { @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); + AiTest.aiAppSecrets(APILocator.systemHost(), "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); } @Before - public void before() { + public void before() throws Exception { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); host = new SiteDataGen().nextPersisted(); searchTool = prepareSearchTool(viewContext); + AiTest.aiAppSecrets(host, "gpt-4o-mini", "dall-e-3", "text-embedding-ada-002"); } /** diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java index 344efe66d3af..b825ee8842c8 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIAutoTagActionletTest.java @@ -39,7 +39,6 @@ */ public class OpenAIAutoTagActionletTest { - private static AppConfig config; private static Host host; @BeforeClass @@ -77,8 +76,8 @@ public static void beforeClass() throws Exception { .withType(Type.STRING) .withValue("{\"default\":\"blog\"}".toCharArray()) .build()); - config = new AppConfig(host.getHostname(), secrets); - DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments)-> new CompletionsAPI() { + new AppConfig(host.getHostname(), secrets); + DotAIAPIFacadeImpl.addCompletionsAPIImplementation("default", (Object... initArguments) -> new CompletionsAPI() { @Override public JSONObject summarize(CompletionsForm searcher) { return null; @@ -90,7 +89,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(JSONObject promptJSON) { + public JSONObject raw(JSONObject promptJSON, final String userId) { return null; } @@ -104,7 +103,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java index 9e92628e2df6..3db2f537423c 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIContentPromptActionletTest.java @@ -89,7 +89,7 @@ public void summarizeStream(CompletionsForm searcher, OutputStream out) { } @Override - public JSONObject raw(final JSONObject promptJSON) { + public JSONObject raw(final JSONObject promptJSON, final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + @@ -126,7 +126,8 @@ public JSONObject prompt(final String systemPrompt, final String userPrompt, final String model, final float temperature, - final int maxTokens) { + final int maxTokens, + final String userId) { return new JSONObject("{\n" + " \"id\": \"chatcmpl-7bHkIY2cNQXV3yWZmZ1lM1b4AIlJ6\",\n" + " \"object\": \"chat.completion\",\n" + diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIGenerateImageActionletTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIGenerateImageActionletTest.java index 8aa61917bb9f..d35e9a02a8d0 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIGenerateImageActionletTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/workflow/OpenAIGenerateImageActionletTest.java @@ -40,7 +40,6 @@ */ public class OpenAIGenerateImageActionletTest { - private static AppConfig config; private static Host host; @BeforeClass @@ -78,34 +77,29 @@ public static void beforeClass() throws Exception { .withType(Type.STRING) .withValue("{\"default\":\"blog\"}".toCharArray()) .build()); - config = new AppConfig(host.getHostname(), secrets); - DotAIAPIFacadeImpl.setDefaultImageAPIProvider(new ImageAPIProvider() { + new AppConfig(host.getHostname(), secrets); + DotAIAPIFacadeImpl.setDefaultImageAPIProvider(initArguments -> new ImageAPI() { @Override - public ImageAPI getImageAPI(Object... initArguments) { - return new ImageAPI() { - @Override - public JSONObject sendTextPrompt(String prompt) { - return new JSONObject("{\n" + - " \"response\":\"image_id123\",\n" + - " \"tempFile\":\"image_id123\"\n" + - "}"); - } - - @Override - public JSONObject sendRawRequest(String prompt) { - return null; - } - - @Override - public JSONObject sendRequest(JSONObject jsonObject) { - return null; - } - - @Override - public JSONObject sendRequest(AIImageRequestDTO dto) { - return null; - } - }; + public JSONObject sendTextPrompt(String prompt) { + return new JSONObject("{\n" + + " \"response\":\"image_id123\",\n" + + " \"tempFile\":\"image_id123\"\n" + + "}"); + } + + @Override + public JSONObject sendRawRequest(String prompt) { + return null; + } + + @Override + public JSONObject sendRequest(JSONObject jsonObject) { + return null; + } + + @Override + public JSONObject sendRequest(AIImageRequestDTO dto) { + return null; } }); } diff --git a/dotcms-postman/pom.xml b/dotcms-postman/pom.xml index 8957e0ff6835..a84e1627f93b 100644 --- a/dotcms-postman/pom.xml +++ b/dotcms-postman/pom.xml @@ -116,9 +116,6 @@ [WireMock] green - - --verbose -