diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java index 63f58056609a..a4f6d2c8fb12 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java @@ -1,6 +1,8 @@ package com.dotcms.ai.app; +import com.dotcms.security.apps.AppsUtil; import com.dotcms.security.apps.Secret; +import com.dotmarketing.util.UtilMethods; import com.liferay.util.StringPool; import io.vavr.Lazy; import io.vavr.control.Try; @@ -141,6 +143,24 @@ public boolean discoverBooleanSecret(final Map secrets, final Ap return Boolean.parseBoolean(discoverSecret(secrets, key)); } + /** + * Resolves a secret value from the provided secrets map using the specified key and environment variable. + * If the secret is not found in the secrets map, it attempts to discover the value from the environment variable. + * + * @param secrets the map of secrets + * @param key the key to look up the secret + * @param envVar the environment variable name to look up if the secret is not found in the secrets map + * @return the resolved secret value or the value from the environment variable if the secret is not found + */ + public String discoverEnvSecret(final Map secrets, final AppKeys key, final String envVar) { + return Optional + .ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, envVar)) + .orElseGet(() -> { + final String secret = discoverSecret(secrets, key); + return UtilMethods.isSet(secret) ? secret : null; + }); + } + private int toInt(final String value) { return Try.of(() -> Integer.parseInt(value)).getOrElse(0); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index d84e2ff86728..efbcc09a0872 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -103,6 +103,7 @@ public String getCurrentModel() { logInvalidModelMessage(); return null; } + return names.get(currentIndex); } @@ -113,11 +114,14 @@ public long minIntervalBetweenCalls() { @Override public String toString() { return "AIModel{" + - "name='" + names + '\'' + + "type=" + type + + ", names=" + names + ", tokensPerMinute=" + tokensPerMinute + ", apiPerMinute=" + apiPerMinute + ", maxTokens=" + maxTokens + ", isCompletion=" + isCompletion + + ", current=" + current + + ", decommissioned=" + decommissioned + '}'; } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 8f88e214d9ca..2a9ff3ba0577 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -2,7 +2,9 @@ import com.dotcms.ai.model.OpenAIModel; import com.dotcms.ai.model.OpenAIModels; +import com.dotcms.ai.model.SimpleModel; import com.dotcms.http.CircuitBreakerUrl; +import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Config; import com.dotmarketing.util.Logger; import com.github.benmanes.caffeine.cache.Cache; @@ -15,8 +17,9 @@ import org.apache.commons.collections4.CollectionUtils; import java.time.Duration; -import java.util.HashSet; +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -47,7 +50,7 @@ public class AIModels { private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = + private final Cache> supportedModelsCache = Caffeine.newBuilder() .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) .maximumSize(AI_MODELS_CACHE_SIZE) @@ -105,7 +108,11 @@ public void loadModels(final String host, final List loading) { * @return an Optional containing the found AIModel, or an empty Optional if not found */ public Optional findModel(final String host, final String modelName) { - return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase()))); + final String lowered = modelName.toLowerCase(); + final Set supported = getOrPullSupportedModels(); + return supported.contains(lowered) + ? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered))) + : Optional.empty(); } /** @@ -144,29 +151,32 @@ public void resetModels(final String host) { * Retrieves the list of supported models, either from the cache or by fetching them * from an external source if the cache is empty or expired. * - * @return a list of supported model names + * @return a set of supported model names */ - public List getOrPullSupportedModels() { - final List cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); + public Set getOrPullSupportedModels() { + final Set cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); if (CollectionUtils.isNotEmpty(cached)) { return cached; } final AppConfig appConfig = appConfigSupplier.get(); if (!appConfig.isEnabled()) { - Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models"); - return List.of(); + AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty list of supported models"); + throw new DotRuntimeException("App dotAI config without API urls or API key"); } - final List supported = Try.of(() -> - fetchOpenAIModels(appConfig) - .getResponse() - .getData() - .stream() - .map(OpenAIModel::getId) - .map(String::toLowerCase) - .collect(Collectors.toList())) - .getOrElse(Optional.ofNullable(cached).orElse(List.of())); + final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig); + if (Objects.nonNull(response.getResponse().getError())) { + throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage()); + } + + final Set supported = response + .getResponse() + .getData() + .stream() + .map(OpenAIModel::getId) + .map(String::toLowerCase) + .collect(Collectors.toSet()); supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); return supported; @@ -177,25 +187,30 @@ public List getOrPullSupportedModels() { * * @return a list of available model names */ - public List getAvailableModels() { - final Set configured = internalModels.entrySet().stream().flatMap(entry -> entry.getValue().stream()) + public List getAvailableModels() { + final Set configured = internalModels.entrySet() + .stream() + .flatMap(entry -> entry.getValue().stream()) .map(Tuple2::_2) - .flatMap(model -> model.getNames().stream()) + .flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType()))) + .collect(Collectors.toSet()); + final Set supported = getOrPullSupportedModels() + .stream() + .map(SimpleModel::new) .collect(Collectors.toSet()); - final Set supported = new HashSet<>(getOrPullSupportedModels()); configured.retainAll(supported); - return configured.stream().sorted().collect(Collectors.toList()); + + return new ArrayList<>(configured); } private static CircuitBreakerUrl.Response fetchOpenAIModels(final AppConfig appConfig) { - final CircuitBreakerUrl.Response response = CircuitBreakerUrl.builder() .setMethod(CircuitBreakerUrl.Method.GET) .setUrl(OPEN_AI_MODELS_URL) .setTimeout(AI_MODELS_FETCH_TIMEOUT) .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) - .setThrowWhenNot2xx(false) + .setThrowWhenNot2xx(true) .build() .doResponse(OpenAIModels.class); @@ -206,6 +221,7 @@ private static CircuitBreakerUrl.Response fetchOpenAIModels(final "Error fetching OpenAI supported models from [%s] (status code: [%d])", OPEN_AI_MODELS_URL, response.getStatusCode())); + throw new DotRuntimeException("Error fetching OpenAI supported models"); } return response; diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index 630ca4138a1e..5dc503f97272 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -4,15 +4,20 @@ import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; +import com.liferay.util.StringPool; import io.vavr.control.Try; import org.apache.commons.lang3.StringUtils; import java.io.Serializable; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * The AppConfig class provides a configuration for the AI application. @@ -20,8 +25,15 @@ */ public class AppConfig implements Serializable { + private static final String AI_API_KEY_KEY = "AI_API_KEY"; + private static final String AI_API_URL_KEY = "AI_API_URL"; + private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL"; + private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL"; + private static final String SYSTEM_HOST = "System Host"; public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?"); + private static final AtomicReference SYSTEM_HOST_CONFIG = new AtomicReference<>(); + private final String host; private final String apiKey; private final transient AIModel model; @@ -39,12 +51,15 @@ public class AppConfig implements Serializable { public AppConfig(final String host, final Map secrets) { this.host = host; + if (SYSTEM_HOST.equalsIgnoreCase(host)) { + setSystemHostConfig(this); + } final AIAppUtil aiAppUtil = AIAppUtil.get(); - apiKey = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY); - apiUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_URL); - apiImageUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_IMAGE_URL); - apiEmbeddingsUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_EMBEDDINGS_URL); + apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY, AI_API_KEY_KEY); + apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL, AI_API_URL_KEY); + apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL, AI_IMAGE_API_URL_KEY); + apiEmbeddingsUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL, AI_EMBEDDINGS_API_URL_KEY); if (!secrets.isEmpty() || isEnabled()) { AIModels.get().loadModels( @@ -67,18 +82,36 @@ public AppConfig(final String host, final Map secrets) { configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - Logger.debug(getClass(), () -> "apiKey: " + apiKey); - Logger.debug(getClass(), () -> "apiUrl: " + apiUrl); - Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl); - Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl); - Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt); - Logger.debug(getClass(), () -> "textPrompt: " + textPrompt); - Logger.debug(getClass(), () -> "model: " + model); - Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt); - Logger.debug(getClass(), () -> "imageModel: " + imageModel); - Logger.debug(getClass(), () -> "imageSize: " + imageSize); - Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel); - Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer); + Logger.debug(this, this::toString); + } + + /** + * Retrieves the system host configuration. + * + * @return the system host configuration + */ + public static AppConfig getSystemHostConfig() { + if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) { + setSystemHostConfig(ConfigService.INSTANCE.config()); + } + return SYSTEM_HOST_CONFIG.get(); + } + + /** + * Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING} + * property instead of the usual Log4j configuration. + * + * @param clazz The {@link Class} to log the message for. + * @param message The {@link Supplier} with the message to log. + */ + public static void debugLogger(final Class clazz, final Supplier message) { + if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { + Logger.info(clazz, message.get()); + } + } + + public static void setSystemHostConfig(final AppConfig systemHostConfig) { + AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig); } /** @@ -282,20 +315,31 @@ public AIModel resolveModelOrThrow(final String modelName) { } /** - * Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING} - * property instead of the usual Log4j configuration. + * Checks if the configuration is enabled. * - * @param clazz The {@link Class} to log the message for. - * @param message The {@link Supplier} with the message to log. + * @return true if the configuration is enabled, false otherwise */ - public static void debugLogger(final Class clazz, final Supplier message) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(clazz, message.get()); - } + public boolean isEnabled() { + return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank); } - public boolean isEnabled() { - return StringUtils.isNotBlank(apiKey); + @Override + public String toString() { + return "AppConfig{\n" + + " host='" + host + "',\n" + + " apiKey='" + Optional.ofNullable(apiKey).map(key -> "*****").orElse(StringPool.BLANK) + "',\n" + + " model=" + model + "',\n" + + " imageModel=" + imageModel + "',\n" + + " embeddingsModel=" + embeddingsModel + "',\n" + + " apiUrl='" + apiUrl + "',\n" + + " apiImageUrl='" + apiImageUrl + "',\n" + + " apiEmbeddingsUrl='" + apiEmbeddingsUrl + "',\n" + + " rolePrompt='" + rolePrompt + "',\n" + + " textPrompt='" + textPrompt + "',\n" + + " imagePrompt='" + imagePrompt + "',\n" + + " imageSize='" + imageSize + "',\n" + + " listenerIndexer='" + listenerIndexer + "'\n" + + '}'; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java index 2f8bcdcf1c55..a40c57c959f8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java @@ -1,11 +1,13 @@ package com.dotcms.ai.app; +import com.liferay.util.StringPool; + public enum AppKeys { + API_KEY("apiKey", null), API_URL("apiUrl", "https://api.openai.com/v1/chat/completions"), API_IMAGE_URL("apiImageUrl", "https://api.openai.com/v1/images/generations"), API_EMBEDDINGS_URL("apiEmbeddingsUrl", "https://api.openai.com/v1/embeddings"), - API_KEY("apiKey", null), ROLE_PROMPT( "rolePrompt", "You are dotCMSbot, and AI assistant to help content" + @@ -22,12 +24,12 @@ public enum AppKeys { IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "0"), IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "50"), IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "0"), - IMAGE_MODEL_COMPLETION("imageModelCompletion", "false"), + IMAGE_MODEL_COMPLETION("imageModelCompletion", StringPool.FALSE), EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", null), EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000000"), EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "3000"), EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "8191"), - EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", "false"), + EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", StringPool.FALSE), EMBEDDINGS_SPLIT_AT_TOKENS("com.dotcms.ai.embeddings.split.at.tokens", "512"), EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX("com.dotcms.ai.embeddings.minimum.text.length", "64"), EMBEDDINGS_MINIMUM_FILE_SIZE_TO_INDEX("com.dotcms.ai.embeddings.minimum.file.size", "1024"), @@ -39,7 +41,7 @@ public enum AppKeys { EMBEDDINGS_CACHE_TTL_SECONDS("com.dotcms.ai.embeddings.cache.ttl.seconds", "600"), EMBEDDINGS_CACHE_SIZE("com.dotcms.ai.embeddings.cache.size", "1000"), EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"), - DEBUG_LOGGING("com.dotcms.ai.debug.logging", "false"), + DEBUG_LOGGING("com.dotcms.ai.debug.logging", StringPool.FALSE), COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"), COMPLETION_ROLE_PROMPT( "com.dotcms.ai.completion.role.prompt", diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index 24f7b2b21b7c..9739bab313eb 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -83,7 +83,10 @@ private AppConfig getAppConfig(final String hostId) { final AppConfig appConfig = ConfigService.INSTANCE.config(host); if (!appConfig.isEnabled()) { - throw new DotRuntimeException("No API key found in app config"); + AppConfig.debugLogger( + getClass(), + () -> "dotAI is not enabled since no API urls or API key found in app config"); + throw new DotRuntimeException("App dotAI config without API urls or API key"); } return appConfig; diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java new file mode 100644 index 000000000000..c5486b61191f --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java @@ -0,0 +1,53 @@ +package com.dotcms.ai.model; + +import com.dotcms.ai.app.AIModelType; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Represents a simple model with a name and type. + * This class is immutable and uses Jackson annotations for JSON serialization and deserialization. + * + * @author vico + */ +public class SimpleModel implements Serializable { + + private final String name; + private final AIModelType type; + + @JsonCreator + public SimpleModel(@JsonProperty("name") final String name, @JsonProperty("type") final AIModelType type) { + this.name = name; + this.type = type; + } + + @JsonCreator + public SimpleModel(@JsonProperty("name") final String name) { + this(name, null); + } + + public String getName() { + return name; + } + + public AIModelType getType() { + return type; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SimpleModel that = (SimpleModel) o; + return Objects.equals(name, that.name); + } + + @Override + public int hashCode() { + return Objects.hashCode(name); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index 98591c6502f3..e7b62cf46712 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -5,6 +5,7 @@ import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.model.SimpleModel; import com.dotcms.ai.rest.forms.CompletionsForm; import com.dotcms.ai.util.LineReadingOutputStream; import com.dotcms.rest.WebResource; @@ -118,7 +119,7 @@ public final Response getConfig(@Context final HttpServletRequest request, final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET"; map.put(AppKeys.API_KEY.key, apiKey); - final List models = AIModels.get().getAvailableModels(); + final List models = AIModels.get().getAvailableModels(); map.put(AiKeys.AVAILABLE_MODELS, models); return Response.ok(map).build(); diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java index daf29ec8b846..b2a9b9adf789 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java @@ -52,18 +52,23 @@ public static void doRequest(final String urlIn, final AppConfig appConfig, final JSONObject json, final OutputStream out) { + AppConfig.debugLogger( + OpenAIRequest.class, + () -> String.format( + "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", + urlIn, + method, + System.lineSeparator(), + appConfig.toString(), + System.lineSeparator(), + json.toString(2))); if (!appConfig.isEnabled()) { - Logger.debug(OpenAIRequest.class, "OpenAI is not enabled and will not send request."); - return; + AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request."); + throw new DotRuntimeException("App dotAI config without API urls or API key"); } final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); - - if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.debug(OpenAIRequest.class, "posting: " + json); - } - final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) + model.minIntervalBetweenCalls() - System.currentTimeMillis(); diff --git a/dotCMS/src/main/java/com/dotcms/security/apps/AppsUtil.java b/dotCMS/src/main/java/com/dotcms/security/apps/AppsUtil.java index 3a7b32847215..271d51e1c59b 100644 --- a/dotCMS/src/main/java/com/dotcms/security/apps/AppsUtil.java +++ b/dotCMS/src/main/java/com/dotcms/security/apps/AppsUtil.java @@ -676,7 +676,8 @@ private static String guessEnvVar(final String key, final String paramName) { private static String discoverEnvVarValue(final Supplier envVarSupplier, final String envVar) { return Optional .ofNullable(envVarSupplier.get()) - .map(discovered -> Config.getStringProperty(discovered, null)) + .map(supplied -> Config.getStringProperty(supplied, null)) + .or(() -> Optional.ofNullable(envVar).map(ev -> Config.getStringProperty(ev, null))) .or(() -> Optional.ofNullable(envVar).map(System::getenv)) .orElse(null); } diff --git a/dotCMS/src/main/java/com/liferay/util/StringPool.java b/dotCMS/src/main/java/com/liferay/util/StringPool.java index 80b5d2c29f57..478ef31f3dc6 100644 --- a/dotCMS/src/main/java/com/liferay/util/StringPool.java +++ b/dotCMS/src/main/java/com/liferay/util/StringPool.java @@ -89,4 +89,6 @@ public class StringPool { public static final String TRUE = Boolean.TRUE.toString(); + public static final String FALSE = Boolean.FALSE.toString(); + } diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml index 6ad3a46c6ed6..d23962e1e4f0 100644 --- a/dotCMS/src/main/resources/apps/dotAI.yml +++ b/dotCMS/src/main/resources/apps/dotAI.yml @@ -14,6 +14,13 @@ params: label: "API Key" hint: "Your ChatGPT API key" required: true + textModelNames: + value: "" + hidden: false + type: "STRING" + label: "Model Names" + hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-3.5-turbo-16k)" + required: true rolePrompt: value: "You are dotCMSbot, and AI assistant to help content creators generate and rewrite content in their content management system." hidden: false @@ -28,13 +35,6 @@ params: label: "Text Prompt" hint: "A prompt describing writing style." required: false - textModelNames: - value: "gpt-3.5-turbo-16k" - hidden: false - type: "STRING" - label: "Model Names" - hint: "Comma delimited list of models used to generate OpenAI API response." - required: true textModelTokensPerMinute: value: "180000" hidden: false @@ -63,6 +63,13 @@ params: label: "Completion model enabled" hint: "Enable completion model used to generate OpenAI API response." required: false + imageModelNames: + value: "" + hidden: false + type: "STRING" + label: "Image Model Names" + hint: "Comma delimited list of image models used to generate OpenAI API response(e.g. dall-e-3)." + required: true imagePrompt: value: "Use 16:9 aspect ratio." hidden: false @@ -96,13 +103,6 @@ params: value: "1920x1080" - label: "256x256 (Small Square 1:1)" value: "256x256" - imageModelNames: - value: "dall-e-3" - hidden: false - type: "STRING" - label: "Image Model Names" - hint: "Comma delimited list of image models used to generate OpenAI API response." - required: true imageModelTokensPerMinute: value: "0" hidden: false @@ -132,11 +132,11 @@ params: hint: "Enable completion model used to generate OpenAI API response." required: false embeddingsModelNames: - value: "text-embedding-ada-002" + value: "" hidden: false type: "STRING" label: "Embeddings Model Names" - hint: "Comma delimited list of embeddings models used to generate OpenAI API response." + hint: "Comma delimited list of embeddings models used to generate OpenAI API response (e.g. text-embedding-ada-002)." required: true embeddingsModelTokensPerMinute: value: "1000000" diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js index b76879dfd125..088436aef605 100644 --- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js +++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js @@ -131,10 +131,13 @@ const writeModelToDropdown = async () => { } for (i = 0; i < dotAiState.config.availableModels.length; i++) { + if (dotAiState.config.availableModels[i].type !== 'TEXT') { + continue; + } const newOption = document.createElement("option"); - newOption.value = dotAiState.config.availableModels[i]; - newOption.text = `${dotAiState.config.availableModels[i]}` + newOption.value = dotAiState.config.availableModels[i].name; + newOption.text = `${dotAiState.config.availableModels[i].name}` if (dotAiState.config.availableModels[i] === dotAiState.config.model) { newOption.selected = true; newOption.text = `${dotAiState.config.availableModels[i]} (default)` diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java index a95eca8a4d5c..c4d5c93b7627 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -162,4 +162,17 @@ public void testCreateEmbeddingsModel() { assertTrue(model.getNames().contains("embeddingsmodel")); } + @Test + public void testDiscoverEnvSecret() { + // Mock the secret value in the secrets map + when(secrets.get("apiKey")).thenReturn(secret); + when(secret.getString()).thenReturn("secretValue"); + + // Call the method with the key and environment variable + String result = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY, "ENV_API_KEY"); + + // Assert the expected outcome + assertEquals("secretValue", result); + } + } \ No newline at end of file diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 8b0fca114036..855f61ad4572 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -78,4 +78,8 @@ static Map aiAppSecrets(final WireMockServer wireMockServer, fin return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } + static void removeSecrets(final Host host) throws DotDataException, DotSecurityException { + APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser()); + } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index 2ea51fe91ab4..e08965e20843 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -7,10 +7,12 @@ import com.dotmarketing.beans.Host; import com.dotmarketing.business.APILocator; import com.dotmarketing.exception.DotDataException; +import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.exception.DotSecurityException; import com.dotmarketing.util.DateUtil; import com.github.tomakehurst.wiremock.WireMockServer; import io.vavr.control.Try; +import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -18,6 +20,7 @@ import java.util.List; import java.util.Optional; +import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -43,23 +46,27 @@ public class AIModelsTest { @BeforeClass public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); - IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); } @AfterClass public static void afterClass() { wireMockServer.stop(); - IPUtils.disabledIpPrivateSubnet(false); } @Before public void before() { + IPUtils.disabledIpPrivateSubnet(true); host = new SiteDataGen().nextPersisted(); otherHost = new SiteDataGen().nextPersisted(); List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(wireMockServer, host)).get()); } + @After + public void after() { + IPUtils.disabledIpPrivateSubnet(false); + } + /** * Given a set of models loaded into the AIModels instance * When the findModel method is called with various model names and types @@ -67,6 +74,7 @@ public void before() { */ @Test public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException { + AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); saveSecrets( host, "text-model-1,text-model-2", @@ -116,16 +124,15 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx final Optional text7 = aiModels.findModel(hostId, "text-model-7"); final Optional text8 = aiModels.findModel(hostId, "text-model-8"); - assertModels(text7, text8, AIModelType.TEXT); + assertNotPresentModels(text7, text8); final Optional image9 = aiModels.findModel(hostId, "image-model-9"); final Optional image10 = aiModels.findModel(hostId, "image-model-10"); - assertModels(image9, image10, AIModelType.IMAGE); + assertNotPresentModels(image9, image10); final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11"); - assertTrue(embeddings11.isPresent()); final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12"); - assertModels(embeddings11, embeddings12, AIModelType.EMBEDDINGS); + assertNotPresentModels(embeddings11, embeddings12); } /** @@ -138,13 +145,9 @@ public void test_getOrPullSupportedModules() throws DotDataException, DotSecurit AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); AIModels.get().cleanSupportedModelsCache(); - List supported = aiModels.getOrPullSupportedModels(); + Set supported = aiModels.getOrPullSupportedModels(); assertNotNull(supported); - assertEquals(32, supported.size()); - - supported = aiModels.getOrPullSupportedModels(); - assertNotNull(supported); - assertEquals(32, supported.size()); + assertEquals(38, supported.size()); AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } @@ -154,14 +157,13 @@ public void test_getOrPullSupportedModules() throws DotDataException, DotSecurit * When the getOrPullSupportedModules method is called * Then an empty list of supported models should be returned. */ - @Test - public void test_getOrPullSupportedModules_invalidEndpoint() { + @Test(expected = DotRuntimeException.class) + public void test_getOrPullSupportedModules_withNetworkError() { AIModels.get().cleanSupportedModelsCache(); IPUtils.disabledIpPrivateSubnet(false); - final List supported = aiModels.getOrPullSupportedModels(); - assertNotNull(supported); - assertTrue(supported.isEmpty()); + final Set supported = aiModels.getOrPullSupportedModels(); + assertSupported(supported); IPUtils.disabledIpPrivateSubnet(true); AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); @@ -172,14 +174,25 @@ public void test_getOrPullSupportedModules_invalidEndpoint() { * When the getOrPullSupportedModules method is called * Then an empty list of supported models should be returned. */ - @Test + @Test(expected = DotRuntimeException.class) public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, DotSecurityException { AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), null); AIModels.get().cleanSupportedModelsCache(); - final List supported = aiModels.getOrPullSupportedModels(); - assertNotNull(supported); - assertTrue(supported.isEmpty()); + aiModels.getOrPullSupportedModels(); + } + + /** + * Given no API key + * When the getOrPullSupportedModules method is called + * Then an empty list of supported models should be returned. + */ + @Test(expected = DotRuntimeException.class) + public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException { + AiTest.removeSecrets(APILocator.systemHost()); + + AIModels.get().cleanSupportedModelsCache(); + aiModels.getOrPullSupportedModels(); } private void saveSecrets(final Host host, @@ -206,4 +219,14 @@ private static void assertModels(final Optional model1, assertSame(type, model2.get().getType()); } + private static void assertNotPresentModels(final Optional model1, final Optional model2) { + assertTrue(model1.isEmpty()); + assertTrue(model2.isEmpty()); + } + + private static void assertSupported(Set supported) { + assertNotNull(supported); + assertTrue(supported.isEmpty()); + } + } diff --git a/dotcms-integration/src/test/resources/mappings/openai-models.json b/dotcms-integration/src/test/resources/mappings/openai-models.json index 9bf3d1ca8a0f..0d9ab6aa7a51 100644 --- a/dotcms-integration/src/test/resources/mappings/openai-models.json +++ b/dotcms-integration/src/test/resources/mappings/openai-models.json @@ -9,6 +9,36 @@ "object": "list", "data": [ { + "id": "text-model-1", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "text-model-2", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "image-model-3", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "image-model-4", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "embeddings-model-5", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "embeddings-model-6", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ "id": "dall-e-3", "object": "model", "created": 1698785189,