Skip to content

Commit

Permalink
#29587: Returning available models as object that includes 'name' an…
Browse files Browse the repository at this point in the history
…d 'type' instead of just a list os strings. Throwing exceptions when AppConfig.isEnabled() is false to break the tests when any of AI API url or AI API key are missing. Clean the default values for model names. Added better logging at key places when interacting with OpenAI provider. Tests were added/updated.
  • Loading branch information
victoralfaro-dotcms committed Aug 14, 2024
1 parent 8e57564 commit 7aa5eb8
Show file tree
Hide file tree
Showing 18 changed files with 333 additions and 106 deletions.
20 changes: 20 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.dotcms.ai.app;

import com.dotcms.security.apps.AppsUtil;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.Lazy;
import io.vavr.control.Try;
Expand Down Expand Up @@ -141,6 +143,24 @@ public boolean discoverBooleanSecret(final Map<String, Secret> secrets, final Ap
return Boolean.parseBoolean(discoverSecret(secrets, key));
}

/**
* Resolves a secret value from the provided secrets map using the specified key and environment variable.
* If the secret is not found in the secrets map, it attempts to discover the value from the environment variable.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @param envVar the environment variable name to look up if the secret is not found in the secrets map
* @return the resolved secret value or the value from the environment variable if the secret is not found
*/
public String discoverEnvSecret(final Map<String, Secret> secrets, final AppKeys key, final String envVar) {
return Optional
.ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, envVar))
.orElseGet(() -> {
final String secret = discoverSecret(secrets, key);
return UtilMethods.isSet(secret) ? secret : null;
});
}

private int toInt(final String value) {
return Try.of(() -> Integer.parseInt(value)).getOrElse(0);
}
Expand Down
6 changes: 5 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public String getCurrentModel() {
logInvalidModelMessage();
return null;
}

return names.get(currentIndex);
}

Expand All @@ -113,11 +114,14 @@ public long minIntervalBetweenCalls() {
@Override
public String toString() {
return "AIModel{" +
"name='" + names + '\'' +
"type=" + type +
", names=" + names +
", tokensPerMinute=" + tokensPerMinute +
", apiPerMinute=" + apiPerMinute +
", maxTokens=" + maxTokens +
", isCompletion=" + isCompletion +
", current=" + current +
", decommissioned=" + decommissioned +
'}';
}

Expand Down
64 changes: 40 additions & 24 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import com.dotcms.ai.model.OpenAIModel;
import com.dotcms.ai.model.OpenAIModels;
import com.dotcms.ai.model.SimpleModel;
import com.dotcms.http.CircuitBreakerUrl;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.github.benmanes.caffeine.cache.Cache;
Expand All @@ -15,8 +17,9 @@
import org.apache.commons.collections4.CollectionUtils;

import java.time.Duration;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -47,7 +50,7 @@ public class AIModels {

private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels = new ConcurrentHashMap<>();
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName = new ConcurrentHashMap<>();
private final Cache<String, List<String>> supportedModelsCache =
private final Cache<String, Set<String>> supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
Expand Down Expand Up @@ -105,7 +108,11 @@ public void loadModels(final String host, final List<AIModel> loading) {
* @return an Optional containing the found AIModel, or an empty Optional if not found
*/
public Optional<AIModel> findModel(final String host, final String modelName) {
return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase())));
final String lowered = modelName.toLowerCase();
final Set<String> supported = getOrPullSupportedModels();
return supported.contains(lowered)
? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered)))
: Optional.empty();
}

/**
Expand Down Expand Up @@ -144,29 +151,32 @@ public void resetModels(final String host) {
* Retrieves the list of supported models, either from the cache or by fetching them
* from an external source if the cache is empty or expired.
*
* @return a list of supported model names
* @return a set of supported model names
*/
public List<String> getOrPullSupportedModels() {
final List<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
public Set<String> getOrPullSupportedModels() {
final Set<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
if (CollectionUtils.isNotEmpty(cached)) {
return cached;
}

final AppConfig appConfig = appConfigSupplier.get();
if (!appConfig.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models");
return List.of();
AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty list of supported models");
throw new DotRuntimeException("App dotAI config without API urls or API key");
}

final List<String> supported = Try.of(() ->
fetchOpenAIModels(appConfig)
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toList()))
.getOrElse(Optional.ofNullable(cached).orElse(List.of()));
final CircuitBreakerUrl.Response<OpenAIModels> response = fetchOpenAIModels(appConfig);
if (Objects.nonNull(response.getResponse().getError())) {
throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage());
}

