From d980f8bd419cf89397a0e98165ff13e494521985 Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Wed, 31 Jul 2024 14:51:01 -0600 Subject: [PATCH] #29281: adding a centralized OpenAI api-key validation procedure --- .../com/dotcms/ai/api/CompletionsAPIImpl.java | 2 +- .../java/com/dotcms/ai/app/AIAppUtil.java | 4 +- .../main/java/com/dotcms/ai/app/AIModel.java | 9 +++ .../main/java/com/dotcms/ai/app/AIModels.java | 60 ++++++++++++------- .../java/com/dotcms/ai/app/AppConfig.java | 28 +++++++-- .../ai/listener/EmbeddingContentListener.java | 2 +- .../com/dotcms/ai/model/OpenAIModels.java | 44 +++++++++++++- .../dotcms/ai/rest/forms/CompletionsForm.java | 2 +- .../com/dotcms/ai/util/OpenAIRequest.java | 31 +++++----- .../com/dotcms/ai/viewtool/AIViewTool.java | 5 +- .../ai/workflow/OpenAIAutoTagActionlet.java | 3 +- .../OpenAIContentPromptActionlet.java | 2 +- .../java/com/dotcms/ai/app/AIAppUtilTest.java | 18 +++++- .../src/test/java/com/dotcms/ai/AiTest.java | 4 +- .../java/com/dotcms/ai/app/AIModelsTest.java | 30 +++++++++- 15 files changed, 185 insertions(+), 59 deletions(-) diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index f52f874142c1..ca87f3e1e77a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -256,7 +256,7 @@ private JSONObject buildRequestJson(final CompletionsForm form) { final JSONObject json = new JSONObject(); json.put(AiKeys.MESSAGES, messages); - json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.TEXT_MODEL_NAMES)); + json.putIfAbsent(AiKeys.MODEL, config.get().getModel().getCurrentModel()); json.put(AiKeys.TEMPERATURE, form.temperature); json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens); json.put(AiKeys.STREAM, form.stream); diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java index 6feaaf24afba..dc47af0f87cd 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java @@ -42,7 +42,7 @@ public static AIAppUtil get() { public AIModel createTextModel(final Map secrets) { return AIModel.builder() .withType(AIModelType.TEXT) - .withNames(discoverSecret(secrets, AppKeys.TEXT_MODEL_NAMES)) + .withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES)) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS)) @@ -59,7 +59,7 @@ public AIModel createTextModel(final Map secrets) { public AIModel createImageModel(final Map secrets) { return AIModel.builder() .withType(AIModelType.IMAGE) - .withNames(discoverSecret(secrets, AppKeys.IMAGE_MODEL_NAMES)) + .withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES)) .withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE)) .withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE)) .withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS)) diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index 88b3ef6d58df..d84e2ff86728 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -18,6 +18,11 @@ */ public class AIModel { + public static final AIModel NOOP_MODEL = AIModel.builder() + .withType(AIModelType.UNKNOWN) + .withNames(List.of()) + .build(); + private final AIModelType type; private final List names; private final int tokensPerMinute; @@ -88,6 +93,10 @@ public void setDecommissioned(final boolean decommissioned) { this.decommissioned.set(decommissioned); } + public boolean isOperational() { + return this != NOOP_MODEL; + } + public String getCurrentModel() { final int currentIndex = this.current.get(); if (!isCurrentValid(currentIndex)) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 0773d0de5711..acde8b99932f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -8,6 +8,7 @@ import com.dotmarketing.util.Logger; import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; +import com.google.common.annotations.VisibleForTesting; import io.vavr.Lazy; import io.vavr.Tuple; import io.vavr.Tuple2; @@ -21,6 +22,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -44,11 +46,6 @@ public class AIModels { private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours private static final int AI_MODELS_CACHE_SIZE = 128; - public static final AIModel NOOP_MODEL = AIModel.builder() - .withType(AIModelType.UNKNOWN) - .withNames(List.of()) - .build(); - private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); private final Cache> supportedModelsCache = @@ -56,6 +53,7 @@ public class AIModels { .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) .maximumSize(AI_MODELS_CACHE_SIZE) .build(); + private Supplier appConfigSupplier = ConfigService.INSTANCE::config; public static AIModels get() { return INSTANCE.get(); @@ -154,24 +152,31 @@ public void resetModels(final Host host) { * @return a list of supported model names */ public List getOrPullSupportedModels() { - final List cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); - if (CollectionUtils.isNotEmpty(cached)) { - return cached; + synchronized (supportedModelsCache) { + final List cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); + if (CollectionUtils.isNotEmpty(cached)) { + return cached; + } + + final AppConfig appConfig = appConfigSupplier.get(); + if (!appConfig.isEnabled()) { + Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models"); + return List.of(); + } + + final List supported = Try.of(() -> + fetchOpenAIModels(appConfig) + .getResponse() + .getData() + .stream() + .map(OpenAIModel::getId) + .map(String::toLowerCase) + .collect(Collectors.toList())) + .getOrElse(Optional.ofNullable(cached).orElse(List.of())); + supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); + + return supported; } - - final AppConfig appConfig = ConfigService.INSTANCE.config(); - final List supported = Try.of(() -> - fetchOpenAIModels(appConfig) - .getResponse() - .getData() - .stream() - .map(OpenAIModel::getId) - .map(String::toLowerCase) - .collect(Collectors.toList())) - .getOrElse(Optional.ofNullable(cached).orElse(List.of())); - supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); - - return supported; } /** @@ -212,4 +217,15 @@ private static CircuitBreakerUrl.Response fetchOpenAIModels(final return response; } + + @VisibleForTesting + void setAppConfigSupplier(final Supplier appConfigSupplier) { + this.appConfigSupplier = appConfigSupplier; + } + + @VisibleForTesting + void cleanSupportedModelsCache() { + supportedModelsCache.invalidateAll(); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index d3a161daa746..0704258b7e9f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,5 +1,6 @@ package com.dotcms.ai.app; +import com.dotcms.ai.util.OpenAIRequest; import com.dotcms.security.apps.Secret; import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Config; @@ -25,13 +26,13 @@ public class AppConfig implements Serializable { public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?"); private final String host; + private final String apiKey; private final transient AIModel model; private final transient AIModel imageModel; private final transient AIModel embeddingsModel; private final String apiUrl; private final String apiImageUrl; private final String apiEmbeddingsUrl; - private final String apiKey; private final String rolePrompt; private final String textPrompt; private final String imagePrompt; @@ -43,6 +44,8 @@ public AppConfig(final String host, final Map secrets) { this.host = host; final AIAppUtil aiAppUtil = AIAppUtil.get(); + apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY); + AIModels.get().loadModels( this.host, List.of( @@ -57,7 +60,7 @@ public AppConfig(final String host, final Map secrets) { apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL); apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL); apiEmbeddingsUrl = discoverEmbeddingsApiUrl(secrets); - apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY); + rolePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.ROLE_PROMPT); textPrompt = aiAppUtil.discoverSecret(secrets, AppKeys.TEXT_PROMPT); imagePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.IMAGE_PROMPT); @@ -66,10 +69,10 @@ public AppConfig(final String host, final Map secrets) { configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + Logger.debug(getClass(), () -> "apiKey: " + apiKey); Logger.debug(getClass(), () -> "apiUrl: " + apiUrl); Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl); Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl); - Logger.debug(getClass(), () -> "apiKey: " + apiKey); Logger.debug(getClass(), () -> "model: " + model); Logger.debug(getClass(), () -> "imageModel: " + imageModel); Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel); @@ -251,7 +254,7 @@ public String getConfig(final AppKeys appKey) { * @param type the type of the model to find */ public AIModel resolveModel(final AIModelType type) { - return AIModels.get().findModel(host, type).orElse(AIModels.NOOP_MODEL); + return AIModels.get().findModel(host, type).orElse(AIModel.NOOP_MODEL); } /** @@ -260,13 +263,24 @@ public AIModel resolveModel(final AIModelType type) { * @param modelName the name of the model to find */ public AIModel resolveModelOrThrow(final String modelName) { - return AIModels.get() + final AIModel model = AIModels.get() .findModel(host, modelName) .orElseThrow(() -> { final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels()); return new DotRuntimeException( "Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported "); }); + + if (!model.isOperational()) { + Logger.debug( + OpenAIRequest.class, + String.format( + "Resolved model [%s] is not operational, avoiding its usage", + model.getCurrentModel())); + throw new DotRuntimeException(String.format("Model [%s] is not operational", model.getCurrentModel())); + } + + return model; } /** @@ -282,6 +296,10 @@ public static void debugLogger(final Class clazz, final Supplier mess } } + public boolean isEnabled() { + return StringUtils.isNotBlank(apiKey); + } + private String discoverEmbeddingsApiUrl(final Map secrets) { final String url = AIAppUtil.get().discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL); return StringUtils.isBlank(url) diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index 24dbbf1072f6..afbb22f19bd0 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -100,7 +100,7 @@ private AppConfig getAppConfig(final String hostId) { */ private JSONObject getConfigJson(final String hostId) { return Try - .of(() -> new JSONObject(getAppConfig(hostId).getConfig(AppKeys.LISTENER_INDEXER))) + .of(() -> new JSONObject(getAppConfig(hostId).getListenerIndexer())) .onFailure(e -> Logger.debug(getClass(), "error in json config from app: " + e.getMessage())) .getOrElse(new JSONObject()); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java index 1c851628489d..b98317557246 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java @@ -16,12 +16,15 @@ public class OpenAIModels implements Serializable { private final String object; private final List data; + private final OpenAIError error; @JsonCreator public OpenAIModels(@JsonProperty("object") final String object, - @JsonProperty("data") final List data) { + @JsonProperty("data") final List data, + @JsonProperty("error") final OpenAIError error) { this.object = object; this.data = data; + this.error = error; } public String getObject() { @@ -32,4 +35,43 @@ public List getData() { return data; } + public OpenAIError getError() { + return error; + } + + public static class OpenAIError { + + private final String message; + private final String type; + private final String param; + private final String code; + + @JsonCreator + public OpenAIError(@JsonProperty("object") final String message, + @JsonProperty("type") final String type, + @JsonProperty("param") final String param, + @JsonProperty("code") final String code) { + this.message = message; + this.type = type; + this.param = param; + this.code = code; + } + + public String getMessage() { + return message; + } + + public String getType() { + return type; + } + + public String getParam() { + return param; + } + + public String getCode() { + return code; + } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java index a6bbbbeeec81..f4eb199d4bf2 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java @@ -117,7 +117,7 @@ private CompletionsForm(final Builder builder) { } else { this.temperature = builder.temperature >= 2 ? 2 : builder.temperature; } - this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES); + this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getModel().getCurrentModel(); } private String validateBuilderQuery(final String query) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java index e851c9b8f871..daf29ec8b846 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java @@ -21,7 +21,6 @@ import java.io.BufferedInputStream; import java.io.ByteArrayOutputStream; import java.io.OutputStream; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; /** @@ -45,22 +44,26 @@ private OpenAIRequest() {} * @param urlIn the URL to send the request to * @param method the HTTP method to use for the request * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param payload the JSON payload to send with the request + * @param json the JSON payload to send with the request * @param out the OutputStream to write the response to */ public static void doRequest(final String urlIn, final String method, final AppConfig appConfig, - final JSONObject payload, + final JSONObject json, final OutputStream out) { - final JSONObject json = Optional.ofNullable(payload).orElse(new JSONObject()); + if (!appConfig.isEnabled()) { + Logger.debug(OpenAIRequest.class, "OpenAI is not enabled and will not send request."); + return; + } + + final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { Logger.debug(OpenAIRequest.class, "posting: " + json); } - final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) + model.minIntervalBetweenCalls() - System.currentTimeMillis(); @@ -117,15 +120,15 @@ public static void doRequest(final String urlIn, * @param url the URL to send the request to * @param method the HTTP method to use for the request * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param payload the JSON payload to send with the request + * @param json the JSON payload to send with the request * @return the response from the request as a string */ public static String doRequest(final String url, final String method, final AppConfig appConfig, - final JSONObject payload) { + final JSONObject json) { final ByteArrayOutputStream out = new ByteArrayOutputStream(); - doRequest(url, method, appConfig, payload, out); + doRequest(url, method, appConfig, json, out); return out.toString(); } @@ -136,14 +139,14 @@ public static String doRequest(final String url, * * @param urlIn the URL to send the request to * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param payload the JSON payload to send with the request + * @param json the JSON payload to send with the request * @param out the OutputStream to write the response to */ public static void doPost(final String urlIn, final AppConfig appConfig, - final JSONObject payload, + final JSONObject json, final OutputStream out) { - doRequest(urlIn, HttpMethod.POST, appConfig, payload, out); + doRequest(urlIn, HttpMethod.POST, appConfig, json, out); } /** @@ -152,14 +155,14 @@ public static void doPost(final String urlIn, * * @param urlIn the URL to send the request to * @param appConfig the AppConfig object containing the OpenAI API key and models - * @param payload the JSON payload to send with the request + * @param json the JSON payload to send with the request * @param out the OutputStream to write the response to */ public static void doGet(final String urlIn, final AppConfig appConfig, - final JSONObject payload, + final JSONObject json, final OutputStream out) { - doRequest(urlIn, HttpMethod.GET, appConfig, payload, out); + doRequest(urlIn, HttpMethod.GET, appConfig, json, out); } private static HttpUriRequest resolveMethod(final String method, final String urlIn) { diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java index a33483840625..7891263733e4 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java @@ -14,7 +14,6 @@ import com.liferay.portal.model.User; import com.liferay.portal.util.PortalUtil; import io.vavr.control.Try; -import org.apache.commons.lang3.StringUtils; import org.apache.velocity.tools.view.context.ViewContext; import org.apache.velocity.tools.view.tools.ViewTool; @@ -46,9 +45,7 @@ public void init(final Object obj) { * @return true if AI is enabled, false otherwise */ public boolean isAiEnabled() { - return Optional.ofNullable(config) - .map(appConfig -> StringUtils.isNotBlank(appConfig.getApiKey())) - .orElse(false); + return Optional.ofNullable(config).map(AppConfig::isEnabled).orElse(false); } /** diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java index 67819a5d80ac..bb0e93f3c6c9 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java @@ -1,6 +1,5 @@ package com.dotcms.ai.workflow; -import com.dotcms.ai.app.AppKeys; import com.dotcms.ai.app.ConfigService; import com.dotmarketing.portlets.workflows.actionlet.Actionlet; import com.dotmarketing.portlets.workflows.actionlet.WorkFlowActionlet; @@ -41,7 +40,7 @@ public List getParameters() { return List.of( overwriteParameter, limitTagsToHost, - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), false), + new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false) ); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java index 7c3a13ebdcfd..b6a14ab22d44 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java +++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java @@ -37,7 +37,7 @@ public List getParameters() { "
and the keys of the json object will be used to update the content fields.", "", false), overwriteParameter, new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true), - new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES), false), + new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false), new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false) ); } diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java index 0d7ce095668a..d3fb0fe28261 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -97,6 +97,20 @@ public void testDiscoverEnvSecret() { assertEquals("envSecretValue", result); } + /** + * Given a map of secrets containing a key with an environment secret value + * When the discoverEnvSecret method is called with the key + * Then the environment secret value should be returned. + */ + @Test + public void testDiscoverNotFoundEnvSecret() { + when(secrets.get("something-else")).thenReturn(secret); + when(secret.getString()).thenReturn("envSecretValue"); + + String result = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY); + assertEquals("", result); + } + /** * Given a map of secrets containing a key with an integer secret value * When the discoverIntSecret method is called with the key @@ -138,7 +152,7 @@ public void testCreateTextModel() { AIModel model = aiAppUtil.createTextModel(secrets); assertNotNull(model); assertEquals(AIModelType.TEXT, model.getType()); - assertTrue(model.getNames().contains("textModel")); + assertTrue(model.getNames().contains("textmodel")); } /** @@ -154,7 +168,7 @@ public void testCreateImageModel() { AIModel model = aiAppUtil.createImageModel(secrets); assertNotNull(model); assertEquals(AIModelType.IMAGE, model.getType()); - assertTrue(model.getNames().contains("imageModel")); + assertTrue(model.getNames().contains("imagemodel")); } /** diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index fb529a7dc30f..56619523c082 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -79,10 +79,10 @@ static Map appConfigMap(final WireMockServer wireMockServer) { Secret.builder().withType(Type.STRING).withValue(API_KEY.toCharArray()).build(), AppKeys.TEXT_MODEL_NAMES.key, - Secret.builder().withType(Type.STRING).withValue(MODEL.toCharArray()).build(), + Secret.builder().withType(Type.STRING).withValue(AppKeys.TEXT_MODEL_NAMES.defaultValue.toCharArray()).build(), AppKeys.IMAGE_MODEL_NAMES.key, - Secret.builder().withType(Type.STRING).withValue(IMAGE_MODEL.toCharArray()).build(), + Secret.builder().withType(Type.STRING).withValue(AppKeys.IMAGE_MODEL_NAMES.defaultValue.toCharArray()).build(), AppKeys.IMAGE_SIZE.key, Secret.builder().withType(Type.SELECT).withValue(IMAGE_SIZE.toCharArray()).build(), diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index deba9567ca88..b996e50e435c 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -5,6 +5,7 @@ import com.dotcms.util.IntegrationTestInitService; import com.dotcms.util.network.IPUtils; import com.dotmarketing.beans.Host; +import com.dotmarketing.business.APILocator; import com.github.tomakehurst.wiremock.WireMockServer; import org.junit.AfterClass; import org.junit.Before; @@ -31,6 +32,7 @@ public class AIModelsTest { private static WireMockServer wireMockServer; + private static AppConfig config; private final AIModels aiModels = AIModels.get(); private Host host; @@ -41,6 +43,7 @@ public static void beforeClass() throws Exception { IntegrationTestInitService.getInstance().init(); IPUtils.disabledIpPrivateSubnet(true); wireMockServer = AiTest.prepareWireMock(); + config = AiTest.prepareConfig(APILocator.systemHost(), wireMockServer); } @AfterClass @@ -122,9 +125,18 @@ public void test_resetModels() { */ @Test public void test_getOrPullSupportedModules() { - final List supported = aiModels.getOrPullSupportedModels(); + AIModels.get().cleanSupportedModelsCache(); + AIModels.get().setAppConfigSupplier(() -> config); + + List supported = aiModels.getOrPullSupportedModels(); + assertNotNull(supported); + assertEquals(32, supported.size()); + + supported = aiModels.getOrPullSupportedModels(); assertNotNull(supported); assertEquals(32, supported.size()); + + AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /** @@ -134,6 +146,8 @@ public void test_getOrPullSupportedModules() { */ @Test public void test_getOrPullSupportedModules_invalidEndpoint() { + AIModels.get().cleanSupportedModelsCache(); + AIModels.get().setAppConfigSupplier(() -> config); IPUtils.disabledIpPrivateSubnet(false); final List supported = aiModels.getOrPullSupportedModels(); @@ -141,6 +155,20 @@ public void test_getOrPullSupportedModules_invalidEndpoint() { assertTrue(supported.isEmpty()); IPUtils.disabledIpPrivateSubnet(true); + AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); + } + + /** + * Given no API key + * When the getOrPullSupportedModules method is called + * Then an empty list of supported models should be returned. + */ + @Test + public void test_getOrPullSupportedModules_noApiKey() { + AIModels.get().cleanSupportedModelsCache(); + final List supported = aiModels.getOrPullSupportedModels(); + assertNotNull(supported); + assertTrue(supported.isEmpty()); } private void loadModels() {