Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#29281: Several dotAI enhancements #29588

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cicd_comp_test-phase.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
strategy:
fail-fast: false
matrix:
collection_group: [ 'category-content', 'container', 'experiment', 'graphql', 'page', 'pp', 'template', 'workflow', 'default-split', 'default' ]
collection_group: [ 'ai', 'category-content', 'container', 'experiment', 'graphql', 'page', 'pp', 'template', 'workflow', 'default-split', 'default' ]
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
20 changes: 20 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.dotcms.ai.app;

import com.dotcms.security.apps.AppsUtil;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.Lazy;
import io.vavr.control.Try;
Expand Down Expand Up @@ -141,6 +143,24 @@ public boolean discoverBooleanSecret(final Map<String, Secret> secrets, final Ap
return Boolean.parseBoolean(discoverSecret(secrets, key));
}

/**
* Resolves a secret value from the provided secrets map using the specified key and environment variable.
* If the secret is not found in the secrets map, it attempts to discover the value from the environment variable.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @param envVar the environment variable name to look up if the secret is not found in the secrets map
* @return the resolved secret value or the value from the environment variable if the secret is not found
*/
public String discoverEnvSecret(final Map<String, Secret> secrets, final AppKeys key, final String envVar) {
return Optional
.ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, envVar))
.orElseGet(() -> {
final String secret = discoverSecret(secrets, key);
return UtilMethods.isSet(secret) ? secret : null;
});
}

private int toInt(final String value) {
return Try.of(() -> Integer.parseInt(value)).getOrElse(0);
}
Expand Down
6 changes: 5 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public String getCurrentModel() {
logInvalidModelMessage();
return null;
}

return names.get(currentIndex);
}

Expand All @@ -113,11 +114,14 @@ public long minIntervalBetweenCalls() {
@Override
public String toString() {
return "AIModel{" +
"name='" + names + '\'' +
"type=" + type +
", names=" + names +
", tokensPerMinute=" + tokensPerMinute +
", apiPerMinute=" + apiPerMinute +
", maxTokens=" + maxTokens +
", isCompletion=" + isCompletion +
", current=" + current +
", decommissioned=" + decommissioned +
'}';
}

Expand Down
74 changes: 45 additions & 29 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import com.dotcms.ai.model.OpenAIModel;
import com.dotcms.ai.model.OpenAIModels;
import com.dotcms.ai.model.SimpleModel;
import com.dotcms.http.CircuitBreakerUrl;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.github.benmanes.caffeine.cache.Cache;
Expand All @@ -11,12 +13,12 @@
import io.vavr.Lazy;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.collections4.CollectionUtils;

import java.time.Duration;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -39,15 +41,16 @@ public class AIModels {
private static final String AI_MODELS_FETCH_TIMEOUT_KEY = "ai.models.fetch.timeout";
private static final int AI_MODELS_FETCH_TIMEOUT = Config.getIntProperty(AI_MODELS_FETCH_TIMEOUT_KEY, 4000);
private static final Lazy<AIModels> INSTANCE = Lazy.of(AIModels::new);
private static final String OPEN_AI_MODELS_URL = Config.getStringProperty(
"OPEN_AI_MODELS_URL",
private static final String AI_MODELS_API_URL_KEY = "DOT_AI_MODELS_API_URL";
private static final String AI_MODELS_API_URL = Config.getStringProperty(
AI_MODELS_API_URL_KEY,
"https://api.openai.com/v1/models");
private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours
private static final int AI_MODELS_CACHE_SIZE = 128;

private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels = new ConcurrentHashMap<>();
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName = new ConcurrentHashMap<>();
private final Cache<String, List<String>> supportedModelsCache =
private final Cache<String, Set<String>> supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
Expand Down Expand Up @@ -105,7 +108,11 @@ public void loadModels(final String host, final List<AIModel> loading) {
* @return an Optional containing the found AIModel, or an empty Optional if not found
*/
public Optional<AIModel> findModel(final String host, final String modelName) {
return Optional.ofNullable(modelsByName.get(Tuple.of(host, modelName.toLowerCase())));
final String lowered = modelName.toLowerCase();
final Set<String> supported = getOrPullSupportedModels();
return supported.contains(lowered)
? Optional.ofNullable(modelsByName.get(Tuple.of(host, lowered)))
: Optional.empty();
}

/**
Expand Down Expand Up @@ -144,29 +151,32 @@ public void resetModels(final String host) {
* Retrieves the list of supported models, either from the cache or by fetching them
* from an external source if the cache is empty or expired.
*
* @return a list of supported model names
* @return a set of supported model names
*/
public List<String> getOrPullSupportedModels() {
final List<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
public Set<String> getOrPullSupportedModels() {
final Set<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
if (CollectionUtils.isNotEmpty(cached)) {
return cached;
}

final AppConfig appConfig = appConfigSupplier.get();
if (!appConfig.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models");
return List.of();
AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty list of supported models");
throw new DotRuntimeException("App dotAI config without API urls or API key");
}

final List<String> supported = Try.of(() ->
fetchOpenAIModels(appConfig)
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toList()))
.getOrElse(Optional.ofNullable(cached).orElse(List.of()));
final CircuitBreakerUrl.Response<OpenAIModels> response = fetchOpenAIModels(appConfig);
if (Objects.nonNull(response.getResponse().getError())) {
throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage());
victoralfaro-dotcms marked this conversation as resolved.
Show resolved Hide resolved
}

final Set<String> supported = response
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toSet());
supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported);

return supported;
Expand All @@ -177,25 +187,30 @@ public List<String> getOrPullSupportedModels() {
*
* @return a list of available model names
*/
public List<String> getAvailableModels() {
final Set<String> configured = internalModels.entrySet().stream().flatMap(entry -> entry.getValue().stream())
public List<SimpleModel> getAvailableModels() {
final Set<SimpleModel> configured = internalModels.entrySet()
.stream()
.flatMap(entry -> entry.getValue().stream())
.map(Tuple2::_2)
.flatMap(model -> model.getNames().stream())
.flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType())))
.collect(Collectors.toSet());
final Set<SimpleModel> supported = getOrPullSupportedModels()
.stream()
.map(SimpleModel::new)
.collect(Collectors.toSet());
final Set<String> supported = new HashSet<>(getOrPullSupportedModels());
configured.retainAll(supported);
return configured.stream().sorted().collect(Collectors.toList());

return new ArrayList<>(configured);
}

private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final AppConfig appConfig) {

final CircuitBreakerUrl.Response<OpenAIModels> response = CircuitBreakerUrl.builder()
.setMethod(CircuitBreakerUrl.Method.GET)
.setUrl(OPEN_AI_MODELS_URL)
.setUrl(AI_MODELS_API_URL)
.setTimeout(AI_MODELS_FETCH_TIMEOUT)
.setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS)
.setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey()))
.setThrowWhenNot2xx(false)
.setThrowWhenNot2xx(true)
.build()
.doResponse(OpenAIModels.class);