final Set<String> supported = response
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toSet());
supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported);

return supported;
Expand All @@ -177,25 +187,30 @@ public List<String> getOrPullSupportedModels() {
*
* @return a list of available model names
*/
public List<String> getAvailableModels() {
final Set<String> configured = internalModels.entrySet().stream().flatMap(entry -> entry.getValue().stream())
public List<SimpleModel> getAvailableModels() {
final Set<SimpleModel> configured = internalModels.entrySet()
.stream()
.flatMap(entry -> entry.getValue().stream())
.map(Tuple2::_2)
.flatMap(model -> model.getNames().stream())
.flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType())))
.collect(Collectors.toSet());
final Set<SimpleModel> supported = getOrPullSupportedModels()
.stream()
.map(SimpleModel::new)
.collect(Collectors.toSet());
final Set<String> supported = new HashSet<>(getOrPullSupportedModels());
configured.retainAll(supported);
return configured.stream().sorted().collect(Collectors.toList());

return new ArrayList<>(configured);
}

private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final AppConfig appConfig) {

final CircuitBreakerUrl.Response<OpenAIModels> response = CircuitBreakerUrl.builder()
.setMethod(CircuitBreakerUrl.Method.GET)
.setUrl(OPEN_AI_MODELS_URL)
.setTimeout(AI_MODELS_FETCH_TIMEOUT)
.setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS)
.setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey()))
.setThrowWhenNot2xx(false)
.setThrowWhenNot2xx(true)
.build()
.doResponse(OpenAIModels.class);

Expand All @@ -206,6 +221,7 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final
"Error fetching OpenAI supported models from [%s] (status code: [%d])",
OPEN_AI_MODELS_URL,
response.getStatusCode()));
throw new DotRuntimeException("Error fetching OpenAI supported models");
}

return response;
Expand Down
94 changes: 69 additions & 25 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,36 @@
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.control.Try;
import org.apache.commons.lang3.StringUtils;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* The AppConfig class provides a configuration for the AI application.
* It includes methods for retrieving configuration values based on given keys.
*/
public class AppConfig implements Serializable {

private static final String AI_API_KEY_KEY = "AI_API_KEY";
private static final String AI_API_URL_KEY = "AI_API_URL";
private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL";
private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL";
private static final String SYSTEM_HOST = "System Host";
public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

private static final AtomicReference<AppConfig> SYSTEM_HOST_CONFIG = new AtomicReference<>();

private final String host;
private final String apiKey;
private final transient AIModel model;
Expand All @@ -39,12 +51,15 @@ public class AppConfig implements Serializable {

public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;
if (SYSTEM_HOST.equalsIgnoreCase(host)) {
setSystemHostConfig(this);
}

final AIAppUtil aiAppUtil = AIAppUtil.get();
apiKey = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY);
apiUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_URL);
apiImageUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_IMAGE_URL);
apiEmbeddingsUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_EMBEDDINGS_URL);
apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL, AI_API_URL_KEY);
apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL, AI_IMAGE_API_URL_KEY);
apiEmbeddingsUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL, AI_EMBEDDINGS_API_URL_KEY);

if (!secrets.isEmpty() || isEnabled()) {
AIModels.get().loadModels(
Expand All @@ -67,18 +82,36 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {

configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

Logger.debug(getClass(), () -> "apiKey: " + apiKey);
Logger.debug(getClass(), () -> "apiUrl: " + apiUrl);
Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl);
Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl);
Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt);
Logger.debug(getClass(), () -> "textPrompt: " + textPrompt);
Logger.debug(getClass(), () -> "model: " + model);
Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt);
Logger.debug(getClass(), () -> "imageModel: " + imageModel);
Logger.debug(getClass(), () -> "imageSize: " + imageSize);
Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel);
Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer);
Logger.debug(this, this::toString);
}

/**
* Retrieves the system host configuration.
*
* @return the system host configuration
*/
public static AppConfig getSystemHostConfig() {
if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) {
setSystemHostConfig(ConfigService.INSTANCE.config());
}
return SYSTEM_HOST_CONFIG.get();
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(clazz, message.get());
}
}

