diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index bfbc6fbda290..3a69c6935a76 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -1,13 +1,13 @@ package com.dotcms.ai.api; import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.db.EmbeddingsDTO; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIModel; import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.api.web.HttpServletRequestThreadLocal; import com.dotcms.mock.request.FakeHttpRequest; @@ -42,7 +42,7 @@ public class CompletionsAPIImpl implements CompletionsAPI { private final Lazy config; - final Lazy defaultConfig = + private final Lazy defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config( Try.of(() -> WebAPILocator .getHostWebAPI() @@ -60,7 +60,7 @@ public JSONObject prompt(final String systemPrompt, final String modelIn, final float temperature, final int maxTokens) { - final OpenAIModel model = OpenAIModel.resolveModel(modelIn); + final AIModel model = config.get().resolveModelOrThrow(modelIn); final JSONObject json = new JSONObject(); json.put(AiKeys.TEMPERATURE, temperature); @@ -70,7 +70,7 @@ public JSONObject prompt(final String systemPrompt, json.put(AiKeys.MAX_TOKENS, maxTokens); } - json.put(AiKeys.MODEL, model.modelName); + json.put(AiKeys.MODEL, model.getCurrentModel()); return raw(json); } @@ -91,7 +91,7 @@ public JSONObject summarize(final CompletionsForm summaryRequest) { Try.of(() -> OpenAIRequest.doRequest( config.get().getApiUrl(), HttpMethod.POST, - config.get().getApiKey(), + config.get(), json)) .getOrElseThrow(DotRuntimeException::new); final JSONObject dotCMSResponse = EmbeddingsAPI.impl().reduceChunksToContent(searcher, localResults); @@ -107,7 +107,7 @@ public void summarizeStream(final CompletionsForm summaryRequest, final OutputSt final JSONObject json = buildRequestJson(summaryRequest, localResults); json.put(AiKeys.STREAM, true); - OpenAIRequest.doPost(config.get().getApiUrl(), config.get().getApiKey(), json, out); + OpenAIRequest.doPost(config.get().getApiUrl(), config.get(), json, out); } @Override @@ -119,7 +119,7 @@ public JSONObject raw(final JSONObject json) { final String response = OpenAIRequest.doRequest( config.get().getApiUrl(), HttpMethod.POST, - config.get().getApiKey(), + config.get(), json); if (config.get().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { Logger.info(this.getClass(), "OpenAI response:" + response); @@ -138,7 +138,7 @@ public JSONObject raw(CompletionsForm promptForm) { public void rawStream(final CompletionsForm promptForm, final OutputStream out) { final JSONObject json = buildRequestJson(promptForm); json.put(AiKeys.STREAM, true); - OpenAIRequest.doRequest(config.get().getApiUrl(), HttpMethod.POST, config.get().getApiKey(), json, out); + OpenAIRequest.doRequest(config.get().getApiUrl(), HttpMethod.POST, config.get(), json, out); } private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) { @@ -151,7 +151,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f } private JSONObject buildRequestJson(final CompletionsForm form, final List searchResults) { - final OpenAIModel model = OpenAIModel.resolveModel(form.model); + final AIModel model = config.get().resolveModelOrThrow(form.model); // aggregate matching results into text final StringBuilder supportingContent = new StringBuilder(); searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" ")); @@ -162,7 +162,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List enc.countTokens(testString)) .orElseThrow(() -> new DotRuntimeException("Encoder not found")); } @@ -244,20 +244,19 @@ private String reduceStringToTokenSize(final String incomingString, final int ma } private JSONObject buildRequestJson(final CompletionsForm form) { - final int maxTokenSize = OpenAIModel.resolveModel(config.get().getConfig(AppKeys.MODEL)).maxTokens; + final AIModel aiModel = config.get().getModel(); final int promptTokens = countTokens(form.prompt); final JSONArray messages = new JSONArray(); final String textPrompt = reduceStringToTokenSize( form.prompt, - maxTokenSize - form.responseLengthTokens - promptTokens); + aiModel.getMaxTokens() - form.responseLengthTokens - promptTokens); messages.add(Map.of(AiKeys.ROLE, AiKeys.USER, AiKeys.CONTENT, textPrompt)); final JSONObject json = new JSONObject(); json.put(AiKeys.MESSAGES, messages); - json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.MODEL)); - + json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.TEXT_MODEL_NAMES)); json.put(AiKeys.TEMPERATURE, form.temperature); json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens); json.put(AiKeys.STREAM, form.stream); diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java index 587382751141..2c49fac5efe2 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java @@ -27,7 +27,6 @@ import com.dotmarketing.exception.DotCorruptedDataException; import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.portlets.contentlet.model.Contentlet; -import com.dotmarketing.util.Config; import com.dotmarketing.util.Logger; import com.dotmarketing.util.StringUtils; import com.dotmarketing.util.UtilMethods; @@ -37,7 +36,6 @@ import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; import com.liferay.portal.model.User; -import io.vavr.Lazy; import io.vavr.Tuple; import io.vavr.Tuple2; import io.vavr.Tuple3; @@ -69,9 +67,6 @@ */ class EmbeddingsAPIImpl implements EmbeddingsAPI { - private static final Lazy OPEN_AI_EMBEDDINGS_URL = Lazy.of(() - -> Config.getStringProperty("OPEN_AI_EMBEDDINGS_URL", "https://api.openai.com/v1/embeddings")); - private static final Cache>> EMBEDDING_CACHE = Caffeine.newBuilder() .expireAfterWrite( @@ -348,7 +343,9 @@ public Tuple2> pullOrGenerateEmbeddings(final String conten return Tuple.of(dbEmbeddings._2, dbEmbeddings._3); } - final Tuple2> openAiEmbeddings = Tuple.of(tokens.size(), this.sendTokensToOpenAI(contentId, tokens)); + final Tuple2> openAiEmbeddings = Tuple.of( + tokens.size(), + sendTokensToOpenAI(contentId, tokens)); saveEmbeddingsForCache(content, openAiEmbeddings); EMBEDDING_CACHE.put(hashed, openAiEmbeddings); @@ -424,13 +421,13 @@ private void saveEmbeddingsForCache(final String content, final Tuple2 sendTokensToOpenAI(final String contentId, @NotNull final List tokens) { final JSONObject json = new JSONObject(); - json.put(AiKeys.MODEL, config.getConfig(AppKeys.EMBEDDINGS_MODEL)); + json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel()); json.put(AiKeys.INPUT, tokens); debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens)); final String responseString = OpenAIRequest.doRequest( - OPEN_AI_EMBEDDINGS_URL.get(), + config.getApiEmbeddingsUrl(), HttpMethod.POST, - this.config.getApiKey(), + config, json); debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s", contentId, responseString.replace("\n", BLANK))); diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java new file mode 100644 index 000000000000..6feaaf24afba --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java @@ -0,0 +1,169 @@ +package com.dotcms.ai.app; + +import com.dotcms.security.apps.AppsUtil; +import com.dotcms.security.apps.Secret; +import com.dotmarketing.util.UtilMethods; +import com.liferay.util.StringPool; +import io.vavr.Lazy; +import io.vavr.control.Try; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Utility class for handling AI application configurations and secrets. + * This class provides methods to resolve secrets, normalize model names, + * split model names, and create AI model instances based on the provided + * configuration and secrets. + * + * @author vico + */ +public class AIAppUtil { + + private static final Lazy INSTANCE = Lazy.of(AIAppUtil::new); + + private AIAppUtil() { + // Private constructor to prevent instantiation + } + + public static AIAppUtil get() { + return INSTANCE.get(); + } + + /** + * Creates a text model instance based on the provided secrets. + * + * @param secrets the map of secrets + * @return the created text model instance + */ + public AIModel createTextModel(final Map secrets) { + return AIModel.builder() + .withType(AIModelType.TEXT) + .withNames(discoverSecret(secrets, AppKeys.TEXT_MODEL_NAMES)) + .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE)) + .withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE)) + .withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS)) + .withIsCompletion(discoverBooleanSecret(secrets, AppKeys.TEXT_MODEL_COMPLETION)) + .build(); + } + + /** + * Creates an image model instance based on the provided secrets. + * + * @param secrets the map of secrets + * @return the created image model instance + */ + public AIModel createImageModel(final Map secrets) { + return AIModel.builder() + .withType(AIModelType.IMAGE) + .withNames(discoverSecret(secrets, AppKeys.IMAGE_MODEL_NAMES)) + .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE)) + .withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE)) + .withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS)) + .withIsCompletion(discoverBooleanSecret(secrets, AppKeys.IMAGE_MODEL_COMPLETION)) + .build(); + } + + /** + * Creates an embeddings model instance based on the provided secrets. + * + * @param secrets the map of secrets + * @return the created embeddings model instance + */ + public AIModel createEmbeddingsModel(final Map secrets) { + return AIModel.builder() + .withType(AIModelType.EMBEDDINGS) + .withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES)) + .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE)) + .withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE)) + .withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS)) + .withIsCompletion(discoverBooleanSecret(secrets, AppKeys.EMBEDDINGS_MODEL_COMPLETION)) + .build(); + } + + /** + * Resolves a secret value from the provided secrets map using the specified key. + * If the secret is not found, the default value is returned. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + * @param defaultValue the default value to return if the secret is not found + * @return the resolved secret value or the default value if the secret is not found + */ + public String discoverSecret(final Map secrets, final AppKeys key, final String defaultValue) { + return Try.of(() -> secrets.get(key.key).getString()).getOrElse(defaultValue); + } + + /** + * Resolves a secret value from the provided secrets map using the specified key. + * If the secret is not found, the default value defined in the key is returned. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + * @return the resolved secret value or the default value defined in the key if the secret is not found + */ + public String discoverSecret(final Map secrets, final AppKeys key) { + return discoverSecret(secrets, key, key.defaultValue); + } + + /** + * Splits a model-specific secret value from the provided secrets map using the specified key. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + * @return the list of split secret values + */ + public List splitDiscoveredSecret(final Map secrets, final AppKeys key) { + return Arrays.stream(discoverSecret(secrets, key).split(",")) + .map(String::trim) + .map(String::toLowerCase) + .collect(Collectors.toList()); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + */ + public int discoverIntSecret(final Map secrets, final AppKeys key) { + return toInt(discoverSecret(secrets, key)); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + */ + public boolean discoverBooleanSecret(final Map secrets, final AppKeys key) { + return Boolean.parseBoolean(discoverSecret(secrets, key)); + } + + /** + * Resolves an environment-specific secret value from the provided secrets map using the specified key. + * If the secret is not found, it attempts to discover the value from environment variables. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + * @return the resolved environment-specific secret value or an empty string if not found + */ + public String discoverEnvSecret(final Map secrets, final AppKeys key) { + final String secret = discoverSecret(secrets, key, StringPool.BLANK); + if (UtilMethods.isSet(secret)) { + return secret; + } + + return Optional + .ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, null)) + .orElse(StringPool.BLANK); + } + + private int toInt(final String value) { + return Try.of(() -> Integer.parseInt(value)).getOrElse(0); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java new file mode 100644 index 000000000000..88b3ef6d58df --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -0,0 +1,179 @@ +package com.dotcms.ai.app; + +import com.dotcms.util.DotPreconditions; +import com.dotmarketing.util.Logger; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Represents an AI model with various attributes such as type, names, tokens per minute, + * API calls per minute, maximum tokens, and completion status. This class provides methods + * to manage the current model, decommission status, and calculate the minimum interval + * between API calls. It also includes a builder for creating instances of AIModel. + * + * @author vico + */ +public class AIModel { + + private final AIModelType type; + private final List names; + private final int tokensPerMinute; + private final int apiPerMinute; + private final int maxTokens; + private final boolean isCompletion; + private final AtomicInteger current; + private final AtomicBoolean decommissioned; + + private AIModel(final AIModelType type, + final List names, + final int tokensPerMinute, + final int apiPerMinute, + final int maxTokens, + final boolean isCompletion) { + DotPreconditions.checkNotNull(type, "type cannot be null"); + this.type = type; + this.names = Optional.ofNullable(names).orElse(List.of()); + this.tokensPerMinute = tokensPerMinute; + this.apiPerMinute = apiPerMinute; + this.maxTokens = maxTokens; + this.isCompletion = isCompletion; + current = new AtomicInteger(this.names.isEmpty() ? -1 : 0); + decommissioned = new AtomicBoolean(false); + } + + public AIModelType getType() { + return type; + } + + public List getNames() { + return names; + } + + public int getTokensPerMinute() { + return tokensPerMinute; + } + + public int getApiPerMinute() { + return apiPerMinute; + } + + public int getMaxTokens() { + return maxTokens; + } + + public boolean isCompletion() { + return isCompletion; + } + + public int getCurrent() { + return current.get(); + } + + public void setCurrent(final int current) { + if (!isCurrentValid(current)) { + logInvalidModelMessage(); + return; + } + this.current.set(current); + } + + public boolean isDecommissioned() { + return decommissioned.get(); + } + + public void setDecommissioned(final boolean decommissioned) { + this.decommissioned.set(decommissioned); + } + + public String getCurrentModel() { + final int currentIndex = this.current.get(); + if (!isCurrentValid(currentIndex)) { + logInvalidModelMessage(); + return null; + } + return names.get(currentIndex); + } + + public long minIntervalBetweenCalls() { + return 60000 / apiPerMinute; + } + + @Override + public String toString() { + return "AIModel{" + + "name='" + names + '\'' + + ", tokensPerMinute=" + tokensPerMinute + + ", apiPerMinute=" + apiPerMinute + + ", maxTokens=" + maxTokens + + ", isCompletion=" + isCompletion + + '}'; + } + + private boolean isCurrentValid(final int current) { + return !names.isEmpty() && current >= 0 && current < names.size(); + } + + private void logInvalidModelMessage() { + Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", names.size())); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private AIModelType type; + private List names; + private int tokensPerMinute; + private int apiPerMinute; + private int maxTokens; + private boolean isCompletion; + + private Builder() { + } + + public Builder withType(final AIModelType type) { + this.type = type; + return this; + } + + public Builder withNames(final List names) { + this.names = names; + return this; + } + + public Builder withNames(final String... names) { + return withNames(List.of(names)); + } + + public Builder withTokensPerMinute(final int tokensPerMinute) { + this.tokensPerMinute = tokensPerMinute; + return this; + } + + public Builder withApiPerMinute(final int apiPerMinute) { + this.apiPerMinute = apiPerMinute; + return this; + } + + public Builder withMaxTokens(final int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder withIsCompletion(final boolean isCompletion) { + this.isCompletion = isCompletion; + return this; + } + + public AIModel build() { + return new AIModel(type, names, tokensPerMinute, apiPerMinute, maxTokens, isCompletion); + } + + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModelType.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModelType.java new file mode 100644 index 000000000000..5f25c015e428 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModelType.java @@ -0,0 +1,19 @@ +package com.dotcms.ai.app; + +/** + * Enum representing different types of AI models used in the application. + * The types include: + *
    + *
  • TEXT: Models used for text generation and processing.
  • + *
  • IMAGE: Models used for image generation and processing.
  • + *
  • EMBEDDINGS: Models used for generating vector embeddings from text or other data.
  • + *
  • UNKNOWN: Represents an unknown or unsupported model type.
  • + *
+ * + * @author vico + */ +public enum AIModelType { + + TEXT, IMAGE, EMBEDDINGS, UNKNOWN + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java new file mode 100644 index 000000000000..0773d0de5711 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -0,0 +1,215 @@ +package com.dotcms.ai.app; + +import com.dotcms.ai.model.OpenAIModel; +import com.dotcms.ai.model.OpenAIModels; +import com.dotcms.http.CircuitBreakerUrl; +import com.dotmarketing.beans.Host; +import com.dotmarketing.util.Config; +import com.dotmarketing.util.Logger; +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import io.vavr.Lazy; +import io.vavr.Tuple; +import io.vavr.Tuple2; +import io.vavr.control.Try; +import org.apache.commons.collections4.CollectionUtils; + +import java.time.Duration; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; + +/** + * Manages the AI models used in the application. This class handles loading, caching, + * and retrieving AI models based on the host and model type. It also fetches supported + * models from external sources and maintains a cache of these models. + * + * @author vico + */ +public class AIModels { + + private static final String SUPPORTED_MODELS_KEY = "supportedModels"; + private static final String AI_MODELS_FETCH_ATTEMPTS_KEY = "ai.models.fetch.attempts"; + private static final int AI_MODELS_FETCH_ATTEMPTS = Config.getIntProperty(AI_MODELS_FETCH_ATTEMPTS_KEY, 3); + private static final String AI_MODELS_FETCH_TIMEOUT_KEY = "ai.models.fetch.timeout"; + private static final int AI_MODELS_FETCH_TIMEOUT = Config.getIntProperty(AI_MODELS_FETCH_TIMEOUT_KEY, 4000); + private static final Lazy INSTANCE = Lazy.of(AIModels::new); + private static final String OPEN_AI_MODELS_URL = Config.getStringProperty( + "OPEN_AI_MODELS_URL", + "https://api.openai.com/v1/models"); + private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours + private static final int AI_MODELS_CACHE_SIZE = 128; + + public static final AIModel NOOP_MODEL = AIModel.builder() + .withType(AIModelType.UNKNOWN) + .withNames(List.of()) + .build(); + + private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); + private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); + private final Cache> supportedModelsCache = + Caffeine.newBuilder() + .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) + .maximumSize(AI_MODELS_CACHE_SIZE) + .build(); + + public static AIModels get() { + return INSTANCE.get(); + } + + private AIModels() { + } + + /** + * Loads the given list of AI models for the specified host. If models for the host + * are already loaded, this method does nothing. It also maps model names to their + * corresponding AIModel instances. + * + * @param host the host for which the models are being loaded + * @param loading the list of AI models to load + */ + public void loadModels(final String host, final List loading) { + Optional.ofNullable(internalModels.get(host)) + .ifPresentOrElse( + model -> {}, + () -> internalModels.putIfAbsent( + host, + loading.stream() + .map(model -> Tuple.of(model.getType(), model)) + .collect(Collectors.toList()))); + loading.forEach(model -> model + .getNames() + .forEach(name -> { + final Tuple2 key = Tuple.of( + host, + name.toLowerCase().trim()); + if (modelsByName.containsKey(key)) { + Logger.debug( + this, + String.format( + "Model [%s] already exists for host [%s], ignoring it", + name, + host)); + return; + } + modelsByName.putIfAbsent(key, model); + })); + } + + /** + * Finds an AI model by the host and model name. The search is case-insensitive. + * + * @param host the host for which the model is being searched + * @param modelName the name of the model to find + * @return an Optional containing the found AIModel, or an empty Optional if not found + */ + public Optional findModel(final String host, final String modelName) { + return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase()))); + } + + /** + * Finds an AI model by the host and model type. + * + * @param host the host for which the model is being searched + * @param type the type of the model to find + * @return an Optional containing the found AIModel, or an empty Optional if not found + */ + public Optional findModel(final String host, final AIModelType type) { + return Optional.ofNullable(internalModels.get(host)) + .flatMap(tuples -> tuples.stream() + .filter(tuple -> tuple._1 == type) + .map(Tuple2::_2) + .findFirst()); + } + + /** + * Resets the internal models cache for the specified host. + * + * @param host the host for which the models are being reset + */ + public void resetModels(final Host host) { + final String hostKey = host.getHostname(); + synchronized (AIModels.class) { + Optional.ofNullable(internalModels.get(hostKey)).ifPresent(models -> { + models.clear(); + internalModels.remove(hostKey); + }); + modelsByName.keySet() + .stream() + .filter(key -> key._1.equals(hostKey)) + .collect(Collectors.toSet()) + .forEach(modelsByName::remove); + ConfigService.INSTANCE.config(host); + } + } + + /** + * Retrieves the list of supported models, either from the cache or by fetching them + * from an external source if the cache is empty or expired. + * + * @return a list of supported model names + */ + public List getOrPullSupportedModels() { + final List cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); + if (CollectionUtils.isNotEmpty(cached)) { + return cached; + } + + final AppConfig appConfig = ConfigService.INSTANCE.config(); + final List supported = Try.of(() -> + fetchOpenAIModels(appConfig) + .getResponse() + .getData() + .stream() + .map(OpenAIModel::getId) + .map(String::toLowerCase) + .collect(Collectors.toList())) + .getOrElse(Optional.ofNullable(cached).orElse(List.of())); + supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); + + return supported; + } + + /** + * Retrieves the list of available models that are both configured and supported. + * + * @return a list of available model names + */ + public List getAvailableModels() { + final Set configured = internalModels.entrySet().stream().flatMap(entry -> entry.getValue().stream()) + .map(Tuple2::_2) + .flatMap(model -> model.getNames().stream()) + .collect(Collectors.toSet()); + final Set supported = new HashSet<>(getOrPullSupportedModels()); + configured.retainAll(supported); + return configured.stream().sorted().collect(Collectors.toList()); + } + + private static CircuitBreakerUrl.Response fetchOpenAIModels(final AppConfig appConfig) { + + final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() + .setMethod(CircuitBreakerUrl.Method.GET) + .setUrl(OPEN_AI_MODELS_URL) + .setTimeout(AI_MODELS_FETCH_TIMEOUT) + .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) + .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) + .setThrowWhenNot2xx(false) + .build() + .doResponse(OpenAIModels.class); + + if (!CircuitBreakerUrl.isSuccessResponse(response)) { + Logger.debug( + AIModels.class, + String.format( + "Error fetching OpenAI supported models from [%s] (status code: [%d])", + OPEN_AI_MODELS_URL, + response.getStatusCode())); + } + + return response; + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index d5a8105d3895..d3a161daa746 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,15 +1,16 @@ package com.dotcms.ai.app; -import com.dotcms.security.apps.AppsUtil; import com.dotcms.security.apps.Secret; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.Config; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; -import com.liferay.util.StringPool; import io.vavr.control.Try; +import org.apache.commons.lang3.StringUtils; import java.io.Serializable; +import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -20,12 +21,16 @@ */ public class AppConfig implements Serializable { - public static final Pattern SPLITTER= Pattern.compile("\\s?,\\s?"); + private static final String OPEN_AI_EMBEDDINGS_URL_KEY = "OPEN_AI_EMBEDDINGS_URL"; + public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?"); - public final String model; - public final String imageModel; + private final String host; + private final transient AIModel model; + private final transient AIModel imageModel; + private final transient AIModel embeddingsModel; private final String apiUrl; private final String apiImageUrl; + private final String apiEmbeddingsUrl; private final String apiKey; private final String rolePrompt; private final String textPrompt; @@ -34,47 +39,45 @@ public class AppConfig implements Serializable { private final String listenerIndexer; private final Map configValues; - public AppConfig(final Map secrets) { - this.configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - apiUrl = resolveEnvSecret(secrets, AppKeys.API_URL); - apiImageUrl = resolveEnvSecret(secrets, AppKeys.API_IMAGE_URL); - apiKey = resolveEnvSecret(secrets, AppKeys.API_KEY); - rolePrompt = resolveSecretOrBlank(secrets, AppKeys.ROLE_PROMPT); - textPrompt = resolveSecretOrBlank(secrets, AppKeys.TEXT_PROMPT); - imagePrompt = resolveSecretOrBlank(secrets, AppKeys.IMAGE_PROMPT); - imageSize = resolveSecret(secrets, AppKeys.IMAGE_SIZE, AppKeys.IMAGE_SIZE.defaultValue); - model = resolveSecretOrBlank(secrets, AppKeys.MODEL); - imageModel = resolveSecret(secrets, AppKeys.IMAGE_MODEL, "dall-e-3"); - listenerIndexer = resolveSecretOrBlank(secrets, AppKeys.LISTENER_INDEXER); - Logger.debug(this.getClass().getName(), () -> "apiUrl: " + apiUrl); - Logger.debug(this.getClass().getName(), () -> "apiImageUrl: " + apiImageUrl); - Logger.debug(this.getClass().getName(), () -> "apiKey: " + apiKey); - Logger.debug(this.getClass().getName(), () -> "rolePrompt: " + rolePrompt); - Logger.debug(this.getClass().getName(), () -> "textPrompt: " + textPrompt); - Logger.debug(this.getClass().getName(), () -> "imagePrompt: " + imagePrompt); - Logger.debug(this.getClass().getName(), () -> "imageModel: " + imageModel); - Logger.debug(this.getClass().getName(), () -> "imageSize: " + imageSize); - Logger.debug(this.getClass().getName(), () -> "model: " + model); - Logger.debug(this.getClass().getName(), () -> "listerIndexer: " + listenerIndexer); - } - - private String resolveSecret(final Map secrets, final AppKeys key, final String defaultValue) { - return Try.of(() -> secrets.get(key.key).getString()).getOrElse(defaultValue); - } - - private String resolveSecretOrBlank(final Map secrets, final AppKeys key) { - return resolveSecret(secrets, key, StringPool.BLANK); - } - - private String resolveEnvSecret(final Map secrets, final AppKeys key) { - final String secret = resolveSecretOrBlank(secrets, key); - if (UtilMethods.isSet(secret)) { - return secret; - } + public AppConfig(final String host, final Map secrets) { + this.host = host; + + final AIAppUtil aiAppUtil = AIAppUtil.get(); + AIModels.get().loadModels( + this.host, + List.of( + aiAppUtil.createTextModel(secrets), + aiAppUtil.createImageModel(secrets), + aiAppUtil.createEmbeddingsModel(secrets))); + + model = resolveModel(AIModelType.TEXT); + imageModel = resolveModel(AIModelType.IMAGE); + embeddingsModel = resolveModel(AIModelType.EMBEDDINGS); + + apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL); + apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL); + apiEmbeddingsUrl = discoverEmbeddingsApiUrl(secrets); + apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY); + rolePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.ROLE_PROMPT); + textPrompt = aiAppUtil.discoverSecret(secrets, AppKeys.TEXT_PROMPT); + imagePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.IMAGE_PROMPT); + imageSize = aiAppUtil.discoverSecret(secrets, AppKeys.IMAGE_SIZE); + listenerIndexer = aiAppUtil.discoverSecret(secrets, AppKeys.LISTENER_INDEXER); - return Optional - .ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, null)) - .orElse(StringPool.BLANK); + configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + Logger.debug(getClass(), () -> "apiUrl: " + apiUrl); + Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl); + Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl); + Logger.debug(getClass(), () -> "apiKey: " + apiKey); + Logger.debug(getClass(), () -> "model: " + model); + Logger.debug(getClass(), () -> "imageModel: " + imageModel); + Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel); + Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt); + Logger.debug(getClass(), () -> "textPrompt: " + textPrompt); + Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt); + Logger.debug(getClass(), () -> "imageSize: " + imageSize); + Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer); } /** @@ -83,7 +86,7 @@ private String resolveEnvSecret(final Map secrets, final AppKeys * @return the API URL */ public String getApiUrl() { - return UtilMethods.isEmpty(apiUrl) ? "https://api.openai.com/v1/chat/completions" : apiUrl; + return UtilMethods.isEmpty(apiUrl) ? AppKeys.API_URL.defaultValue : apiUrl; } /** @@ -92,7 +95,16 @@ public String getApiUrl() { * @return the API Image URL */ public String getApiImageUrl() { - return UtilMethods.isEmpty(apiImageUrl)? "https://api.openai.com/v1/images/generations" : apiImageUrl; + return UtilMethods.isEmpty(apiImageUrl) ? AppKeys.API_IMAGE_URL.defaultValue : apiImageUrl; + } + + /** + * Retrieves the API Embeddings URL. + * + * @return + */ + public String getApiEmbeddingsUrl() { + return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl; } /** @@ -105,12 +117,12 @@ public String getApiKey() { } /** - * Retrieves the Role Prompt. + * Retrieves the Model. * - * @return the Role Prompt + * @return the Model */ - public String getRolePrompt() { - return rolePrompt; + public AIModel getModel() { + return model; } /** @@ -118,7 +130,27 @@ public String getRolePrompt() { * * @return the Image Model */ - public String getImageModel() {return imageModel;} + public AIModel getImageModel() { + return imageModel; + } + + /** + * Retrieves the Embeddings Model. + * + * @return the Embeddings Model + */ + public AIModel getEmbeddingsModel() { + return embeddingsModel; + } + + /** + * Retrieves the Role Prompt. + * + * @return the Role Prompt + */ + public String getRolePrompt() { + return rolePrompt; + } /** * Retrieves the Text Prompt. @@ -147,15 +179,6 @@ public String getImageSize() { return imageSize; } - /** - * Retrieves the Model. - * - * @return the Model - */ - public String getModel() { - return model; - } - /** * Retrieves the Listener Indexer. * @@ -171,9 +194,9 @@ public String getListenerIndexer() { * @param appKey the key to retrieve the configuration value for * @return the integer configuration value */ - public int getConfigInteger(AppKeys appKey) { - String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); - return Try.of(()->Integer.parseInt(value)).getOrElse(0); + public int getConfigInteger(final AppKeys appKey) { + String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); + return Try.of(() -> Integer.parseInt(value)).getOrElse(0); } /** @@ -182,9 +205,9 @@ public int getConfigInteger(AppKeys appKey) { * @param appKey the key to retrieve the configuration value for * @return the float configuration value */ - public float getConfigFloat(AppKeys appKey) { - String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); - return Try.of(()->Float.parseFloat(value)).getOrElse(0f); + public float getConfigFloat(final AppKeys appKey) { + String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); + return Try.of(() -> Float.parseFloat(value)).getOrElse(0f); } /** @@ -193,9 +216,9 @@ public float getConfigFloat(AppKeys appKey) { * @param appKey the key to retrieve the configuration value for * @return the boolean configuration value */ - public boolean getConfigBoolean(AppKeys appKey) { - String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); - return Try.of(()->Boolean.parseBoolean(value)).getOrElse(false); + public boolean getConfigBoolean(final AppKeys appKey) { + final String value = Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); + return Try.of(() -> Boolean.parseBoolean(value)).getOrElse(false); } /** @@ -204,9 +227,8 @@ public boolean getConfigBoolean(AppKeys appKey) { * @param appKey the key to retrieve the configuration value for * @return the array configuration value */ - public String[] getConfigArray(AppKeys appKey) { - String returnValue = getConfig(appKey); - + public String[] getConfigArray(final AppKeys appKey) { + final String returnValue = getConfig(appKey); return returnValue != null ? SPLITTER.split(returnValue) : new String[0]; } @@ -216,13 +238,37 @@ public String[] getConfigArray(AppKeys appKey) { * @param appKey the key to retrieve the configuration value for * @return the configuration value */ - public String getConfig(AppKeys appKey) { + public String getConfig(final AppKeys appKey) { if (configValues.containsKey(appKey.key)) { return Try.of(() -> configValues.get(appKey.key).getString()).getOrElse(appKey.defaultValue); } return appKey.defaultValue; } + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param type the type of the model to find + */ + public AIModel resolveModel(final AIModelType type) { + return AIModels.get().findModel(host, type).orElse(AIModels.NOOP_MODEL); + } + + /** + * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. + * + * @param modelName the name of the model to find + */ + public AIModel resolveModelOrThrow(final String modelName) { + return AIModels.get() + .findModel(host, modelName) + .orElseThrow(() -> { + final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels()); + return new DotRuntimeException( + "Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported "); + }); + } + /** * Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING} * property instead of the usual Log4j configuration. @@ -236,4 +282,11 @@ public static void debugLogger(final Class clazz, final Supplier mess } } -} \ No newline at end of file + private String discoverEmbeddingsApiUrl(final Map secrets) { + final String url = AIAppUtil.get().discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL); + return StringUtils.isBlank(url) + ? Config.getStringProperty(OPEN_AI_EMBEDDINGS_URL_KEY, "https://api.openai.com/v1/embeddings") + : url; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java index 7f79b831d966..947c0bf2a831 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java @@ -1,22 +1,33 @@ package com.dotcms.ai.app; public enum AppKeys { + API_URL("apiUrl", "https://api.openai.com/v1/chat/completions"), API_IMAGE_URL("apiImageUrl", "https://api.openai.com/v1/images/generations"), + API_EMBEDDINGS_URL("apiEmbeddingsUrl", null), API_KEY("apiKey", null), - ROLE_PROMPT("rolePrompt", "You are dotCMSbot, and AI assistant to help content" + - " creators generate and rewrite content in their content management system."), + ROLE_PROMPT( + "rolePrompt", + "You are dotCMSbot, and AI assistant to help content" + + " creators generate and rewrite content in their content management system."), TEXT_PROMPT("textPrompt", "Use Descriptive writing style."), IMAGE_PROMPT("imagePrompt", "Use 16:9 aspect ratio."), IMAGE_SIZE("imageSize", "1024x1024"), - MODEL("model", "gpt-3.5-turbo-16k"), - IMAGE_MODEL("imageModel", "dall-e-3"), - DEBUG_LOGGING("com.dotcms.ai.debug.logging", "false"), - COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"), - COMPLETION_ROLE_PROMPT("com.dotcms.ai.completion.role.prompt", - "You are a helpful assistant with a descriptive writing style."), - COMPLETION_TEXT_PROMPT("com.dotcms.ai.completion.text.prompt", "Answer this question\\n\\\"$!{prompt}?\\\"\\n\\nby using only the information in the following text:\\n\"\"\"\\n$!{supportingContent} \\n\"\"\"\\n"), - EMBEDDINGS_MODEL("com.dotcms.ai.embeddings.model", "text-embedding-ada-002"), + TEXT_MODEL_NAMES("textModelNames", "gpt-3.5-turbo-16k"), + TEXT_MODEL_TOKENS_PER_MINUTE("textModelTokensPerMinute", "1000"), + TEXT_MODEL_API_PER_MINUTE("textModelApiPerMinute", "1000"), + TEXT_MODEL_MAX_TOKENS("textModelMaxTokens", "1000"), + TEXT_MODEL_COMPLETION("textModelCompletion", "true"), + IMAGE_MODEL_NAMES("imageModelNames", "dall-e-3"), + IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "1000"), + IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "1000"), + IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "1000"), + IMAGE_MODEL_COMPLETION("imageModelCompletion", "true"), + EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", "text-embedding-ada-002"), + EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000"), + EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "1000"), + EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "1000"), + EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", "true"), EMBEDDINGS_SPLIT_AT_TOKENS("com.dotcms.ai.embeddings.split.at.tokens", "512"), EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX("com.dotcms.ai.embeddings.minimum.text.length", "64"), EMBEDDINGS_MINIMUM_FILE_SIZE_TO_INDEX("com.dotcms.ai.embeddings.minimum.file.size", "1024"), @@ -27,8 +38,19 @@ public enum AppKeys { EMBEDDINGS_THREADS_QUEUE("com.dotcms.ai.embeddings.threads.queue", "10000"), EMBEDDINGS_CACHE_TTL_SECONDS("com.dotcms.ai.embeddings.cache.ttl.seconds", "600"), EMBEDDINGS_CACHE_SIZE("com.dotcms.ai.embeddings.cache.size", "1000"), + EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"), + DEBUG_LOGGING("com.dotcms.ai.debug.logging", "false"), + COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"), + COMPLETION_ROLE_PROMPT( + "com.dotcms.ai.completion.role.prompt", + "You are a helpful assistant with a descriptive writing style."), + COMPLETION_TEXT_PROMPT( + "com.dotcms.ai.completion.text.prompt", + "Answer this question\\n\\\"$!{prompt}?\\\"\\n\\nby using only the information in" + + " the following text:\\n\"\"\"\\n$!{supportingContent} \\n\"\"\"\\n"), LISTENER_INDEXER("listenerIndexer", "{}"), - EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"); + AI_MODELS_CACHE_TTL("com.dotcms.ai.models.supported.ttl", "28800"), + AI_MODELS_CACHE_SIZE("com.dotcms.ai.models.supported.size", "64"); public static final String APP_KEY = "dotAI"; @@ -39,4 +61,5 @@ public enum AppKeys { this.key = key; this.defaultValue = defaultValue; } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java index 3e035e277bb6..ca1e9d7eb91c 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java @@ -17,8 +17,7 @@ public class ConfigService { public static final ConfigService INSTANCE = new ConfigService(); - public AppConfig config() { - return config(null); + private ConfigService() { } /** @@ -26,12 +25,21 @@ public AppConfig config() { * by dotCMS. */ public AppConfig config(final Host host) { + final Host resolved = resolveHost(host); final Optional appSecrets = Try.of(() -> APILocator - .getAppsAPI() - .getSecrets(AppKeys.APP_KEY, true, resolveHost(host), APILocator.systemUser())) + .getAppsAPI() + .getSecrets(AppKeys.APP_KEY, true, resolved, APILocator.systemUser())) .getOrElse(Optional.empty()); - return new AppConfig(appSecrets.map(AppSecrets::getSecrets).orElse(Map.of())); + return new AppConfig(resolved.getHostname(), appSecrets.map(AppSecrets::getSecrets).orElse(Map.of())); + } + + /** + * Gets the secrets from the App - this will check the current host then the SYSTEM_HOST for a valid configuration. This lookup is low overhead and cached + * by dotCMS. + */ + public AppConfig config() { + return config(null); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java new file mode 100644 index 000000000000..226be03607e8 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/AIAppListener.java @@ -0,0 +1,70 @@ +package com.dotcms.ai.listener; + +import com.dotcms.ai.app.AIModels; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.security.apps.AppSecretSavedEvent; +import com.dotcms.system.event.local.model.EventSubscriber; +import com.dotcms.system.event.local.model.KeyFilterable; +import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; +import com.dotmarketing.portlets.contentlet.business.HostAPI; +import com.dotmarketing.util.Logger; +import io.vavr.control.Try; +import org.apache.commons.lang3.StringUtils; + +import java.util.Objects; +import java.util.Optional; + +/** + * This class listens to events related to the AI application and performs actions based on those events. + * It implements the EventSubscriber interface and overrides its methods to provide custom functionality. + * The class also implements the KeyFilterable interface to filter events based on a specific key. + * + * @author vico + */ +public final class AIAppListener implements EventSubscriber, KeyFilterable { + + private final HostAPI hostAPI; + + public AIAppListener(final HostAPI hostAPI) { + this.hostAPI = hostAPI; + } + + public AIAppListener() { + this(APILocator.getHostAPI()); + } + + @Override + public void notify(final AppSecretSavedEvent event) { + if (Objects.isNull(event)) { + Logger.debug(this, "Missing event, aborting"); + return; + } + + if (StringUtils.isBlank(event.getHostIdentifier())) { + Logger.debug(this, "Missing event's host id, aborting"); + return; + } + + final String hostId = event.getHostIdentifier(); + final Host host = Try.of(() -> hostAPI.find(hostId, APILocator.systemUser(), false)).getOrNull(); + + Optional.ofNullable(host).ifPresent(found -> AIModels.get().resetModels(found)); + } + + @Override + public Comparable getKey() { + return AppKeys.APP_KEY; + } + + public enum Instance { + SINGLETON; + + private final AIAppListener provider = new AIAppListener(); + + public static AIAppListener get() { + return AIAppListener.Instance.SINGLETON.provider; + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java index 0e5ba4cbe1f1..53f83c3ab149 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java @@ -27,16 +27,17 @@ public String getSize() { return size; } - public int getNumberOfImages() { return numberOfImages; } - public String getPrompt() { return prompt; } + public String getModel() { + return model; + } public static class Builder { @JsonSetter(nulls = Nulls.SKIP) @@ -46,7 +47,7 @@ public static class Builder { @JsonSetter(nulls = Nulls.SKIP) private String size = ConfigService.INSTANCE.config().getImageSize(); @JsonSetter(nulls = Nulls.SKIP) - private String model = ConfigService.INSTANCE.config().getImageModel(); + private String model = ConfigService.INSTANCE.config().getImageModel().getCurrentModel(); public AIImageRequestDTO build() { return new AIImageRequestDTO(this); @@ -72,4 +73,5 @@ public Builder size(String size) { return this; } } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModel.java new file mode 100644 index 000000000000..3ff86d8ad44c --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModel.java @@ -0,0 +1,43 @@ +package com.dotcms.ai.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; + +public class OpenAIModel implements Serializable { + + private final String id; + private final String object; + private final long created; + private final String ownedBy; + + @JsonCreator + public OpenAIModel(@JsonProperty("id") final String id, + @JsonProperty("object") final String object, + @JsonProperty("created") final long created, + @JsonProperty("owned_by") final String ownedBy) { + this.id = id; + this.object = object; + this.created = created; + this.ownedBy = ownedBy; + } + + public String getId() { + return id; + } + + public String getObject() { + return object; + } + + public long getCreated() { + return created; + } + + @JsonProperty("owned_by") + public String getOwnedBy() { + return ownedBy; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java new file mode 100644 index 000000000000..faa691b6a9c1 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java @@ -0,0 +1,29 @@ +package com.dotcms.ai.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.List; + +public class OpenAIModels implements Serializable { + + private final String object; + private final List data; + + @JsonCreator + public OpenAIModels(@JsonProperty("object") final String object, + @JsonProperty("data") final List data) { + this.object = object; + this.data = data; + } + + public String getObject() { + return object; + } + + public List getData() { + return data; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index 6a9e3065fd32..d56f4857870f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -2,12 +2,12 @@ import com.dotcms.ai.AiKeys; import com.dotcms.ai.api.CompletionsAPI; +import com.dotcms.ai.app.AIModels; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.LineReadingOutputStream; -import com.dotcms.ai.util.OpenAIModel; import com.dotcms.rest.WebResource; import com.dotmarketing.beans.Host; import com.dotmarketing.business.web.WebAPILocator; @@ -28,13 +28,11 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.StreamingOutput; import java.io.OutputStream; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Consumer; import java.util.function.Supplier; -import java.util.stream.Collectors; /** * The CompletionsResource class provides REST endpoints for interacting with the AI completions service. @@ -101,7 +99,7 @@ public final Response rawPrompt(@Context final HttpServletRequest request, @Produces({MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_JSON}) public final Response getConfig(@Context final HttpServletRequest request, @Context final HttpServletResponse response) { - // get user if we have one (this is allow anon) + // get user if we have one (this allows anon) new WebResource .InitBuilder(request, response) .requiredBackendUser(true) @@ -120,10 +118,7 @@ public final Response getConfig(@Context final HttpServletRequest request, final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET"; map.put(AppKeys.API_KEY.key, apiKey); - final List models = Arrays.stream(OpenAIModel.values()) - .filter(m->m.completionModel) - .map(m-> m.modelName) - .collect(Collectors.toList()); + final List models = AIModels.get().getAvailableModels(); map.put(AiKeys.AVAILABLE_MODELS, models); return Response.ok(map).build(); @@ -145,7 +140,10 @@ private static CompletionsForm resolveForm(final HttpServletRequest request, .getUser(); final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request); return (!user.isAdmin()) - ? CompletionsForm.copy(formIn).model(ConfigService.INSTANCE.config(host).getModel()).build() + ? CompletionsForm + .copy(formIn) + .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel()) + .build() : formIn; } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java index cfa0045b4450..a6bbbbeeec81 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java @@ -117,7 +117,7 @@ private CompletionsForm(final Builder builder) { } else { this.temperature = builder.temperature >= 2 ? 2 : builder.temperature; } - this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getConfig(AppKeys.MODEL); + this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES); } private String validateBuilderQuery(final String query) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java index e70a3814a557..61815b1307eb 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java @@ -44,7 +44,7 @@ private EmbeddingsForm(Builder builder) { this.indexName = UtilMethods.isSet(builder.indexName) ? builder.indexName : "default"; this.velocityTemplate = builder.velocityTemplate; this.offset = builder.offset; - this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getConfig(AppKeys.EMBEDDINGS_MODEL); + this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getEmbeddingsModel().getCurrentModel(); this.fields = (builder.fields != null) ? AppConfig.SPLITTER.split(builder.fields.toLowerCase()) : new String[0]; this.userId= PortalUtil.getUser() != null ? PortalUtil.getUser().getUserId() : APILocator.systemUser().getUserId(); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIChatServiceImpl.java b/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIChatServiceImpl.java index 2e219c62dbf7..08edb4d5d691 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIChatServiceImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIChatServiceImpl.java @@ -22,7 +22,7 @@ public OpenAIChatServiceImpl(final AppConfig appConfig) { @Override public JSONObject sendRawRequest(final JSONObject prompt) { - prompt.putIfAbsent(AiKeys.MODEL, config.getModel()); + prompt.putIfAbsent(AiKeys.MODEL, config.getModel().getCurrentModel()); prompt.putIfAbsent(AiKeys.TEMPERATURE, config.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)); if (UtilMethods.isEmpty(prompt.optString(AiKeys.MESSAGES))) { @@ -36,7 +36,7 @@ public JSONObject sendRawRequest(final JSONObject prompt) { prompt.remove(AiKeys.PROMPT); - return new JSONObject(doRequest(config.getApiUrl(), config.getApiKey(), prompt)); + return new JSONObject(doRequest(config.getApiUrl(), prompt)); } @Override @@ -47,8 +47,8 @@ public JSONObject sendTextPrompt(final String textPrompt) { } @VisibleForTesting - String doRequest(final String urlIn, final String openAiAPIKey, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, openAiAPIKey, json); + String doRequest(final String urlIn, final JSONObject json) { + return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIImageServiceImpl.java b/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIImageServiceImpl.java index c5da0cd6f4b7..57b571dc140b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIImageServiceImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/service/OpenAIImageServiceImpl.java @@ -55,12 +55,12 @@ public JSONObject sendRequest(final JSONObject jsonObject) { } OpenAiRequestUtil.get().handleLargePrompt(jsonObject); - jsonObject.putIfAbsent(AiKeys.MODEL, config.getImageModel()); + jsonObject.putIfAbsent(AiKeys.MODEL, config.getImageModel().getCurrentModel()); jsonObject.putIfAbsent(AiKeys.SIZE, config.getImageSize()); String responseString = ""; try { - responseString = doRequest(config.getApiImageUrl(), config.getApiKey(), jsonObject); + responseString = doRequest(config.getApiImageUrl(), jsonObject); JSONObject returnObject = new JSONObject(responseString); if (returnObject.containsKey(AiKeys.ERROR)) { @@ -87,7 +87,7 @@ public JSONObject sendRawRequest(final String prompt) { @Override public JSONObject sendRequest(final AIImageRequestDTO dto) { JSONObject jsonRequest = new JSONObject(); - jsonRequest.put(AiKeys.MODEL, config.getImageModel()); + jsonRequest.put(AiKeys.MODEL, config.getImageModel().getCurrentModel()); jsonRequest.put(AiKeys.PROMPT, dto.getPrompt()); jsonRequest.put(AiKeys.SIZE, dto.getSize()); return sendRequest(jsonRequest); @@ -173,8 +173,8 @@ private String generateFileName(final String originalPrompt) { } @VisibleForTesting - String doRequest(final String urlIn, final String openAiAPIKey, final JSONObject json) { - return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, openAiAPIKey, json); + String doRequest(final String urlIn, final JSONObject json) { + return OpenAIRequest.doRequest(urlIn, HttpMethod.POST, config, json); } @VisibleForTesting diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java b/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java index cb5a836c1783..0a4258959cd7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java @@ -1,6 +1,5 @@ package com.dotcms.ai.util; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; @@ -11,7 +10,7 @@ public class EncodingUtil { public static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); - public static final String model = ConfigService.INSTANCE.config().getConfig(AppKeys.EMBEDDINGS_MODEL); + public static final String model = ConfigService.INSTANCE.config().getEmbeddingsModel().getCurrentModel(); public static Lazy encoding = Lazy.of(()-> registry.getEncodingForModel(model).get() diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIModel.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIModel.java deleted file mode 100644 index 7e368c002ae3..000000000000 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIModel.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.dotcms.ai.util; - -import com.dotmarketing.exception.DotRuntimeException; - -import java.util.Arrays; -import java.util.stream.Collectors; - -/** - * Enum representing different models of OpenAI. - * Each enum value contains the model name, tokens per minute, API per minute, maximum tokens, and a flag indicating if it's a completion model. - */ -public enum OpenAIModel { - - GPT_3_5_TURBO("gpt-3.5-turbo", 3000, 3500, 4096, true), - GPT_3_5_TURBO_16k("gpt-3.5-turbo-16k", 180000, 3500, 16384, true), - GPT_4("gpt-4", 10000, 200, 8191, true), - GPT_4_TURBO("gpt-4-1106-preview", 10000, 200, 128000, true), - GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview", 10000, 200, 128000, true), - TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", 1000000, 3000, 8191, false), - DALL_E_2("dall-e-2", 0, 50, 0, false), - DALL_E_3("dall-e-3", 0, 50, 0, false); - - public final int tokensPerMinute; - public final int apiPerMinute; - public final int maxTokens; - public final String modelName; - public final boolean completionModel; - - OpenAIModel(final String modelName, - final int tokensPerMinute, - final int apiPerMinute, - final int maxTokens, - final boolean completionModel) { - this.modelName = modelName; - this.tokensPerMinute = tokensPerMinute; - this.apiPerMinute = apiPerMinute; - this.maxTokens = maxTokens; - this.completionModel = completionModel; - } - - /** - * Resolves the model based on the input string. - * - * @param modelIn The input string representing the model. - * @return The corresponding OpenAIModel. - * @throws DotRuntimeException If the input string does not correspond to any OpenAIModel. - */ - public static OpenAIModel resolveModel(final String modelIn) { - final String modelOut = modelIn.replace("-", "_").replace(".", "_").toUpperCase().trim(); - for (final OpenAIModel openAiModel : OpenAIModel.values()) { - if (openAiModel.modelName.equalsIgnoreCase(modelIn) || openAiModel.name().equalsIgnoreCase(modelOut)) { - return openAiModel; - } - } - - throw new DotRuntimeException( - "Unable to parse model:'" + modelIn + "'. Only " + supportedModels() + " are supported "); - } - - /** - * Returns a string representing the supported models. - * - * @return A string representing the supported models. - */ - private static String supportedModels() { - return Arrays.stream(OpenAIModel.values()).map(o -> o.modelName).collect(Collectors.joining(", ")); - } - - /** - * Returns the minimum interval between calls for the model. - * - * @return The minimum interval between calls for the model. - */ - public long minIntervalBetweenCalls() { - return 60000 / apiPerMinute; - } - -} diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java index f319fea6a3cc..e851c9b8f871 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java @@ -1,6 +1,8 @@ package com.dotcms.ai.util; import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.exception.DotRuntimeException; @@ -19,6 +21,7 @@ import java.io.BufferedInputStream; import java.io.ByteArrayOutputStream; import java.io.OutputStream; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; /** @@ -30,62 +33,10 @@ */ public class OpenAIRequest { - private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>(); private OpenAIRequest() {} - /** - * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. - * The response from the request is returned as a string. - * - * @param url the URL to send the request to - * @param method the HTTP method to use for the request - * @param openAiAPIKey the OpenAI API key to use for the request - * @param json the JSON payload to send with the request - * @return the response from the request as a string - */ - public static String doRequest(final String url, - final String method, - final String openAiAPIKey, - final JSONObject json) { - final ByteArrayOutputStream out = new ByteArrayOutputStream(); - doRequest(url, method, openAiAPIKey, json, out); - - return out.toString(); - } - - /** - * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param openAiAPIKey the OpenAI API key to use for the request - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doPost(final String urlIn, - final String openAiAPIKey, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.POST, openAiAPIKey, json, out); - } - - /** - * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload. - * The response from the request is written to the provided OutputStream. - * - * @param urlIn the URL to send the request to - * @param openAiAPIKey the OpenAI API key to use for the request - * @param json the JSON payload to send with the request - * @param out the OutputStream to write the response to - */ - public static void doGet(final String urlIn, - final String openAiAPIKey, - final JSONObject json, - final OutputStream out) { - doRequest(urlIn, HttpMethod.GET, openAiAPIKey,json,out); - } - /** * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. * The response from the request is written to the provided OutputStream. @@ -93,21 +44,23 @@ public static void doGet(final String urlIn, * * @param urlIn the URL to send the request to * @param method the HTTP method to use for the request - * @param openAiAPIKey the OpenAI API key to use for the request - * @param json the JSON payload to send with the request + * @param appConfig the AppConfig object containing the OpenAI API key and models + * @param payload the JSON payload to send with the request * @param out the OutputStream to write the response to */ public static void doRequest(final String urlIn, final String method, - final String openAiAPIKey, - final JSONObject json, + final AppConfig appConfig, + final JSONObject payload, final OutputStream out) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.debug(OpenAIRequest.class, "posting:" + json); + final JSONObject json = Optional.ofNullable(payload).orElse(new JSONObject()); + + if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { + Logger.debug(OpenAIRequest.class, "posting: " + json); } - final OpenAIModel model = OpenAIModel.resolveModel(json.optString(AiKeys.MODEL)); + final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) + model.minIntervalBetweenCalls() - System.currentTimeMillis(); @@ -115,9 +68,9 @@ public static void doRequest(final String urlIn, Logger.info( OpenAIRequest.class, "Rate limit:" - + model.apiPerMinute + + model.getApiPerMinute() + "/minute, or 1 every " - + (60000 / model.apiPerMinute) + + model.minIntervalBetweenCalls() + "ms. Sleeping:" + sleep); Try.run(() -> Thread.sleep(sleep)); @@ -129,7 +82,7 @@ public static void doRequest(final String urlIn, final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); final HttpUriRequest httpRequest = resolveMethod(method, urlIn); httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); - httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + openAiAPIKey); + httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); if (!json.getAsMap().isEmpty()) { Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); @@ -157,6 +110,58 @@ public static void doRequest(final String urlIn, } } + /** + * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload. + * The response from the request is returned as a string. + * + * @param url the URL to send the request to + * @param method the HTTP method to use for the request + * @param appConfig the AppConfig object containing the OpenAI API key and models + * @param payload the JSON payload to send with the request + * @return the response from the request as a string + */ + public static String doRequest(final String url, + final String method, + final AppConfig appConfig, + final JSONObject payload) { + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + doRequest(url, method, appConfig, payload, out); + + return out.toString(); + } + + /** + * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload. + * The response from the request is written to the provided OutputStream. + * + * @param urlIn the URL to send the request to + * @param appConfig the AppConfig object containing the OpenAI API key and models + * @param payload the JSON payload to send with the request + * @param out the OutputStream to write the response to + */ + public static void doPost(final String urlIn, + final AppConfig appConfig, + final JSONObject payload, + final OutputStream out) { + doRequest(urlIn, HttpMethod.POST, appConfig, payload, out); + } + + /** + * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload. + * The response from the request is written to the provided OutputStream. + * + * @param urlIn the URL to send the request to + * @param appConfig the AppConfig object containing the OpenAI API key and models + * @param payload the JSON payload to send with the request + * @param out the OutputStream to write the response to + */ + public static void doGet(final String urlIn, + final AppConfig appConfig, + final JSONObject payload, + final OutputStream out) { + doRequest(urlIn, HttpMethod.GET, appConfig, payload, out); + } + private static HttpUriRequest resolveMethod(final String method, final String urlIn) { switch(method) { case HttpMethod.POST: diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java index 25dd1b0c74cd..03f73a37a8ec 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java @@ -49,8 +49,8 @@ public Map getConfig() { this.config.getConfig(AppKeys.COMPLETION_ROLE_PROMPT), AppKeys.COMPLETION_TEXT_PROMPT.key, this.config.getConfig(AppKeys.COMPLETION_TEXT_PROMPT), - AppKeys.MODEL.key, - this.config.getConfig(AppKeys.MODEL)); + AppKeys.TEXT_MODEL_NAMES.key, + this.config.getConfig(AppKeys.TEXT_MODEL_NAMES)); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java index 6754aa034204..aa2813d123d0 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java @@ -2,10 +2,8 @@ import com.dotcms.ai.api.EmbeddingsAPI; import com.dotcms.ai.app.AppConfig; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.util.EncodingUtil; -import com.dotcms.ai.util.OpenAIModel; import com.dotmarketing.beans.Host; import com.dotmarketing.business.web.WebAPILocator; import com.dotmarketing.util.Logger; @@ -54,7 +52,7 @@ public void init(Object initData) { */ public int countTokens(final String prompt) { return EncodingUtil.registry - .getEncodingForModel(appConfig.getModel()) + .getEncodingForModel(appConfig.getModel().getCurrentModel()) .map(encoding -> encoding.countTokens(prompt)) .orElse(-1); } @@ -69,9 +67,7 @@ public int countTokens(final String prompt) { */ public List generateEmbeddings(final String prompt) { int tokens = countTokens(prompt); - int maxTokens = OpenAIModel - .resolveModel(ConfigService.INSTANCE.config(host).getConfig(AppKeys.EMBEDDINGS_MODEL)) - .maxTokens; + int maxTokens = ConfigService.INSTANCE.config(host).getEmbeddingsModel().getMaxTokens(); if (tokens > maxTokens) { Logger.warn( EmbeddingsTool.class, diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java index f09a172cb937..c05689bccf5f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java @@ -45,7 +45,7 @@ public List getParameters() { limitTagsToHost, new WorkflowActionletParameter(OpenAIParams.RUN_DELAY.key, "Update the content asynchronously, after X seconds. O means run in-process", "5", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.MODEL), ConfigService.INSTANCE.config().getConfig(AppKeys.MODEL), false), + new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), false), new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java index 9c8d403bb430..5e99b809d843 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java @@ -40,7 +40,7 @@ public List getParameters() { overwriteParameter, new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true), new WorkflowActionletParameter(OpenAIParams.RUN_DELAY.key, "Update the content asynchronously, after X seconds. O means run in-process", "5", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.MODEL), ConfigService.INSTANCE.config().getConfig(AppKeys.MODEL), false), + new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), false), new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/analytics/AnalyticsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/analytics/AnalyticsAPIImpl.java index 2adcd87badbb..62cbdc3f5f99 100644 --- a/dotCMS/src/main/java/com/dotcms/analytics/AnalyticsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/analytics/AnalyticsAPIImpl.java @@ -202,8 +202,9 @@ public void resetAnalyticsKey(final AnalyticsApp analyticsApp, final boolean for Logger.info( this, String.format( - "For clientId %s found this ANALYTICS_KEY response:\n%s", + "For clientId %s found this ANALYTICS_KEY response:%s%s", analyticsApp.getAnalyticsProperties().clientId(), + System.lineSeparator(), DotObjectMapperProvider.getInstance().getDefaultObjectMapper().writeValueAsString(response))); AnalyticsHelper.get().extractAnalyticsKey(response) @@ -278,7 +279,7 @@ private void validateAnalyticsApp(final AnalyticsApp analyticsApp) { * @param analyticsApp analytics app */ private void logTokenResponse(final CircuitBreakerUrl.Response response, AnalyticsApp analyticsApp) { - if (AnalyticsHelper.get().isSuccessResponse(response)) { + if (CircuitBreakerUrl.isSuccessResponse(response)) { return; } @@ -340,7 +341,7 @@ private String prepareRequestData(final AnalyticsApp analyticsApp) { private void logKeyResponse(final CircuitBreakerUrl.Response response, final AnalyticsApp analyticsApp) { - if (AnalyticsHelper.get().isSuccessResponse(response)) { + if (CircuitBreakerUrl.isSuccessResponse(response)) { return; } @@ -382,10 +383,7 @@ private CircuitBreakerUrl.Response requestAnalyticsKey(final Analy * @return map representation of http headers */ private Map analyticsKeyHeaders(final AccessToken accessToken) throws AnalyticsException { - return ImmutableMap.builder() - .put(HttpHeaders.AUTHORIZATION, AnalyticsHelper.get().formatBearer(accessToken)) - .put(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON) - .build(); + return CircuitBreakerUrl.authHeaders(AnalyticsHelper.get().formatBearer(accessToken)); } } diff --git a/dotCMS/src/main/java/com/dotcms/analytics/helper/AnalyticsHelper.java b/dotCMS/src/main/java/com/dotcms/analytics/helper/AnalyticsHelper.java index 32a7b1ce4015..c62d3add78e4 100644 --- a/dotCMS/src/main/java/com/dotcms/analytics/helper/AnalyticsHelper.java +++ b/dotCMS/src/main/java/com/dotcms/analytics/helper/AnalyticsHelper.java @@ -25,8 +25,6 @@ import org.apache.commons.lang3.StringUtils; import org.apache.http.HttpStatus; -import javax.validation.constraints.NotNull; -import javax.ws.rs.core.Response; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; @@ -51,26 +49,6 @@ public static AnalyticsHelper get(){ private AnalyticsHelper() {} - /** - * Evaluates if a given status code instance has a http status within the SUCCESSFUL range. - * - * @param statusCode http status code - * @return true if the response http status is considered tobe successful, otherwise false - */ - public boolean isSuccessResponse(final int statusCode) { - return Response.Status.Family.familyOf(statusCode) == Response.Status.Family.SUCCESSFUL; - } - - /** - * Evaluates if a given status code instance has a http status within the SUCCESSFUL range. - * - * @param response http response representation - * @return true if the response http status is considered tobe successful, otherwise false - */ - public boolean isSuccessResponse(@NotNull final CircuitBreakerUrl.Response response) { - return isSuccessResponse(response.getStatusCode()); - } - /** * Given a {@link CircuitBreakerUrl.Response} instance, extracts JSON representing the token and * deserializes to {@link AccessToken}. @@ -251,7 +229,7 @@ public AnalyticsApp appFromHost(final Host host) { */ public void throwFromResponse(final CircuitBreakerUrl.Response response, final String message) throws AnalyticsException { - if (isSuccessResponse(response)) { + if (CircuitBreakerUrl.isSuccessResponse(response)) { return; } diff --git a/dotCMS/src/main/java/com/dotcms/http/CircuitBreakerUrl.java b/dotCMS/src/main/java/com/dotcms/http/CircuitBreakerUrl.java index 9c6b6ca91338..e074368ef21b 100644 --- a/dotCMS/src/main/java/com/dotcms/http/CircuitBreakerUrl.java +++ b/dotCMS/src/main/java/com/dotcms/http/CircuitBreakerUrl.java @@ -20,7 +20,6 @@ import org.apache.commons.io.IOUtils; import org.apache.http.Header; import org.apache.http.HttpResponse; -import org.apache.http.client.ClientProtocolException; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; @@ -35,6 +34,9 @@ import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; import javax.servlet.http.HttpServletResponse; +import javax.validation.constraints.NotNull; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -65,6 +67,13 @@ */ public class CircuitBreakerUrl { + private static final Lazy circuitBreakerMaxConnTotal = + Lazy.of(() -> Config.getIntProperty("CIRCUIT_BREAKER_MAX_CONN_TOTAL", 100)); + private static final Lazy allowAccessToPrivateSubnets = + Lazy.of(() -> Config.getBooleanProperty("ALLOW_ACCESS_TO_PRIVATE_SUBNETS", false)); + private static final CircuitBreakerConnectionControl circuitBreakerConnectionControl = + new CircuitBreakerConnectionControl(circuitBreakerMaxConnTotal.get()); + private final String proxyUrl; private final long timeoutMs; private final CircuitBreaker circuitBreaker; @@ -76,11 +85,12 @@ public class CircuitBreakerUrl { private final boolean allowRedirects; private final boolean throwWhenNot2xx; - private static final Lazy circuitBreakerMaxConnTotal = Lazy.of(()->Config.getIntProperty("CIRCUIT_BREAKER_MAX_CONN_TOTAL",100)); - private static final Lazy allowAccessToPrivateSubnets = Lazy.of(()->Config.getBooleanProperty("ALLOW_ACCESS_TO_PRIVATE_SUBNETS", false)); - private static final CircuitBreakerConnectionControl circuitBreakerConnectionControl = new CircuitBreakerConnectionControl(circuitBreakerMaxConnTotal.get()); + public static final Response EMPTY_RESPONSE = new Response<>(StringPool.BLANK, 0, new Header[] {}); + + public enum Method { + GET, POST, PUT, DELETE, PATCH + } - public static final Response EMPTY_RESPONSE = new Response<>(StringPool.BLANK, 0, new Header[]{}); /** * * @param proxyUrl @@ -118,13 +128,12 @@ public CircuitBreakerUrl(final String proxyUrl, timeoutMs, circuitBreaker, new HttpGet(proxyUrl), - ImmutableMap.of(), - ImmutableMap.of(), + Map.of(), + Map.of(), verbose, null); } - - + @VisibleForTesting public CircuitBreakerUrl(final String proxyUrl, final long timeoutMs, @@ -135,23 +144,8 @@ public CircuitBreakerUrl(final String proxyUrl, final boolean verbose, final String rawData) { this(proxyUrl, timeoutMs, circuitBreaker, request, params, headers, verbose, rawData, false, true); - } - /** - * Full featured constructor - * - * @param proxyUrl - * @param timeoutMs - * @param circuitBreaker - * @param request - * @param params - * @param headers - * @param verbose - * @param rawData - * @param allowRedirects - * @param throwWhenNot2xx - */ @VisibleForTesting public CircuitBreakerUrl(final String proxyUrl, final long timeoutMs, @@ -207,9 +201,10 @@ public String doString() throws IOException { Logger.warn( this, String.format( - "Invalid response detected when consuming [%s] with http status [%d] and response:\n%s", + "Invalid response detected when consuming [%s] with http status [%d] and response:%s%s", this.proxyUrl, this.response, + System.lineSeparator(), output)); } return output; @@ -234,7 +229,7 @@ public boolean isReady() { @Override public void setWriteListener(WriteListener writeListener) { - + // no-op } }; } @@ -305,42 +300,18 @@ public void doOut(final HttpServletResponse response) throws IOException { } } - public static boolean isWithin2xx(final int response) { - return response >= 200 && response <= 299; - } - - private void copyHeaders(final HttpResponse innerResponse, final HttpServletResponse response) { - final Header contentTypeHeader = innerResponse.getFirstHeader("Content-Type"); - - if (UtilMethods.isSet(contentTypeHeader)) { - response.setHeader(contentTypeHeader.getName(), contentTypeHeader.getValue()); - } - - final Header contentLengthHeader = innerResponse.getFirstHeader("Content-Length"); - - if (UtilMethods.isSet(contentLengthHeader)) { - response.setHeader(contentLengthHeader.getName(), contentLengthHeader.getValue()); - } - } - public int response() { - return this.response; + return this.response; } - public static CircuitBreakerUrlBuilder builder() { - return new CircuitBreakerUrlBuilder(); - } - - - @Override - public String toString() { - return "CircuitBreakerUrl [proxyUrl=" + proxyUrl + ", timeoutMs=" + timeoutMs + ", circuitBreaker=" + circuitBreaker + "]"; + public static boolean isWithin2xx(final int response) { + return response >= 200 && response <= 299; } public T doObject(final Class clazz) { return Try.of(() -> DotObjectMapperProvider.getInstance().getDefaultObjectMapper().readValue(doString(), clazz)) - .onFailure(e -> Logger.warnAndDebug(CircuitBreakerUrl.class, e)) - .getOrElse((T) null); + .onFailure(e -> Logger.warnAndDebug(CircuitBreakerUrl.class, e)) + .getOrElse((T) null); } public Response doResponse(final Class clazz) { @@ -364,9 +335,54 @@ public Header[] getResponseHeaders() { return responseHeaders; } - public enum Method { - GET, POST, PUT, DELETE, PATCH; + @Override + public String toString() { + return "CircuitBreakerUrl [proxyUrl=" + proxyUrl + ", timeoutMs=" + timeoutMs + ", circuitBreaker=" + circuitBreaker + "]"; + } + + private void copyHeaders(final HttpResponse innerResponse, final HttpServletResponse response) { + final Header contentTypeHeader = innerResponse.getFirstHeader("Content-Type"); + + if (UtilMethods.isSet(contentTypeHeader)) { + response.setHeader(contentTypeHeader.getName(), contentTypeHeader.getValue()); + } + final Header contentLengthHeader = innerResponse.getFirstHeader("Content-Length"); + + if (UtilMethods.isSet(contentLengthHeader)) { + response.setHeader(contentLengthHeader.getName(), contentLengthHeader.getValue()); + } + } + + public static Map authHeaders(final String token) { + return ImmutableMap.builder() + .put(HttpHeaders.AUTHORIZATION, token) + .put(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON) + .build(); + } + + /** + * Evaluates if a given status code instance has a http status within the SUCCESSFUL range. + * + * @param statusCode http status code + * @return true if the response http status is considered tobe successful, otherwise false + */ + public static boolean isSuccessResponse(final int statusCode) { + return javax.ws.rs.core.Response.Status.Family.familyOf(statusCode) == javax.ws.rs.core.Response.Status.Family.SUCCESSFUL; + } + + /** + * Evaluates if a given status code instance has a http status within the SUCCESSFUL range. + * + * @param response http response representation + * @return true if the response http status is considered tobe successful, otherwise false + */ + public static boolean isSuccessResponse(@NotNull final CircuitBreakerUrl.Response response) { + return isSuccessResponse(response.getStatusCode()); + } + + public static CircuitBreakerUrlBuilder builder() { + return new CircuitBreakerUrlBuilder(); } public static class CircuitBreakerConnectionControl { @@ -381,7 +397,6 @@ public CircuitBreakerConnectionControl(final int maxConnTotal) { public void check(final String proxyUrl) { if (threadIdConnectionCountSet.size() >= maxConnTotal) { - Logger.info(this, "The maximum number of connections has been reached, size: " + threadIdConnectionCountSet.size() + ", url: " + proxyUrl); throw new RejectedExecutionException("The maximum number of connections has been reached."); @@ -389,14 +404,13 @@ public void check(final String proxyUrl) { } public void start(final long id) { - threadIdConnectionCountSet.add(id); } public void end(final long id) { - threadIdConnectionCountSet.remove(id); } + } public static class Response implements Serializable { @@ -447,6 +461,7 @@ public String toString() { ", statusCode=" + statusCode + '}'; } + } } diff --git a/dotCMS/src/main/java/com/dotcms/rest/api/v1/system/ConfigurationHelper.java b/dotCMS/src/main/java/com/dotcms/rest/api/v1/system/ConfigurationHelper.java index 6f8f83862eae..b35147c981a5 100644 --- a/dotCMS/src/main/java/com/dotcms/rest/api/v1/system/ConfigurationHelper.java +++ b/dotCMS/src/main/java/com/dotcms/rest/api/v1/system/ConfigurationHelper.java @@ -7,7 +7,6 @@ import com.dotcms.enterprise.cluster.ClusterFactory; import com.dotcms.enterprise.license.LicenseLevel; import com.dotcms.enterprise.license.LicenseManager; -import com.dotcms.util.CollectionsUtils; import com.dotmarketing.business.APILocator; import com.dotmarketing.util.Config; import com.dotmarketing.util.Constants; diff --git a/dotCMS/src/main/java/com/dotcms/system/event/local/business/LocalSystemEventSubscribersInitializer.java b/dotCMS/src/main/java/com/dotcms/system/event/local/business/LocalSystemEventSubscribersInitializer.java index c30029e2f858..335b77c7f29f 100644 --- a/dotCMS/src/main/java/com/dotcms/system/event/local/business/LocalSystemEventSubscribersInitializer.java +++ b/dotCMS/src/main/java/com/dotcms/system/event/local/business/LocalSystemEventSubscribersInitializer.java @@ -1,5 +1,6 @@ package com.dotcms.system.event.local.business; +import com.dotcms.ai.listener.AIAppListener; import com.dotcms.analytics.listener.AnalyticsAppListener; import com.dotcms.config.DotInitializer; import com.dotcms.content.elasticsearch.business.event.ContentletCheckinEvent; @@ -68,6 +69,7 @@ public void notify(final ChangeLoggerLevelEvent event) { APILocator.getLocalSystemEventsAPI().subscribe(APILocator.getContainerAPI()); APILocator.getLocalSystemEventsAPI().subscribe(AppSecretSavedEvent.class, AnalyticsAppListener.Instance.get()); + APILocator.getLocalSystemEventsAPI().subscribe(AppSecretSavedEvent.class, AIAppListener.Instance.get()); this.initDotVelocityMacrosVtlFiles(); } diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml index c48eb25b26e0..e51d99bfd4c5 100644 --- a/dotCMS/src/main/resources/apps/dotAI.yml +++ b/dotCMS/src/main/resources/apps/dotAI.yml @@ -61,19 +61,82 @@ params: value: "1920x1080" - label: "256x256 (Small Square 1:1)" value: "256x256" - model: - value: "gpt-3.5-turbo-16k" + textModelNames: + value: "" + hidden: false + type: "STRING" + label: "Model Names" + hint: "Comma delimited list of models used to generate OpenAI API response." + required: true + textModelTokensPerMinute: + value: "180000" + hidden: false + type: "STRING" + label: "Tokens per Minute" + hint: "Tokens per minute used to generate OpenAI API response." + required: false + textModelApiPerMinute: + value: "3500" + hidden: false + type: "STRING" + label: "API per Minute" + hint: "API per minute used to generate OpenAI API response." + required: false + textModelMaxTokens: + value: "16384" + hidden: false + type: "STRING" + label: "Max Tokens" + hint: "Maximum number of tokens used to generate OpenAI API response." + required: false + textModelCompletion: + value: "false" + hidden: false + type: "BOOL" + label: "Completion model enabled" + hint: "Enable completion model used to generate OpenAI API response." + required: false + imageModelNames: + value: "" hidden: false type: "STRING" - label: "Model" - hint: "Model used to generate ChatGPT API response." + label: "Image Model Names" + hint: "Comma delimited list of image models used to generate OpenAI API response." required: true - imageModel: - value: "dall-e-3" + imageModelTokensPerMinute: + value: "0" + hidden: false + type: "STRING" + label: "Image Tokens per Minute" + hint: "Tokens per minute used to generate OpenAI API response." + required: false + imageModelApiPerMinute: + value: "50" + hidden: false + type: "STRING" + label: "Image API per Minute" + hint: "API per minute used to generate OpenAI API response." + required: false + imageModelMaxTokens: + value: "0" + hidden: false + type: "STRING" + label: "Image Max Tokens" + hint: "Maximum number of tokens used to generate OpenAI API response." + required: false + imageModelCompletion: + value: "false" + hidden: false + type: "BOOL" + label: "Image Completion model enabled" + hint: "Enable completion model used to generate OpenAI API response." + required: false + embeddingsModelNames: + value: "" hidden: false type: "STRING" - label: "Image Model" - hint: "Image Model used to generate AI Images" + label: "Embeddings Model Names" + hint: "Comma delimited list of embeddings models used to generate OpenAI API response." required: true listenerIndexer: value: "" diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java new file mode 100644 index 000000000000..536427d9e043 --- /dev/null +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -0,0 +1,176 @@ +package com.dotcms.ai.app; + +import com.dotcms.security.apps.Secret; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for the \AIAppUtil\ class. This test class verifies the functionality + * of methods in \AIAppUtil\ such as discovering secrets, creating models, and resolving + * environment-specific secrets. It uses mock objects to simulate the \Secret\ dependencies. + * + * @author vico + */ +public class AIAppUtilTest { + + private AIAppUtil aiAppUtil; + private Map secrets; + private Secret secret; + + @Before + public void setUp() { + aiAppUtil = AIAppUtil.get(); + secrets = mock(Map.class); + secret = mock(Secret.class); + } + + /** + * Given a map of secrets containing a key with a secret value + * When the discoverSecret method is called with the key and a default value + * Then the secret value should be returned. + */ + @Test + public void testDiscoverSecretWithDefaultValue() { + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("secretValue"); + + String result = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY, "defaultValue"); + assertEquals("secretValue", result); + } + + /** + * Given a map of secrets not containing a key + * When the discoverSecret method is called with the key and a default value + * Then the default value should be returned. + */ + @Test + public void testDiscoverSecretWithDefaultValueNotFound() { + when(secrets.get("key")).thenReturn(null); + + String result = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY, "defaultValue"); + assertEquals("defaultValue", result); + } + + /** + * Given a map of secrets containing a key with a secret value + * When the discoverSecret method is called with the key + * Then the secret value should be returned. + */ + @Test + public void testDiscoverSecretWithKeyDefaultValue() { + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("secretValue"); + + String result = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY); + assertEquals("secretValue", result); + } + + /** + * Given a map of secrets not containing a key + * When the discoverSecret method is called with the key + * Then the default value of the key should be returned. + */ + @Test + public void testDiscoverSecretWithKeyDefaultValueNotFound() { + when(secrets.get("key")).thenReturn(null); + + String result = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY); + assertEquals(AppKeys.API_KEY.defaultValue, result); + } + + /** + * Given a map of secrets containing a key with an environment secret value + * When the discoverEnvSecret method is called with the key + * Then the environment secret value should be returned. + */ + @Test + public void testDiscoverEnvSecret() { + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("envSecretValue"); + + String result = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY); + assertEquals("envSecretValue", result); + } + + /** + * Given a map of secrets containing a key with an integer secret value + * When the discoverIntSecret method is called with the key + * Then the integer secret value should be returned. + */ + @Test + public void testDiscoverIntSecret() { + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("123"); + + int result = aiAppUtil.discoverIntSecret(secrets, AppKeys.API_KEY); + assertEquals(123, result); + } + + /** + * Given a map of secrets containing a key with a boolean secret value + * When the discoverBooleanSecret method is called with the key + * Then the boolean secret value should be returned. + */ + @Test + public void testDiscoverBooleanSecret() { + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("true"); + + boolean result = aiAppUtil.discoverBooleanSecret(secrets, AppKeys.API_KEY); + assertTrue(result); + } + + /** + * Given a map of secrets containing a key with a text model name + * When the createTextModel method is called + * Then an AIModel instance should be created with the specified type and model name. + */ + @Test + public void testCreateTextModel() { + when(secrets.get(AppKeys.TEXT_MODEL_NAMES.key)).thenReturn(secret); + when(secret.getString()).thenReturn("textModel"); + + AIModel model = aiAppUtil.createTextModel(secrets); + assertNotNull(model); + assertEquals(AIModelType.TEXT, model.getType()); + assertTrue(model.getNames().contains("textModel")); + } + + /** + * Given a map of secrets containing a key with an image model name + * When the createImageModel method is called + * Then an AIModel instance should be created with the specified type and model name. + */ + @Test + public void testCreateImageModel() { + when(secrets.get(AppKeys.IMAGE_MODEL_NAMES.key)).thenReturn(secret); + when(secret.getString()).thenReturn("imageModel"); + + AIModel model = aiAppUtil.createImageModel(secrets); + assertNotNull(model); + assertEquals(AIModelType.IMAGE, model.getType()); + assertTrue(model.getNames().contains("imageModel")); + } + + /** + * Given a map of secrets containing a key with an embeddings model name + * When the createEmbeddingsModel method is called + * Then an AIModel instance should be created with the specified type and model name. + */ + @Test + public void testCreateEmbeddingsModel() { + when(secrets.get(AppKeys.EMBEDDINGS_MODEL_NAMES.key)).thenReturn(secret); + when(secret.getString()).thenReturn("embeddingsModel"); + + AIModel model = aiAppUtil.createEmbeddingsModel(secrets); + assertNotNull(model); + assertEquals(AIModelType.EMBEDDINGS, model.getType()); + assertTrue(model.getNames().contains("embeddingsModel")); + } + +} \ No newline at end of file diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java index 996098bc51b4..e110608dbb2c 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java @@ -1,5 +1,7 @@ package com.dotcms.ai.service; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotmarketing.util.json.JSONObject; @@ -52,14 +54,15 @@ public void test_sendTextPrompt() { private OpenAIChatService prepareService(final String response) { return new OpenAIChatServiceImpl(config) { @Override - String doRequest(final String urlIn, final String openAiAPIKey, final JSONObject json) { + String doRequest(final String urlIn, final JSONObject json) { return response; } }; } private JSONObject prepareJsonObject(final String prompt) { - when(config.getModel()).thenReturn("some-model"); + when(config.getModel()) + .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withNames("some-model").build()); when(config.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)).thenReturn(123.321F); when(config.getRolePrompt()).thenReturn("some-role-prompt"); diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java index 0e7fef5054a2..1338b3110c74 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java @@ -1,5 +1,7 @@ package com.dotcms.ai.service; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.model.AIImageRequestDTO; import com.dotcms.ai.util.StopWordsUtil; @@ -199,7 +201,7 @@ private OpenAIImageService prepareService(final String response, final User user) { return new OpenAIImageServiceImpl(config, user, hostApi, tempFileApi) { @Override - String doRequest(final String urlIn, final String openAiAPIKey, final JSONObject json) { + String doRequest(final String urlIn, final JSONObject json) { return response; } @@ -216,7 +218,7 @@ AIImageRequestDTO.Builder getDtoBuilder() { } private JSONObject prepareJsonObject(final String prompt, final boolean tempFileError) throws Exception { - when(config.getImageModel()).thenReturn("some-image-model"); + when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withNames("some-image-model").build()); when(config.getImageSize()).thenReturn("some-image-size"); final File file = mock(File.class); when(file.getName()).thenReturn(UUIDGenerator.shorty()); diff --git a/dotCMS/src/test/java/com/dotcms/analytics/helper/AnalyticsHelperTest.java b/dotCMS/src/test/java/com/dotcms/analytics/helper/AnalyticsHelperTest.java index d8593326a51d..2c062cba0365 100644 --- a/dotCMS/src/test/java/com/dotcms/analytics/helper/AnalyticsHelperTest.java +++ b/dotCMS/src/test/java/com/dotcms/analytics/helper/AnalyticsHelperTest.java @@ -47,36 +47,6 @@ public void setup() { when(response.getStatusCode()).thenReturn(HttpStatus.SC_OK); } - /** - * Given an int status code - * Then evaluate it does have a SUCCESS http status - */ - @Test - public void test_isSuccessStatusCode() { - assertTrue(AnalyticsHelper.get().isSuccessResponse(HttpStatus.SC_ACCEPTED)); - assertTrue(AnalyticsHelper.get().isSuccessResponse(HttpStatus.SC_OK)); - assertFalse(AnalyticsHelper.get().isSuccessResponse(HttpStatus.SC_BAD_REQUEST)); - assertFalse(AnalyticsHelper.get().isSuccessResponse(HttpStatus.SC_FORBIDDEN)); - assertFalse(AnalyticsHelper.get().isSuccessResponse(HttpStatus.SC_INTERNAL_SERVER_ERROR)); - } - - /** - * Given a {@link Response} - * Then evaluate it does have a SUCCESS http status - */ - @Test - public void test_isSuccessResponse() { - assertTrue(AnalyticsHelper.get().isSuccessResponse(response)); - when(response.getStatusCode()).thenReturn(HttpStatus.SC_ACCEPTED); - assertTrue(AnalyticsHelper.get().isSuccessResponse(response)); - when(response.getStatusCode()).thenReturn(HttpStatus.SC_BAD_REQUEST); - assertFalse(AnalyticsHelper.get().isSuccessResponse(response)); - when(response.getStatusCode()).thenReturn(HttpStatus.SC_FORBIDDEN); - assertFalse(AnalyticsHelper.get().isSuccessResponse(response)); - when(response.getStatusCode()).thenReturn(HttpStatus.SC_INTERNAL_SERVER_ERROR); - assertFalse(AnalyticsHelper.get().isSuccessResponse(response)); - } - /** * Given an {@link Response} * Then verify that an {@link AccessToken} can be extracted as an entity diff --git a/dotcms-integration/pom.xml b/dotcms-integration/pom.xml index 58dcf3984df9..cf321d39bb74 100644 --- a/dotcms-integration/pom.xml +++ b/dotcms-integration/pom.xml @@ -411,7 +411,9 @@ ${test.webapp.root}/WEB-INF/velocity ${test.webapp.root}/WEB-INF/geoip2/GeoLite2-City.mmdb ${test.webapp.root}/WEB-INF/bin + http://localhost:50505/e http://localhost:50505/e + http://localhost:50505/m ${it.test.fork-folder}${surefire.forkNumber}/${test.temp.folder} diff --git a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java index 0097ec91cbf3..86304a1c8d95 100644 --- a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java +++ b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java @@ -1,5 +1,6 @@ package com.dotcms; +import com.dotcms.ai.app.AIModelsTest; import com.dotcms.ai.listener.EmbeddingContentListenerTest; import com.dotcms.ai.viewtool.AIViewToolTest; import com.dotcms.ai.viewtool.CompletionsToolTest; @@ -299,6 +300,7 @@ SearchToolTest.class, EmbeddingsToolTest.class, CompletionsToolTest.class, + AIModelsTest.class, TimeMachineAPITest.class, Task240513UpdateContentTypesSystemFieldTest.class, PruneTimeMachineBackupJobTest.class, diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java similarity index 82% rename from dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AiTest.java rename to dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 32391726afe4..fb529a7dc30f 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -1,10 +1,11 @@ -package com.dotcms.ai.viewtool; +package com.dotcms.ai; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.security.apps.Secret; import com.dotcms.security.apps.Type; import com.dotcms.util.WireMockTestHelper; +import com.dotmarketing.beans.Host; import com.github.tomakehurst.wiremock.WireMockServer; import java.util.HashMap; @@ -20,12 +21,12 @@ public interface AiTest { String IMAGE_SIZE = "1024x1024"; int PORT = 50505; - static AppConfig prepareConfig(final WireMockServer wireMockServer) { - return new AppConfig(appConfigMap(wireMockServer)); + static AppConfig prepareConfig(final Host host, final WireMockServer wireMockServer) { + return new AppConfig(host.getHostname(), appConfigMap(wireMockServer)); } - static AppConfig prepareCompletionConfig(final WireMockServer wireMockServer) { - return new AppConfig(completionAppConfigMap(appConfigMap(wireMockServer))); + static AppConfig prepareCompletionConfig(final Host host, final WireMockServer wireMockServer) { + return new AppConfig(host.getHostname(), completionAppConfigMap(appConfigMap(wireMockServer))); } static WireMockServer prepareWireMock() { @@ -49,10 +50,10 @@ private static Map completionAppConfigMap(final Map all = new HashMap<>(configMap); @@ -77,10 +78,10 @@ static Map appConfigMap(final WireMockServer wireMockServer) { AppKeys.API_KEY.key, Secret.builder().withType(Type.STRING).withValue(API_KEY.toCharArray()).build(), - AppKeys.MODEL.key, + AppKeys.TEXT_MODEL_NAMES.key, Secret.builder().withType(Type.STRING).withValue(MODEL.toCharArray()).build(), - AppKeys.IMAGE_MODEL.key, + AppKeys.IMAGE_MODEL_NAMES.key, Secret.builder().withType(Type.STRING).withValue(IMAGE_MODEL.toCharArray()).build(), AppKeys.IMAGE_SIZE.key, diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java new file mode 100644 index 000000000000..8a04ae9e5495 --- /dev/null +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -0,0 +1,199 @@ +package com.dotcms.ai.app; + +import com.dotcms.ai.AiTest; +import com.dotcms.datagen.SiteDataGen; +import com.dotcms.util.IntegrationTestInitService; +import com.dotcms.util.network.IPUtils; +import com.dotmarketing.beans.Host; +import com.github.tomakehurst.wiremock.WireMockServer; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +/** + * Integration tests for the \AIModels\ class. This test class verifies the functionality + * of methods in \AIModels\ such as loading models, finding models by host and type, and + * retrieving supported models. It uses \WireMockServer\ to simulate external dependencies + * and \IntegrationTestInitService\ for initializing the test environment. + * + * @author vico + */ +public class AIModelsTest { + + private static WireMockServer wireMockServer; + + private final AIModels aiModels = AIModels.get(); + private Host host; + private Host otherHost; + + @BeforeClass + public static void beforeClass() throws Exception { + IntegrationTestInitService.getInstance().init(); + IPUtils.disabledIpPrivateSubnet(true); + wireMockServer = AiTest.prepareWireMock(); + } + + @AfterClass + public static void afterClass() { + wireMockServer.stop(); + IPUtils.disabledIpPrivateSubnet(false); + } + + @Before + public void before() { + host = new SiteDataGen().nextPersisted(); + otherHost = new SiteDataGen().nextPersisted(); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the findModel method is called with various model names and types + * Then the correct models should be found and returned. + */ + @Test + public void test_loadModels_andFindThem() { + loadModels(); + + final String hostId = host.getHostname(); + final Optional notFound = aiModels.findModel(hostId, "some-invalid-model-name"); + assertTrue(notFound.isEmpty()); + + final Optional text1 = aiModels.findModel(hostId, "text-model-1"); + final Optional text2 = aiModels.findModel(hostId, "text-model-2"); + assertModels(text1, text2, AIModelType.TEXT); + + final Optional image1 = aiModels.findModel(hostId, "image-model-3"); + final Optional image2 = aiModels.findModel(hostId, "image-model-4"); + assertModels(image1, image2, AIModelType.IMAGE); + + final Optional embeddings1 = aiModels.findModel(hostId, "embeddings-model-5"); + assertTrue(embeddings1.isPresent()); + final Optional embeddings2 = aiModels.findModel(hostId, "embeddings-model-6"); + assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS); + + assertNotSame(text1.get(), image1.get()); + assertNotSame(text1.get(), embeddings1.get()); + assertNotSame(image1.get(), embeddings1.get()); + + final Optional text3 = aiModels.findModel(hostId, AIModelType.TEXT); + assertSameModels(text3, text1, text2); + + final Optional image3 = aiModels.findModel(hostId, AIModelType.IMAGE); + assertSameModels(image3, image1, image2); + + final Optional embeddings3 = aiModels.findModel(hostId, AIModelType.EMBEDDINGS); + assertSameModels(embeddings3, embeddings1, embeddings2); + + final Optional text4 = aiModels.findModel(otherHost.getHostname(), "text-model-1"); + assertTrue(text3.isPresent()); + assertNotSame(text1.get(), text4.get()); + } + + /** + * Given a set of models loaded into the AIModels instance + * When the resetModels method is called with a specific host + * Then the models for that host should be reset and no longer found. + */ + @Test + public void test_resetModels() { + loadModels(); + final Optional aiModel = aiModels.findModel(host.getHostname(), AIModelType.TEXT); + + aiModels.resetModels(host); + + assertNotSame(aiModel.get(), aiModels.findModel(host.getHostname(), AIModelType.TEXT)); + assertTrue(aiModels.findModel(host.getHostname(), "text-model-1").isEmpty()); + } + + /** + * Given a URL for supported models + * When the getOrPullSupportedModules method is called + * Then a list of supported models should be returned. + */ + @Test + public void test_getOrPullSupportedModules() { + final List supported = aiModels.getOrPullSupportedModels(); + assertNotNull(supported); + assertEquals(32, supported.size()); + } + + /** + * Given an invalid URL for supported models + * When the getOrPullSupportedModules method is called + * Then an empty list of supported models should be returned. + */ + @Test + public void test_getOrPullSupportedModules_invalidEndpoint() { + final List supported = aiModels.getOrPullSupportedModels(); + assertNotNull(supported); + assertTrue(supported.isEmpty()); + } + + private void loadModels() { + aiModels.loadModels( + host.getHostname(), + List.of( + AIModel.builder() + .withType(AIModelType.TEXT) + .withNames("text-model-1", "text-model-2") + .withTokensPerMinute(123) + .withApiPerMinute(321) + .withMaxTokens(555) + .withIsCompletion(true) + .build(), + AIModel.builder() + .withType(AIModelType.IMAGE) + .withNames("image-model-3", "image-model-4") + .withTokensPerMinute(111) + .withApiPerMinute(222) + .withMaxTokens(333) + .withIsCompletion(false) + .build(), + AIModel.builder() + .withType(AIModelType.EMBEDDINGS) + .withNames("embeddings-model-5", "embeddings-model-6") + .withTokensPerMinute(999) + .withApiPerMinute(888) + .withMaxTokens(777) + .withIsCompletion(false) + .build())); + aiModels.loadModels( + otherHost.getHostname(), + List.of( + AIModel.builder() + .withType(AIModelType.TEXT) + .withNames("text-model-1") + .withTokensPerMinute(123) + .withApiPerMinute(321) + .withMaxTokens(555) + .withIsCompletion(true) + .build())); + } + + private static void assertSameModels(Optional text3, Optional text1, Optional text2) { + assertTrue(text3.isPresent()); + assertSame(text1.get(), text3.get()); + assertSame(text2.get(), text3.get()); + } + + private static void assertModels(final Optional model1, + final Optional model2, + final AIModelType type) { + assertTrue(model1.isPresent()); + assertTrue(model2.isPresent()); + assertSame(model1.get(), model2.get()); + assertSame(type, model1.get().getType()); + assertSame(type, model2.get().getType()); + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java index e4c9567aa348..9f76e9da6f33 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java @@ -2,7 +2,7 @@ import com.dotcms.ai.api.EmbeddingsAPI; import com.dotcms.ai.app.AppKeys; -import com.dotcms.ai.viewtool.AiTest; +import com.dotcms.ai.AiTest; import com.dotcms.contenttype.business.ContentTypeAPI; import com.dotcms.contenttype.model.type.ContentType; import com.dotcms.datagen.TestDataUtils; diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java index 29341e19ddfa..0071d755f616 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/AIViewToolTest.java @@ -1,9 +1,11 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.AiTest; import com.dotcms.ai.app.AppConfig; import com.dotcms.datagen.UserDataGen; import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; +import com.dotmarketing.business.APILocator; import com.dotmarketing.util.json.JSONObject; import com.github.tomakehurst.wiremock.WireMockServer; import com.liferay.portal.model.User; @@ -43,7 +45,7 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); - config = AiTest.prepareConfig(wireMockServer); + config = AiTest.prepareConfig(APILocator.systemHost(), wireMockServer); } @AfterClass diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java index c99cca9a6a4d..a8769c973d5d 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/CompletionsToolTest.java @@ -1,5 +1,6 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.AiTest; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.datagen.EmbeddingsDTODataGen; @@ -39,16 +40,17 @@ public class CompletionsToolTest { private static AppConfig config; private static WireMockServer wireMockServer; + private static Host host; - private Host host; private CompletionsTool completionsTool; @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); + host = new SiteDataGen().nextPersisted(); wireMockServer = AiTest.prepareWireMock(); - config = AiTest.prepareCompletionConfig(wireMockServer); + config = AiTest.prepareCompletionConfig(host, wireMockServer); } @AfterClass @@ -61,7 +63,7 @@ public static void afterClass() { public void before() { final ViewContext viewContext = mock(ViewContext.class); when(viewContext.getRequest()).thenReturn(mock(HttpServletRequest.class)); - host = new SiteDataGen().nextPersisted(); + completionsTool = prepareCompletionsTool(viewContext); } @@ -80,7 +82,7 @@ public void test_getConfig() { assertNotNull(config); assertEquals(AppKeys.COMPLETION_ROLE_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_ROLE_PROMPT.key)); assertEquals(AppKeys.COMPLETION_TEXT_PROMPT.defaultValue, config.get(AppKeys.COMPLETION_TEXT_PROMPT.key)); - assertEquals(AppKeys.MODEL.defaultValue, config.get(AppKeys.MODEL.key)); + assertEquals(AppKeys.TEXT_MODEL_NAMES.defaultValue, config.get(AppKeys.TEXT_MODEL_NAMES.key)); } /** diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java index 3e7d009c132f..22e352b60a38 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/viewtool/EmbeddingsToolTest.java @@ -1,5 +1,7 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.datagen.EmbeddingsDTODataGen; import com.dotcms.datagen.SiteDataGen; @@ -21,6 +23,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +/** + * Integration tests for the \EmbeddingsTool\ class. This test class verifies the functionality + * of methods in \EmbeddingsTool\ such as counting tokens, generating embeddings, and + * retrieving index counts. It uses mock objects to simulate the \ViewContext\ and + * \AppConfig\ dependencies. + * + * @author vico + */ public class EmbeddingsToolTest { private Host host; @@ -107,9 +117,10 @@ AppConfig appConfig() { } private AppConfig prepareAppConfig() { - final AppConfig appConfig = mock(AppConfig.class); - when(appConfig.getModel()).thenReturn("gpt-3.5-turbo-16k"); - return appConfig; + final AppConfig config = mock(AppConfig.class); + final AIModel aiModel = AIModel.builder().withType(AIModelType.TEXT).withNames("gpt-3.5-turbo-16k").build(); + when(config.getModel()).thenReturn(aiModel); + return config; } } diff --git a/dotcms-integration/src/test/java/com/dotcms/http/CircuitBreakerUrlTest.java b/dotcms-integration/src/test/java/com/dotcms/http/CircuitBreakerUrlTest.java index 2877cd65b4c5..6fcbacc07602 100644 --- a/dotcms-integration/src/test/java/com/dotcms/http/CircuitBreakerUrlTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/http/CircuitBreakerUrlTest.java @@ -17,7 +17,9 @@ import com.dotmarketing.util.DateUtil; import org.apache.commons.io.output.NullOutputStream; +import org.apache.http.HttpStatus; import org.junit.Assert; +import org.junit.Before; import org.junit.Ignore; import org.junit.Test; @@ -29,6 +31,13 @@ import net.jodah.failsafe.CircuitBreaker; import net.jodah.failsafe.CircuitBreakerOpenException; +import javax.ws.rs.core.Response; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class CircuitBreakerUrlTest { @@ -43,6 +52,14 @@ public class CircuitBreakerUrlTest { final static String PARAM="X-MY-PARAM"; final static String PARAM_VALUE="PARAM SEEMS TO BE WORKING"; + private CircuitBreakerUrl.Response response; + + @Before + public void setup() { + response = mock(CircuitBreakerUrl.Response.class); + when(response.getStatusCode()).thenReturn(HttpStatus.SC_OK); + } + @Test() public void test_circuitBreakerConnectionControl() { @@ -423,4 +440,34 @@ public void testBadRequest_dontThrow() throws Exception { assert (cburl.response() >= 400 && cburl.response() <= 499); } + /** + * Given an int status code + * Then evaluate it does have a SUCCESS http status + */ + @Test + public void test_isSuccessStatusCode() { + assertTrue(CircuitBreakerUrl.isSuccessResponse(HttpStatus.SC_ACCEPTED)); + assertTrue(CircuitBreakerUrl.isSuccessResponse(HttpStatus.SC_OK)); + assertFalse(CircuitBreakerUrl.isSuccessResponse(HttpStatus.SC_BAD_REQUEST)); + assertFalse(CircuitBreakerUrl.isSuccessResponse(HttpStatus.SC_FORBIDDEN)); + assertFalse(CircuitBreakerUrl.isSuccessResponse(HttpStatus.SC_INTERNAL_SERVER_ERROR)); + } + + /** + * Given a {@link Response} + * Then evaluate it does have a SUCCESS http status + */ + @Test + public void test_isSuccessResponse() { + assertTrue(CircuitBreakerUrl.isSuccessResponse(response)); + when(response.getStatusCode()).thenReturn(HttpStatus.SC_ACCEPTED); + assertTrue(CircuitBreakerUrl.isSuccessResponse(response)); + when(response.getStatusCode()).thenReturn(HttpStatus.SC_BAD_REQUEST); + assertFalse(CircuitBreakerUrl.isSuccessResponse(response)); + when(response.getStatusCode()).thenReturn(HttpStatus.SC_FORBIDDEN); + assertFalse(CircuitBreakerUrl.isSuccessResponse(response)); + when(response.getStatusCode()).thenReturn(HttpStatus.SC_INTERNAL_SERVER_ERROR); + assertFalse(CircuitBreakerUrl.isSuccessResponse(response)); + } + } diff --git a/dotcms-integration/src/test/resources/mappings/openai-models.json b/dotcms-integration/src/test/resources/mappings/openai-models.json new file mode 100644 index 000000000000..9bf3d1ca8a0f --- /dev/null +++ b/dotcms-integration/src/test/resources/mappings/openai-models.json @@ -0,0 +1,206 @@ +{ + "request": { + "method": "GET", + "url": "/m" + }, + "response": { + "status": 200, + "jsonBody": { + "object": "list", + "data": [ + { + "id": "dall-e-3", + "object": "model", + "created": 1698785189, + "owned_by": "system" + }, + { + "id": "gpt-4-1106-preview", + "object": "model", + "created": 1698957206, + "owned_by": "system" + }, + { + "id": "dall-e-2", + "object": "model", + "created": 1698798177, + "owned_by": "system" + }, + { + "id": "gpt-4o", + "object": "model", + "created": 1715367049, + "owned_by": "system" + }, + { + "id": "tts-1-hd-1106", + "object": "model", + "created": 1699053533, + "owned_by": "system" + }, + { + "id": "tts-1-hd", + "object": "model", + "created": 1699046015, + "owned_by": "system" + }, + { + "id": "gpt-4-0125-preview", + "object": "model", + "created": 1706037612, + "owned_by": "system" + }, + { + "id": "babbage-002", + "object": "model", + "created": 1692634615, + "owned_by": "system" + }, + { + "id": "gpt-4-turbo-preview", + "object": "model", + "created": 1706037777, + "owned_by": "system" + }, + { + "id": "text-embedding-3-small", + "object": "model", + "created": 1705948997, + "owned_by": "system" + }, + { + "id": "text-embedding-3-large", + "object": "model", + "created": 1705953180, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-0613", + "object": "model", + "created": 1686587434, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo-instruct", + "object": "model", + "created": 1692901427, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-instruct-0914", + "object": "model", + "created": 1694122472, + "owned_by": "system" + }, + { + "id": "gpt-4o-mini", + "object": "model", + "created": 1721172741, + "owned_by": "system" + }, + { + "id": "gpt-4o-mini-2024-07-18", + "object": "model", + "created": 1721172717, + "owned_by": "system" + }, + { + "id": "whisper-1", + "object": "model", + "created": 1677532384, + "owned_by": "openai-internal" + }, + { + "id": "gpt-4o-2024-05-13", + "object": "model", + "created": 1715368132, + "owned_by": "system" + }, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal" + }, + { + "id": "gpt-3.5-turbo-16k", + "object": "model", + "created": 1683758102, + "owned_by": "openai-internal" + }, + { + "id": "davinci-002", + "object": "model", + "created": 1692634301, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-16k-0613", + "object": "model", + "created": 1685474247, + "owned_by": "openai" + }, + { + "id": "gpt-4-turbo-2024-04-09", + "object": "model", + "created": 1712601677, + "owned_by": "system" + }, + { + "id": "tts-1-1106", + "object": "model", + "created": 1699053241, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-0125", + "object": "model", + "created": 1706048358, + "owned_by": "system" + }, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "system" + }, + { + "id": "tts-1", + "object": "model", + "created": 1681940951, + "owned_by": "openai-internal" + }, + { + "id": "gpt-3.5-turbo-1106", + "object": "model", + "created": 1698959748, + "owned_by": "system" + }, + { + "id": "gpt-4-0613", + "object": "model", + "created": 1686588896, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo-0301", + "object": "model", + "created": 1677649963, + "owned_by": "openai" + }, + { + "id": "gpt-4", + "object": "model", + "created": 1687882411, + "owned_by": "openai" + } + ] + } + } +} \ No newline at end of file diff --git a/dotcms-postman/pom.xml b/dotcms-postman/pom.xml index 8c4487005f73..2d1962d864d9 100644 --- a/dotcms-postman/pom.xml +++ b/dotcms-postman/pom.xml @@ -137,7 +137,9 @@ http://wm:8080/c http://wm:8080/i ${wiremock.api.key} + http://wm:8080/e http://wm:8080/e + http://wm:8080/m true diff --git a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json index b80c4f376bc1..5843dece1d37 100644 --- a/dotcms-postman/src/main/resources/postman/AI.postman_collection.json +++ b/dotcms-postman/src/main/resources/postman/AI.postman_collection.json @@ -1,6 +1,6 @@ { "info": { - "_postman_id": "0d99325f-2e9a-49d7-a096-6770084ffa49", + "_postman_id": "7e9f91c0-35bf-4908-9f25-ba12d1dbf773", "name": "AI", "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json", "_exporter_id": "11174695" @@ -1010,13 +1010,14 @@ "pm.test('Emebeddings are created', function () {", " pm.expect(jsonData.indexName, 'Index name should be \"default\"').equals('default');", " pm.expect(parseInt(jsonData.timeToEmbeddings.split('ms')[0]), 'Time to embeddings must be greater than zero').greaterThan(0);", - " if (currentSeo.embedded) {", - " pm.expect(jsonData.totalToEmbed, 'Total to embed is greater than zero').greaterThan(0);", - " } else {", - " pm.expect(jsonData.totalToEmbed, 'Total to embed is greater than zero').equals(0);", - " }", "});", "", + "if (currentSeo.embedded) {", + " pm.expect(jsonData.totalToEmbed, 'Total to embed is greater than zero').greaterThan(0);", + "} else {", + " pm.expect(jsonData.totalToEmbed, 'Total to embed is zero').equals(0);", + "}", + "", "seoIndex++;", "pm.collectionVariables.set('seoIndex', seoIndex);", "console.log('New seoIndex', seoIndex);", @@ -1082,7 +1083,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"query\": \"+contentType:{{seoContentTypeVar}}\",\n \"fields\": \"seo\"\n}", + "raw": "{\n \"query\": \"+contentType:{{seoContentTypeVar}}\",\n \"fields\": \"seo\",\n \"model\": \"text-embedding-ada-002\"\n}", "options": { "raw": { "language": "json" @@ -1245,7 +1246,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"prompt\": \"{{seoText}}\"\n}", + "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"model\": \"text-embedding-ada-002\"\n}", "options": { "raw": { "language": "json" @@ -2724,7 +2725,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"query\": \"+contentType:{{seoContentTypeVar}}\",\n \"fields\": \"seo\"\n}", + "raw": "{\n \"query\": \"+contentType:{{seoContentTypeVar}}\",\n \"model\": \"text-embedding-ada-002\",\n \"fields\": \"seo\"\n}", "options": { "raw": { "language": "json" @@ -2959,7 +2960,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1\n}", + "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"model\": \"text-embedding-ada-002\",\n \"responseLengthTokens\": 1\n}", "options": { "raw": { "language": "json" @@ -3035,7 +3036,7 @@ ], "body": { "mode": "raw", - "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1,\n \"stream\": true\n}", + "raw": "{\n \"prompt\": \"{{seoText}}\",\n \"responseLengthTokens\": 1,\n \"model\": \"text-embedding-ada-002\",\n \"stream\": true\n}", "options": { "raw": { "language": "json" @@ -3269,14 +3270,11 @@ " pm.expect(jsonData).to.have.property(\"apiUrl\");", " pm.expect(jsonData).to.have.property(\"availableModels\");", " pm.expect(jsonData.availableModels).to.be.an(\"array\");", - " const expectedModels = [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\", \"gpt-4\", \"gpt-4-1106-preview\", \"gpt-4-turbo-preview\"];", - " expectedModels.forEach(function(model) {", - " pm.expect(jsonData.availableModels).to.include(model);", - " });", + " pm.expect(jsonData.availableModels.length).is.greaterThan(0);", " pm.expect(jsonData[\"com.dotcms.ai.completion.default.temperature\"]).to.equal(\"1\");", " pm.expect(jsonData[\"com.dotcms.ai.debug.logging\"]).to.equal(\"false\");", - " pm.expect(jsonData[\"com.dotcms.ai.embeddings.model\"]).to.equal(\"text-embedding-ada-002\");", - " pm.expect(jsonData.imageModel).to.equal(\"dall-e-3\");", + " pm.expect(jsonData.embeddingsModelNames).to.equal(\"text-embedding-ada-002\");", + " pm.expect(jsonData.imageModelNames).to.equal(\"dall-e-3\");", " pm.expect(jsonData.textPrompt).to.include(\"Descriptive writing style\");", " pm.expect(jsonData.rolePrompt).to.include(\"dotCMSbot\");", " pm.expect(jsonData.apiImageUrl).to.match(/^https?:\\/\\/.+/);", diff --git a/dotcms-postman/src/test/resources/mappings/openai-models.json b/dotcms-postman/src/test/resources/mappings/openai-models.json new file mode 100644 index 000000000000..9bf3d1ca8a0f --- /dev/null +++ b/dotcms-postman/src/test/resources/mappings/openai-models.json @@ -0,0 +1,206 @@ +{ + "request": { + "method": "GET", + "url": "/m" + }, + "response": { + "status": 200, + "jsonBody": { + "object": "list", + "data": [ + { + "id": "dall-e-3", + "object": "model", + "created": 1698785189, + "owned_by": "system" + }, + { + "id": "gpt-4-1106-preview", + "object": "model", + "created": 1698957206, + "owned_by": "system" + }, + { + "id": "dall-e-2", + "object": "model", + "created": 1698798177, + "owned_by": "system" + }, + { + "id": "gpt-4o", + "object": "model", + "created": 1715367049, + "owned_by": "system" + }, + { + "id": "tts-1-hd-1106", + "object": "model", + "created": 1699053533, + "owned_by": "system" + }, + { + "id": "tts-1-hd", + "object": "model", + "created": 1699046015, + "owned_by": "system" + }, + { + "id": "gpt-4-0125-preview", + "object": "model", + "created": 1706037612, + "owned_by": "system" + }, + { + "id": "babbage-002", + "object": "model", + "created": 1692634615, + "owned_by": "system" + }, + { + "id": "gpt-4-turbo-preview", + "object": "model", + "created": 1706037777, + "owned_by": "system" + }, + { + "id": "text-embedding-3-small", + "object": "model", + "created": 1705948997, + "owned_by": "system" + }, + { + "id": "text-embedding-3-large", + "object": "model", + "created": 1705953180, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-0613", + "object": "model", + "created": 1686587434, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo-instruct", + "object": "model", + "created": 1692901427, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-instruct-0914", + "object": "model", + "created": 1694122472, + "owned_by": "system" + }, + { + "id": "gpt-4o-mini", + "object": "model", + "created": 1721172741, + "owned_by": "system" + }, + { + "id": "gpt-4o-mini-2024-07-18", + "object": "model", + "created": 1721172717, + "owned_by": "system" + }, + { + "id": "whisper-1", + "object": "model", + "created": 1677532384, + "owned_by": "openai-internal" + }, + { + "id": "gpt-4o-2024-05-13", + "object": "model", + "created": 1715368132, + "owned_by": "system" + }, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal" + }, + { + "id": "gpt-3.5-turbo-16k", + "object": "model", + "created": 1683758102, + "owned_by": "openai-internal" + }, + { + "id": "davinci-002", + "object": "model", + "created": 1692634301, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-16k-0613", + "object": "model", + "created": 1685474247, + "owned_by": "openai" + }, + { + "id": "gpt-4-turbo-2024-04-09", + "object": "model", + "created": 1712601677, + "owned_by": "system" + }, + { + "id": "tts-1-1106", + "object": "model", + "created": 1699053241, + "owned_by": "system" + }, + { + "id": "gpt-3.5-turbo-0125", + "object": "model", + "created": 1706048358, + "owned_by": "system" + }, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "system" + }, + { + "id": "tts-1", + "object": "model", + "created": 1681940951, + "owned_by": "openai-internal" + }, + { + "id": "gpt-3.5-turbo-1106", + "object": "model", + "created": 1698959748, + "owned_by": "system" + }, + { + "id": "gpt-4-0613", + "object": "model", + "created": 1686588896, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo-0301", + "object": "model", + "created": 1677649963, + "owned_by": "openai" + }, + { + "id": "gpt-4", + "object": "model", + "created": 1687882411, + "owned_by": "openai" + } + ] + } + } +} \ No newline at end of file