Expand All @@ -204,8 +219,9 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final
AIModels.class,
String.format(
"Error fetching OpenAI supported models from [%s] (status code: [%d])",
OPEN_AI_MODELS_URL,
AI_MODELS_API_URL,
response.getStatusCode()));
throw new DotRuntimeException("Error fetching OpenAI supported models");
}

return response;
Expand Down
96 changes: 71 additions & 25 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,38 @@

import com.dotcms.security.apps.Secret;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.control.Try;
import org.apache.commons.lang3.StringUtils;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* The AppConfig class provides a configuration for the AI application.
* It includes methods for retrieving configuration values based on given keys.
*/
public class AppConfig implements Serializable {

private static final String AI_API_URL_KEY = "AI_API_URL";
private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL";
private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL";
private static final String AI_DEBUG_LOGGER_KEY = "AI_DEBUG_LOGGER";
private static final String SYSTEM_HOST = "System Host";
private static final AtomicReference<AppConfig> SYSTEM_HOST_CONFIG = new AtomicReference<>();
private static final boolean DEBUG_LOGGING = Config.getBooleanProperty(AI_DEBUG_LOGGER_KEY, false);

public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

private final String host;
Expand All @@ -39,12 +53,15 @@ public class AppConfig implements Serializable {

public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;
if (SYSTEM_HOST.equalsIgnoreCase(host)) {
setSystemHostConfig(this);
}

final AIAppUtil aiAppUtil = AIAppUtil.get();
apiKey = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY);
apiUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_URL);
apiImageUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_IMAGE_URL);
apiEmbeddingsUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_EMBEDDINGS_URL);
apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL, AI_API_URL_KEY);
apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL, AI_IMAGE_API_URL_KEY);
apiEmbeddingsUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL, AI_EMBEDDINGS_API_URL_KEY);

if (!secrets.isEmpty() || isEnabled()) {
AIModels.get().loadModels(
Expand All @@ -67,18 +84,36 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {

configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

Logger.debug(getClass(), () -> "apiKey: " + apiKey);
Logger.debug(getClass(), () -> "apiUrl: " + apiUrl);
Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl);
Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl);
Logger.debug(getClass(), () -> "rolePrompt: " + rolePrompt);
Logger.debug(getClass(), () -> "textPrompt: " + textPrompt);
Logger.debug(getClass(), () -> "model: " + model);
Logger.debug(getClass(), () -> "imagePrompt: " + imagePrompt);
Logger.debug(getClass(), () -> "imageModel: " + imageModel);
Logger.debug(getClass(), () -> "imageSize: " + imageSize);
Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel);
Logger.debug(getClass(), () -> "listerIndexer: " + listenerIndexer);
Logger.debug(this, this::toString);
}

/**
* Retrieves the system host configuration.
*
* @return the system host configuration
*/
public static AppConfig getSystemHostConfig() {
if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) {
setSystemHostConfig(ConfigService.INSTANCE.config());
}
return SYSTEM_HOST_CONFIG.get();
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING) || DEBUG_LOGGING) {
Logger.info(clazz, message.get());
}
}

public static void setSystemHostConfig(final AppConfig systemHostConfig) {
AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig);
}

/**
Expand Down Expand Up @@ -282,20 +317,31 @@ public AIModel resolveModelOrThrow(final String modelName) {
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
* Checks if the configuration is enabled.
*
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
* @return true if the configuration is enabled, false otherwise
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(clazz, message.get());
}
public boolean isEnabled() {
return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank);
}

public boolean isEnabled() {
return StringUtils.isNotBlank(apiKey);
@Override
public String toString() {
return "AppConfig{\n" +
" host='" + host + "',\n" +
" apiKey='" + Optional.ofNullable(apiKey).map(key -> "*****").orElse(StringPool.BLANK) + "',\n" +
" model=" + model + "',\n" +
" imageModel=" + imageModel + "',\n" +
" embeddingsModel=" + embeddingsModel + "',\n" +
" apiUrl='" + apiUrl + "',\n" +
" apiImageUrl='" + apiImageUrl + "',\n" +
" apiEmbeddingsUrl='" + apiEmbeddingsUrl + "',\n" +
" rolePrompt='" + rolePrompt + "',\n" +
" textPrompt='" + textPrompt + "',\n" +
" imagePrompt='" + imagePrompt + "',\n" +
" imageSize='" + imageSize + "',\n" +
" listenerIndexer='" + listenerIndexer + "'\n" +
'}';
}

}
Loading
Loading