public static void setSystemHostConfig(final AppConfig systemHostConfig) {
AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig);
}

/**
Expand Down Expand Up @@ -282,20 +315,31 @@ public AIModel resolveModelOrThrow(final String modelName) {
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
* Checks if the configuration is enabled.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
* @return true if the configuration is enabled, false otherwise
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(clazz, message.get());
}
public boolean isEnabled() {
return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank);
}

public boolean isEnabled() {
return StringUtils.isNotBlank(apiKey);
@Override
public String toString() {
return "AppConfig{\n" +
" host='" + host + "',\n" +
" apiKey='" + Optional.ofNullable(apiKey).map(key -> "*****").orElse(StringPool.BLANK) + "',\n" +
" model=" + model + "',\n" +
" imageModel=" + imageModel + "',\n" +
" embeddingsModel=" + embeddingsModel + "',\n" +
" apiUrl='" + apiUrl + "',\n" +
" apiImageUrl='" + apiImageUrl + "',\n" +
" apiEmbeddingsUrl='" + apiEmbeddingsUrl + "',\n" +
" rolePrompt='" + rolePrompt + "',\n" +
" textPrompt='" + textPrompt + "',\n" +
" imagePrompt='" + imagePrompt + "',\n" +
" imageSize='" + imageSize + "',\n" +
" listenerIndexer='" + listenerIndexer + "'\n" +
'}';
}

}
10 changes: 6 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.dotcms.ai.app;

import com.liferay.util.StringPool;

public enum AppKeys {

API_KEY("apiKey", null),
API_URL("apiUrl", "https://api.openai.com/v1/chat/completions"),
API_IMAGE_URL("apiImageUrl", "https://api.openai.com/v1/images/generations"),
API_EMBEDDINGS_URL("apiEmbeddingsUrl", "https://api.openai.com/v1/embeddings"),
API_KEY("apiKey", null),
ROLE_PROMPT(
"rolePrompt",
"You are dotCMSbot, and AI assistant to help content" +
Expand All @@ -22,12 +24,12 @@ public enum AppKeys {
IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "0"),
IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "50"),
IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "0"),
IMAGE_MODEL_COMPLETION("imageModelCompletion", "false"),
IMAGE_MODEL_COMPLETION("imageModelCompletion", StringPool.FALSE),
EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", null),
EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000000"),
EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "3000"),
EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "8191"),
EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", "false"),
EMBEDDINGS_MODEL_COMPLETION("embeddingsModelCompletion", StringPool.FALSE),
EMBEDDINGS_SPLIT_AT_TOKENS("com.dotcms.ai.embeddings.split.at.tokens", "512"),
EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX("com.dotcms.ai.embeddings.minimum.text.length", "64"),
EMBEDDINGS_MINIMUM_FILE_SIZE_TO_INDEX("com.dotcms.ai.embeddings.minimum.file.size", "1024"),
Expand All @@ -39,7 +41,7 @@ public enum AppKeys {
EMBEDDINGS_CACHE_TTL_SECONDS("com.dotcms.ai.embeddings.cache.ttl.seconds", "600"),
EMBEDDINGS_CACHE_SIZE("com.dotcms.ai.embeddings.cache.size", "1000"),
EMBEDDINGS_DB_DELETE_OLD_ON_UPDATE("com.dotcms.ai.embeddings.delete.old.on.update", "true"),
DEBUG_LOGGING("com.dotcms.ai.debug.logging", "false"),
DEBUG_LOGGING("com.dotcms.ai.debug.logging", StringPool.FALSE),
COMPLETION_TEMPERATURE("com.dotcms.ai.completion.default.temperature", "1"),
COMPLETION_ROLE_PROMPT(
"com.dotcms.ai.completion.role.prompt",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ private AppConfig getAppConfig(final String hostId) {

final AppConfig appConfig = ConfigService.INSTANCE.config(host);
if (!appConfig.isEnabled()) {
throw new DotRuntimeException("No API key found in app config");
AppConfig.debugLogger(
getClass(),
() -> "dotAI is not enabled since no API urls or API key found in app config");
throw new DotRuntimeException("App dotAI config without API urls or API key");
}

return appConfig;
Expand Down
Loading

0 comments on commit 7aa5eb8

Please sign in to comment.