Skip to content

Commit

Permalink
#29281: adding a centralized OpenAI api-key validation procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Aug 2, 2024
1 parent 07b0135 commit f54a174
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ private JSONObject buildRequestJson(final CompletionsForm form) {

final JSONObject json = new JSONObject();
json.put(AiKeys.MESSAGES, messages);
json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.TEXT_MODEL_NAMES));
json.putIfAbsent(AiKeys.MODEL, config.get().getModel().getCurrentModel());
json.put(AiKeys.TEMPERATURE, form.temperature);
json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
json.put(AiKeys.STREAM, form.stream);
Expand Down
4 changes: 2 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static AIAppUtil get() {
public AIModel createTextModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.TEXT)
.withNames(discoverSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS))
Expand All @@ -59,7 +59,7 @@ public AIModel createTextModel(final Map<String, Secret> secrets) {
public AIModel createImageModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.IMAGE)
.withNames(discoverSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS))
Expand Down
9 changes: 9 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
*/
public class AIModel {

public static final AIModel NOOP_MODEL = AIModel.builder()
.withType(AIModelType.UNKNOWN)
.withNames(List.of())
.build();

private final AIModelType type;
private final List<String> names;
private final int tokensPerMinute;
Expand Down Expand Up @@ -88,6 +93,10 @@ public void setDecommissioned(final boolean decommissioned) {
this.decommissioned.set(decommissioned);
}

public boolean isOperational() {
return this != NOOP_MODEL;
}

public String getCurrentModel() {
final int currentIndex = this.current.get();
if (!isCurrentValid(currentIndex)) {
Expand Down
60 changes: 38 additions & 22 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.dotmarketing.util.Logger;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.google.common.annotations.VisibleForTesting;
import io.vavr.Lazy;
import io.vavr.Tuple;
import io.vavr.Tuple2;
Expand All @@ -21,6 +22,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -44,18 +46,14 @@ public class AIModels {
private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours
private static final int AI_MODELS_CACHE_SIZE = 128;

public static final AIModel NOOP_MODEL = AIModel.builder()
.withType(AIModelType.UNKNOWN)
.withNames(List.of())
.build();

private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels = new ConcurrentHashMap<>();
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName = new ConcurrentHashMap<>();
private final Cache<String, List<String>> supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
.build();
private Supplier<AppConfig> appConfigSupplier = ConfigService.INSTANCE::config;

public static AIModels get() {
return INSTANCE.get();
Expand Down Expand Up @@ -154,24 +152,31 @@ public void resetModels(final Host host) {
* @return a list of supported model names
*/
public List<String> getOrPullSupportedModels() {
final List<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
if (CollectionUtils.isNotEmpty(cached)) {
return cached;
synchronized (supportedModelsCache) {
final List<String> cached = supportedModelsCache.getIfPresent(SUPPORTED_MODELS_KEY);
if (CollectionUtils.isNotEmpty(cached)) {
return cached;
}

final AppConfig appConfig = appConfigSupplier.get();
if (!appConfig.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models");
return List.of();
}

final List<String> supported = Try.of(() ->
fetchOpenAIModels(appConfig)
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toList()))
.getOrElse(Optional.ofNullable(cached).orElse(List.of()));
supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported);

return supported;
}

final AppConfig appConfig = ConfigService.INSTANCE.config();
final List<String> supported = Try.of(() ->
fetchOpenAIModels(appConfig)
.getResponse()
.getData()
.stream()
.map(OpenAIModel::getId)
.map(String::toLowerCase)
.collect(Collectors.toList()))
.getOrElse(Optional.ofNullable(cached).orElse(List.of()));
supportedModelsCache.put(SUPPORTED_MODELS_KEY, supported);

return supported;
}

/**
Expand Down Expand Up @@ -212,4 +217,15 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final

return response;
}

@VisibleForTesting
void setAppConfigSupplier(final Supplier<AppConfig> appConfigSupplier) {
this.appConfigSupplier = appConfigSupplier;
}

@VisibleForTesting
void cleanSupportedModelsCache() {
supportedModelsCache.invalidateAll();
}

}
28 changes: 23 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dotcms.ai.app;

import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
Expand All @@ -25,13 +26,13 @@ public class AppConfig implements Serializable {
public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

private final String host;
private final String apiKey;
private final transient AIModel model;
private final transient AIModel imageModel;
private final transient AIModel embeddingsModel;
private final String apiUrl;
private final String apiImageUrl;
private final String apiEmbeddingsUrl;
private final String apiKey;
private final String rolePrompt;
private final String textPrompt;
private final String imagePrompt;
Expand All @@ -43,6 +44,8 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;

final AIAppUtil aiAppUtil = AIAppUtil.get();
apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY);

AIModels.get().loadModels(
this.host,
List.of(
Expand All @@ -57,7 +60,7 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {
apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL);
apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL);
apiEmbeddingsUrl = discoverEmbeddingsApiUrl(secrets);
apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY);

rolePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.ROLE_PROMPT);
textPrompt = aiAppUtil.discoverSecret(secrets, AppKeys.TEXT_PROMPT);
imagePrompt = aiAppUtil.discoverSecret(secrets, AppKeys.IMAGE_PROMPT);
Expand All @@ -66,10 +69,10 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {

configValues = secrets.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

Logger.debug(getClass(), () -> "apiKey: " + apiKey);
Logger.debug(getClass(), () -> "apiUrl: " + apiUrl);
Logger.debug(getClass(), () -> "apiImageUrl: " + apiImageUrl);
Logger.debug(getClass(), () -> "embeddingsUrl: " + apiEmbeddingsUrl);
Logger.debug(getClass(), () -> "apiKey: " + apiKey);
Logger.debug(getClass(), () -> "model: " + model);
Logger.debug(getClass(), () -> "imageModel: " + imageModel);
Logger.debug(getClass(), () -> "embeddingsModel: " + embeddingsModel);
Expand Down Expand Up @@ -251,7 +254,7 @@ public String getConfig(final AppKeys appKey) {
* @param type the type of the model to find
*/
public AIModel resolveModel(final AIModelType type) {
return AIModels.get().findModel(host, type).orElse(AIModels.NOOP_MODEL);
return AIModels.get().findModel(host, type).orElse(AIModel.NOOP_MODEL);
}

/**
Expand All @@ -260,13 +263,24 @@ public AIModel resolveModel(final AIModelType type) {
* @param modelName the name of the model to find
*/
public AIModel resolveModelOrThrow(final String modelName) {
return AIModels.get()
final AIModel model = AIModels.get()
.findModel(host, modelName)
.orElseThrow(() -> {
final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels());
return new DotRuntimeException(
"Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported ");
});

if (!model.isOperational()) {
Logger.debug(
OpenAIRequest.class,
String.format(
"Resolved model [%s] is not operational, avoiding its usage",
model.getCurrentModel()));
throw new DotRuntimeException(String.format("Model [%s] is not operational", model.getCurrentModel()));
}

return model;
}

/**
Expand All @@ -282,6 +296,10 @@ public static void debugLogger(final Class<?> clazz, final Supplier<String> mess
}
}

public boolean isEnabled() {
return StringUtils.isNotBlank(apiKey);
}

private String discoverEmbeddingsApiUrl(final Map<String, Secret> secrets) {
final String url = AIAppUtil.get().discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL);
return StringUtils.isBlank(url)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private AppConfig getAppConfig(final String hostId) {
*/
private JSONObject getConfigJson(final String hostId) {
return Try
.of(() -> new JSONObject(getAppConfig(hostId).getConfig(AppKeys.LISTENER_INDEXER)))
.of(() -> new JSONObject(getAppConfig(hostId).getListenerIndexer()))
.onFailure(e -> Logger.debug(getClass(), "error in json config from app: " + e.getMessage()))
.getOrElse(new JSONObject());
}
Expand Down
44 changes: 43 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/model/OpenAIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ public class OpenAIModels implements Serializable {

private final String object;
private final List<OpenAIModel> data;
private final OpenAIError error;

@JsonCreator
public OpenAIModels(@JsonProperty("object") final String object,
@JsonProperty("data") final List<OpenAIModel> data) {
@JsonProperty("data") final List<OpenAIModel> data,
@JsonProperty("error") final OpenAIError error) {
this.object = object;
this.data = data;
this.error = error;
}

public String getObject() {
Expand All @@ -32,4 +35,43 @@ public List<OpenAIModel> getData() {
return data;
}

public OpenAIError getError() {
return error;
}

public static class OpenAIError {

private final String message;
private final String type;
private final String param;
private final String code;

@JsonCreator
public OpenAIError(@JsonProperty("object") final String message,
@JsonProperty("type") final String type,
@JsonProperty("param") final String param,
@JsonProperty("code") final String code) {
this.message = message;
this.type = type;
this.param = param;
this.code = code;
}

public String getMessage() {
return message;
}

public String getType() {
return type;
}

public String getParam() {
return param;
}

public String getCode() {
return code;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private CompletionsForm(final Builder builder) {
} else {
this.temperature = builder.temperature >= 2 ? 2 : builder.temperature;
}
this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getConfig(AppKeys.TEXT_MODEL_NAMES);
this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getModel().getCurrentModel();
}

private String validateBuilderQuery(final String query) {
Expand Down
Loading

0 comments on commit f54a174

Please sign in to comment.