From 3b8be741cad9f80a0e82f902d4a5cc95fefb352c Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Tue, 13 Aug 2024 17:16:28 -0600 Subject: [PATCH] #29281: Applying feedback --- .../main/java/com/dotcms/ai/app/AIModel.java | 5 +- .../main/java/com/dotcms/ai/app/AIModels.java | 29 ++++--- .../java/com/dotcms/ai/app/AppConfig.java | 85 +++++++++++++------ .../main/java/com/dotcms/ai/app/AppKeys.java | 10 +-- .../com/dotcms/ai/util/OpenAIRequest.java | 17 ++-- .../java/com/liferay/util/StringPool.java | 2 + .../src/test/java/com/dotcms/ai/AiTest.java | 4 + .../java/com/dotcms/ai/app/AIModelsTest.java | 44 +++++++--- .../resources/mappings/openai-models.json | 30 +++++++ 9 files changed, 166 insertions(+), 60 deletions(-) diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index d84e2ff86728..495e68320278 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -113,11 +113,14 @@ public long minIntervalBetweenCalls() { @Override public String toString() { return "AIModel{" + - "name='" + names + '\'' + + "type=" + type + + ", names=" + names + ", tokensPerMinute=" + tokensPerMinute + ", apiPerMinute=" + apiPerMinute + ", maxTokens=" + maxTokens + ", isCompletion=" + isCompletion + + ", current=" + current + + ", decommissioned=" + decommissioned + '}'; } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 26e397f03237..2a9ff3ba0577 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -49,7 +50,7 @@ public class AIModels { private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = + private final Cache> supportedModelsCache = Caffeine.newBuilder() .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) .maximumSize(AI_MODELS_CACHE_SIZE) @@ -107,7 +108,11 @@ public void loadModels(final String host, final List loading) { * @return an Optional containing the found AIModel, or an empty Optional if not found */ public Optional findModel(final String host, final String modelName) { - return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase()))); + final String lowered = modelName.toLowerCase(); + final Set supported = getOrPullSupportedModels(); + return supported.contains(lowered) + ? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered))) + : Optional.empty(); } /** @@ -146,10 +151,10 @@ public void resetModels(final String host) { * Retrieves the list of supported models, either from the cache or by fetching them * from an external source if the cache is empty or expired. * - * @return a list of supported model names + * @return a set of supported model names */ - public List getOrPullSupportedModels() { - final List cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); + public Set getOrPullSupportedModels() { + final Set cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY); if (CollectionUtils.isNotEmpty(cached)) { return cached; } @@ -160,17 +165,18 @@ public List getOrPullSupportedModels() { throw new DotRuntimeException("App dotAI config without API urls or API key"); } - final CircuitBreakerUrl.Response response = Try - .of(() -> fetchOpenAIModels(appConfig)) - .getOrElseThrow(() -> new DotRuntimeException("Error fetching OpenAI supported models")); + final CircuitBreakerUrl.Response response = fetchOpenAIModels(appConfig); + if (Objects.nonNull(response.getResponse().getError())) { + throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage()); + } - final List supported = response + final Set supported = response .getResponse() .getData() .stream() .map(OpenAIModel::getId) .map(String::toLowerCase) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported); return supported; @@ -204,7 +210,7 @@ private static CircuitBreakerUrl.Response fetchOpenAIModels(final .setTimeout(AI_MODELS_FETCH_TIMEOUT) .setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS) .setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey())) - .setThrowWhenNot2xx(false) + .setThrowWhenNot2xx(true) .build() .doResponse(OpenAIModels.class); @@ -215,6 +221,7 @@ private static CircuitBreakerUrl.Response fetchOpenAIModels(final "Error fetching OpenAI supported models from [%s] (status code: [%d])", OPEN_AI_MODELS_URL, response.getStatusCode())); + throw new DotRuntimeException("Error fetching OpenAI supported models"); } return response; diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index fe1c7d067b05..5dc503f97272 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -4,12 +4,16 @@ import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; +import com.liferay.util.StringPool; import io.vavr.control.Try; import org.apache.commons.lang3.StringUtils; import java.io.Serializable; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -25,9 +29,11 @@ public class AppConfig implements Serializable { private static final String AI_API_URL_KEY = "AI_API_URL"; private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL"; private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL"; - + private static final String SYSTEM_HOST = "System Host"; public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?"); + private static final AtomicReference SYSTEM_HOST_CONFIG = new AtomicReference<>(); + private final String host; private final String apiKey; private final transient AIModel model; @@ -45,6 +51,9 @@ public class AppConfig implements Serializable { public AppConfig(final String host, final Map secrets) { this.host = host; + if (SYSTEM_HOST.equalsIgnoreCase(host)) { + setSystemHostConfig(this); + } final AIAppUtil aiAppUtil = AIAppUtil.get(); apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY, AI_API_KEY_KEY); @@ -73,18 +82,36 @@ public AppConfig(final String host, final Map secrets) { configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - Logger.debug(getClass(), () -> "apiKey: " + apiKey); - Logger.debug(getClass(), () -> "apiUrl: " + apiUrl); - Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl); - Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl); - Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt); - Logger.debug(getClass(), () -> "textPrompt: " + textPrompt); - Logger.debug(getClass(), () -> "model: " + model); - Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt); - Logger.debug(getClass(), () -> "imageModel: " + imageModel); - Logger.debug(getClass(), () -> "imageSize: " + imageSize); - Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel); - Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer); + Logger.debug(this, this::toString); + } + + /** + * Retrieves the system host configuration. + * + * @return the system host configuration + */ + public static AppConfig getSystemHostConfig() { + if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) { + setSystemHostConfig(ConfigService.INSTANCE.config()); + } + return SYSTEM_HOST_CONFIG.get(); + } + + /** + * Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING} + * property instead of the usual Log4j configuration. + * + * @param clazz The {@link Class} to log the message for. + * @param message The {@link Supplier} with the message to log. + */ + public static void debugLogger(final Class clazz, final Supplier message) { + if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { + Logger.info(clazz, message.get()); + } + } + + public static void setSystemHostConfig(final AppConfig systemHostConfig) { + AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig); } /** @@ -287,19 +314,6 @@ public AIModel resolveModelOrThrow(final String modelName) { return aiModel; } - /** - * Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING} - * property instead of the usual Log4j configuration. - * - * @param clazz The {@link Class} to log the message for. - * @param message The {@link Supplier} with the message to log. - */ - public static void debugLogger(final Class clazz, final Supplier message) { - if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.info(clazz, message.get()); - } - } - /** * Checks if the configuration is enabled. * @@ -309,4 +323,23 @@ public boolean isEnabled() { return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank); } + @Override + public String toString() { + return "AppConfig{\n" + + " host='" + host + "',\n" + + " apiKey='" + Optional.ofNullable(apiKey).map(key -> "*****").orElse(StringPool.BLANK) + "',\n" + + " model=" + model + "',\n" + + " imageModel=" + imageModel + "',\n" + + " embeddingsModel=" + embeddingsModel + "',\n" + + " apiUrl='" + apiUrl + "',\n" + + " apiImageUrl='" + apiImageUrl + "',\n" + + " apiEmbeddingsUrl='" + apiEmbeddingsUrl + "',\n" + + " rolePrompt='" + rolePrompt + "',\n" + + " textPrompt='" + textPrompt + "',\n" + + " imagePrompt='" + imagePrompt + "',\n" + + " imageSize='" + imageSize + "',\n" + + " listenerIndexer='" + listenerIndexer + "'\n" + + '}'; + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java index 7afdad1c380b..a40c57c959f8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java @@ -1,5 +1,7 @@ package com.dotcms.ai.app; +import com.liferay.util.StringPool; + public enum AppKeys { API_KEY("apiKey", null), @@ -22,12 +24,12 @@ public enum AppKeys { IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "0"), IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "50"), IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "0"), - IMAGE_MODEL_COMPLETION("imageModelCompletion", AppKeys.FALSE), + IMAGE_MODEL_COMPLETION("imageModelCompletion", StringPool.FALSE), EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", null), EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000000"), EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "3000"), EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "8191"), - EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", AppKeys.FALSE), + EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", StringPool.FALSE), EMBEDDINGS_SPLIT_AT_TOKENS("com.dotcms.ai.embeddings.split.at.tokens", "512"), EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX("com.dotcms.ai.embeddings.minimum.text.length", "64"), EMBEDDINGS_MINIMUM_FILE_SIZE_TO_INDEX("com.dotcms.ai.embeddings.minimum.file.size", "1024"), @@ -39,7 +41,7 @@ public enum AppKeys { EMBEDDINGS_CACHE_TTL_SECONDS("com.dotcms.ai.embeddings.cache.ttl.seconds", "600"), EMBEDDINGS_CACHE_SIZE("com.dotcms.ai.embeddings.cache.size", "1000"), EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"), - DEBUG_LOGGING("com.dotcms.ai.debug.logging", AppKeys.FALSE), + DEBUG_LOGGING("com.dotcms.ai.debug.logging", StringPool.FALSE), COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"), COMPLETION_ROLE_PROMPT( "com.dotcms.ai.completion.role.prompt", @@ -52,8 +54,6 @@ public enum AppKeys { AI_MODELS_CACHE_TTL("com.dotcms.ai.models.supported.ttl", "28800"), AI_MODELS_CACHE_SIZE("com.dotcms.ai.models.supported.size", "64"); - private static final String FALSE = "false"; - public static final String APP_KEY = "dotAI"; public final String key; diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java index d022156f15e3..b2a9b9adf789 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java @@ -52,18 +52,23 @@ public static void doRequest(final String urlIn, final AppConfig appConfig, final JSONObject json, final OutputStream out) { + AppConfig.debugLogger( + OpenAIRequest.class, + () -> String.format( + "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s", + urlIn, + method, + System.lineSeparator(), + appConfig.toString(), + System.lineSeparator(), + json.toString(2))); if (!appConfig.isEnabled()) { - AppConfig.debugLogger(OpenAIRequest.class, () -> "dotAI is not enabled and will not send request."); + AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request."); throw new DotRuntimeException("App dotAI config without API urls or API key"); } final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL)); - - if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { - Logger.debug(OpenAIRequest.class, "posting: " + json); - } - final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) + model.minIntervalBetweenCalls() - System.currentTimeMillis(); diff --git a/dotCMS/src/main/java/com/liferay/util/StringPool.java b/dotCMS/src/main/java/com/liferay/util/StringPool.java index 80b5d2c29f57..478ef31f3dc6 100644 --- a/dotCMS/src/main/java/com/liferay/util/StringPool.java +++ b/dotCMS/src/main/java/com/liferay/util/StringPool.java @@ -89,4 +89,6 @@ public class StringPool { public static final String TRUE = Boolean.TRUE.toString(); + public static final String FALSE = Boolean.FALSE.toString(); + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java index 8b0fca114036..855f61ad4572 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java @@ -78,4 +78,8 @@ static Map aiAppSecrets(final WireMockServer wireMockServer, fin return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL); } + static void removeSecrets(final Host host) throws DotDataException, DotSecurityException { + APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser()); + } + } diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index 15483a02ad4e..e08965e20843 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.Optional; +import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -51,7 +52,6 @@ public static void beforeClass() throws Exception { @AfterClass public static void afterClass() { wireMockServer.stop(); - IPUtils.disabledIpPrivateSubnet(false); } @Before @@ -74,6 +74,7 @@ public void after() { */ @Test public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException { + AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); saveSecrets( host, "text-model-1,text-model-2", @@ -123,16 +124,15 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx final Optional text7 = aiModels.findModel(hostId, "text-model-7"); final Optional text8 = aiModels.findModel(hostId, "text-model-8"); - assertModels(text7, text8, AIModelType.TEXT); + assertNotPresentModels(text7, text8); final Optional image9 = aiModels.findModel(hostId, "image-model-9"); final Optional image10 = aiModels.findModel(hostId, "image-model-10"); - assertModels(image9, image10, AIModelType.IMAGE); + assertNotPresentModels(image9, image10); final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11"); - assertTrue(embeddings11.isPresent()); final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12"); - assertModels(embeddings11, embeddings12, AIModelType.EMBEDDINGS); + assertNotPresentModels(embeddings11, embeddings12); } /** @@ -145,9 +145,9 @@ public void test_getOrPullSupportedModules() throws DotDataException, DotSecurit AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost()); AIModels.get().cleanSupportedModelsCache(); - List supported = aiModels.getOrPullSupportedModels(); + Set supported = aiModels.getOrPullSupportedModels(); assertNotNull(supported); - assertEquals(32, supported.size()); + assertEquals(38, supported.size()); AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } @@ -158,13 +158,12 @@ public void test_getOrPullSupportedModules() throws DotDataException, DotSecurit * Then an empty list of supported models should be returned. */ @Test(expected = DotRuntimeException.class) - public void test_getOrPullSupportedModules_invalidEndpoint() { + public void test_getOrPullSupportedModules_withNetworkError() { AIModels.get().cleanSupportedModelsCache(); IPUtils.disabledIpPrivateSubnet(false); - final List supported = aiModels.getOrPullSupportedModels(); - assertNotNull(supported); - assertTrue(supported.isEmpty()); + final Set supported = aiModels.getOrPullSupportedModels(); + assertSupported(supported); IPUtils.disabledIpPrivateSubnet(true); AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); @@ -183,6 +182,19 @@ public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, D aiModels.getOrPullSupportedModels(); } + /** + * Given no API key + * When the getOrPullSupportedModules method is called + * Then an empty list of supported models should be returned. + */ + @Test(expected = DotRuntimeException.class) + public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException { + AiTest.removeSecrets(APILocator.systemHost()); + + AIModels.get().cleanSupportedModelsCache(); + aiModels.getOrPullSupportedModels(); + } + private void saveSecrets(final Host host, final String textModels, final String imageModels, @@ -207,4 +219,14 @@ private static void assertModels(final Optional model1, assertSame(type, model2.get().getType()); } + private static void assertNotPresentModels(final Optional model1, final Optional model2) { + assertTrue(model1.isEmpty()); + assertTrue(model2.isEmpty()); + } + + private static void assertSupported(Set supported) { + assertNotNull(supported); + assertTrue(supported.isEmpty()); + } + } diff --git a/dotcms-integration/src/test/resources/mappings/openai-models.json b/dotcms-integration/src/test/resources/mappings/openai-models.json index 9bf3d1ca8a0f..0d9ab6aa7a51 100644 --- a/dotcms-integration/src/test/resources/mappings/openai-models.json +++ b/dotcms-integration/src/test/resources/mappings/openai-models.json @@ -9,6 +9,36 @@ "object": "list", "data": [ { + "id": "text-model-1", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "text-model-2", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "image-model-3", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "image-model-4", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "embeddings-model-5", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ + "id": "embeddings-model-6", + "object": "model", + "created": 1698785189, + "owned_by": "system" + },{ "id": "dall-e-3", "object": "model", "created": 1698785189,