Skip to content

Commit

Permalink
#28813: Moving OpenAI Model paramenters to dotAI App configuration (#…
Browse files Browse the repository at this point in the history
…29236)

Removing hardcoded OpenAI models at enum class OpenAIModel. Instead they
are now part of the `dotAI.yml` application descriptor so the user can
be the one who configures them not only one but multiple models for
`text`, `image` and `embeddings`.
The way to specify more than one is to provide a comma delimited list in
the new dotAI App params.
Sometimes we accept a model to use in the payload of our AI endpoints,
for this matter we will validate that model and if it's invalid we will
throw an exception.
When is not present in the payload, then our backend will inject the
current model.
Which leads us to the question: How will the current model be
determined?
This is part of the work defined for
#29284 (model fallback)
  • Loading branch information
victoralfaro-dotcms authored Jul 30, 2024
1 parent c5f3961 commit da61861
Show file tree
Hide file tree
Showing 50 changed files with 2,328 additions and 698 deletions.
456 changes: 228 additions & 228 deletions core-web/yarn.lock

Large diffs are not rendered by default.

33 changes: 16 additions & 17 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package com.dotcms.ai.api;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.db.EmbeddingsDTO;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.EncodingUtil;
import com.dotcms.ai.util.OpenAIModel;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.mock.request.FakeHttpRequest;
Expand Down Expand Up @@ -42,7 +42,7 @@ public class CompletionsAPIImpl implements CompletionsAPI {

private final Lazy<AppConfig> config;

final Lazy<AppConfig> defaultConfig =
private final Lazy<AppConfig> defaultConfig =
Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
Expand All @@ -60,7 +60,7 @@ public JSONObject prompt(final String systemPrompt,
final String modelIn,
final float temperature,
final int maxTokens) {
final OpenAIModel model = OpenAIModel.resolveModel(modelIn);
final AIModel model = config.get().resolveModelOrThrow(modelIn);
final JSONObject json = new JSONObject();

json.put(AiKeys.TEMPERATURE, temperature);
Expand All @@ -70,7 +70,7 @@ public JSONObject prompt(final String systemPrompt,
json.put(AiKeys.MAX_TOKENS, maxTokens);
}

json.put(AiKeys.MODEL, model.modelName);
json.put(AiKeys.MODEL, model.getCurrentModel());

return raw(json);
}
Expand All @@ -91,7 +91,7 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
Try.of(() -> OpenAIRequest.doRequest(
config.get().getApiUrl(),
HttpMethod.POST,
config.get().getApiKey(),
config.get(),
json))
.getOrElseThrow(DotRuntimeException::new);
final JSONObject dotCMSResponse = EmbeddingsAPI.impl().reduceChunksToContent(searcher, localResults);
Expand All @@ -107,7 +107,7 @@ public void summarizeStream(final CompletionsForm summaryRequest, final OutputSt

final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doPost(config.get().getApiUrl(), config.get().getApiKey(), json, out);
OpenAIRequest.doPost(config.get().getApiUrl(), config.get(), json, out);
}

@Override
Expand All @@ -119,7 +119,7 @@ public JSONObject raw(final JSONObject json) {
final String response = OpenAIRequest.doRequest(
config.get().getApiUrl(),
HttpMethod.POST,
config.get().getApiKey(),
config.get(),
json);
if (config.get().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI response:" + response);
Expand All @@ -138,7 +138,7 @@ public JSONObject raw(CompletionsForm promptForm) {
public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doRequest(config.get().getApiUrl(), HttpMethod.POST, config.get().getApiKey(), json, out);
OpenAIRequest.doRequest(config.get().getApiUrl(), HttpMethod.POST, config.get(), json, out);
}

private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
Expand All @@ -151,7 +151,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}

private JSONObject buildRequestJson(final CompletionsForm form, final List<EmbeddingsDTO> searchResults) {
final OpenAIModel model = OpenAIModel.resolveModel(form.model);
final AIModel model = config.get().resolveModelOrThrow(form.model);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
Expand All @@ -162,7 +162,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
final int systemPromptTokens = countTokens(systemPrompt);
textPrompt = reduceStringToTokenSize(
textPrompt,
model.maxTokens - form.responseLengthTokens - systemPromptTokens);
model.getMaxTokens() - form.responseLengthTokens - systemPromptTokens);

final JSONObject json = new JSONObject();
json.put(AiKeys.STREAM, form.stream);
Expand All @@ -171,7 +171,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List<Embed
buildMessages(systemPrompt, textPrompt, json);

if (UtilMethods.isSet(form.model)) {
json.put(AiKeys.MODEL, model.modelName);
json.put(AiKeys.MODEL, model.getCurrentModel());
}

json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
Expand Down Expand Up @@ -204,8 +204,8 @@ private String getTextPrompt(final String prompt, final String supportingContent
}

private int countTokens(final String testString) {
return EncodingUtil.registry
.getEncodingForModel(config.get().getConfig(AppKeys.MODEL))
return EncodingUtil.REGISTRY
.getEncodingForModel(config.get().getModel().getCurrentModel())
.map(enc -> enc.countTokens(testString))
.orElseThrow(() -> new DotRuntimeException("Encoder not found"));
}
Expand Down Expand Up @@ -244,20 +244,19 @@ private String reduceStringToTokenSize(final String incomingString, final int ma
}

private JSONObject buildRequestJson(final CompletionsForm form) {
final int maxTokenSize = OpenAIModel.resolveModel(config.get().getConfig(AppKeys.MODEL)).maxTokens;
final AIModel aiModel = config.get().getModel();
final int promptTokens = countTokens(form.prompt);

final JSONArray messages = new JSONArray();
final String textPrompt = reduceStringToTokenSize(
form.prompt,
maxTokenSize - form.responseLengthTokens - promptTokens);
aiModel.getMaxTokens() - form.responseLengthTokens - promptTokens);

messages.add(Map.of(AiKeys.ROLE, AiKeys.USER, AiKeys.CONTENT, textPrompt));

final JSONObject json = new JSONObject();
json.put(AiKeys.MESSAGES, messages);
json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.MODEL));

json.putIfAbsent(AiKeys.MODEL, config.get().getConfig(AppKeys.TEXT_MODEL_NAMES));
json.put(AiKeys.TEMPERATURE, form.temperature);
json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
json.put(AiKeys.STREAM, form.stream);
Expand Down
17 changes: 7 additions & 10 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.dotmarketing.exception.DotCorruptedDataException;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.portlets.contentlet.model.Contentlet;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.StringUtils;
import com.dotmarketing.util.UtilMethods;
Expand All @@ -37,7 +36,6 @@
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.liferay.portal.model.User;
import io.vavr.Lazy;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.Tuple3;
Expand Down Expand Up @@ -69,9 +67,6 @@
*/
class EmbeddingsAPIImpl implements EmbeddingsAPI {

private static final Lazy<String> OPEN_AI_EMBEDDINGS_URL = Lazy.of(()
-> Config.getStringProperty("OPEN_AI_EMBEDDINGS_URL", "https://api.openai.com/v1/embeddings"));

private static final Cache<String, Tuple2<Integer, List<Float>>> EMBEDDING_CACHE =
Caffeine.newBuilder()
.expireAfterWrite(
Expand Down Expand Up @@ -332,7 +327,7 @@ public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String conten
return cachedEmbeddings;
}

final List<Integer> tokens = EncodingUtil.encoding.get().encode(content);
final List<Integer> tokens = EncodingUtil.ENCODING.get().encode(content);
if (tokens.isEmpty()) {
debugLogger(this.getClass(), () -> String.format("No tokens for content ID '%s' were encoded: %s", contentId, content));
return Tuple.of(0, List.of());
Expand All @@ -348,7 +343,9 @@ public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String conten
return Tuple.of(dbEmbeddings._2, dbEmbeddings._3);
}

final Tuple2<Integer, List<Float>> openAiEmbeddings = Tuple.of(tokens.size(), this.sendTokensToOpenAI(contentId, tokens));
final Tuple2<Integer, List<Float>> openAiEmbeddings = Tuple.of(
tokens.size(),
sendTokensToOpenAI(contentId, tokens));
saveEmbeddingsForCache(content, openAiEmbeddings);
EMBEDDING_CACHE.put(hashed, openAiEmbeddings);

Expand Down Expand Up @@ -424,13 +421,13 @@ private void saveEmbeddingsForCache(final String content, final Tuple2<Integer,
*/
private List<Float> sendTokensToOpenAI(final String contentId, @NotNull final List<Integer> tokens) {
final JSONObject json = new JSONObject();
json.put(AiKeys.MODEL, config.getConfig(AppKeys.EMBEDDINGS_MODEL));
json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel());
json.put(AiKeys.INPUT, tokens);
debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
final String responseString = OpenAIRequest.doRequest(
OPEN_AI_EMBEDDINGS_URL.get(),
config.getApiEmbeddingsUrl(),
HttpMethod.POST,
this.config.getApiKey(),
config,
json);
debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
contentId, responseString.replace("\n", BLANK)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void run() {
int totalTokens = 0;
for (int end = iterator.next(); end != BreakIterator.DONE; start = end, end = iterator.next()) {
final String sentence = cleanContent.substring(start, end);
final int tokenCount = EncodingUtil.encoding.get().countTokens(sentence);
final int tokenCount = EncodingUtil.ENCODING.get().countTokens(sentence);
totalTokens += tokenCount;

if (totalTokens < splitAtTokens) {
Expand Down
169 changes: 169 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package com.dotcms.ai.app;

import com.dotcms.security.apps.AppsUtil;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.Lazy;
import io.vavr.control.Try;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Utility class for handling AI application configurations and secrets.
* This class provides methods to resolve secrets, normalize model names,
* split model names, and create AI model instances based on the provided
* configuration and secrets.
*
* @author vico
*/
public class AIAppUtil {

private static final Lazy<AIAppUtil> INSTANCE = Lazy.of(AIAppUtil::new);

private AIAppUtil() {
// Private constructor to prevent instantiation
}

public static AIAppUtil get() {
return INSTANCE.get();
}

/**
* Creates a text model instance based on the provided secrets.
*
* @param secrets the map of secrets
* @return the created text model instance
*/
public AIModel createTextModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.TEXT)
.withNames(discoverSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS))
.withIsCompletion(discoverBooleanSecret(secrets, AppKeys.TEXT_MODEL_COMPLETION))
.build();
}

/**
* Creates an image model instance based on the provided secrets.
*
* @param secrets the map of secrets
* @return the created image model instance
*/
public AIModel createImageModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.IMAGE)
.withNames(discoverSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS))
.withIsCompletion(discoverBooleanSecret(secrets, AppKeys.IMAGE_MODEL_COMPLETION))
.build();
}

/**
* Creates an embeddings model instance based on the provided secrets.
*
* @param secrets the map of secrets
* @return the created embeddings model instance
*/
public AIModel createEmbeddingsModel(final Map<String, Secret> secrets) {
return AIModel.builder()
.withType(AIModelType.EMBEDDINGS)
.withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES))
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS))
.withIsCompletion(discoverBooleanSecret(secrets, AppKeys.EMBEDDINGS_MODEL_COMPLETION))
.build();
}

/**
* Resolves a secret value from the provided secrets map using the specified key.
* If the secret is not found, the default value is returned.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @param defaultValue the default value to return if the secret is not found
* @return the resolved secret value or the default value if the secret is not found
*/
public String discoverSecret(final Map<String, Secret> secrets, final AppKeys key, final String defaultValue) {
return Try.of(() -> secrets.get(key.key).getString()).getOrElse(defaultValue);
}

/**
* Resolves a secret value from the provided secrets map using the specified key.
* If the secret is not found, the default value defined in the key is returned.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @return the resolved secret value or the default value defined in the key if the secret is not found
*/
public String discoverSecret(final Map<String, Secret> secrets, final AppKeys key) {
return discoverSecret(secrets, key, key.defaultValue);
}

/**
* Splits a model-specific secret value from the provided secrets map using the specified key.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @return the list of split secret values
*/
public List<String> splitDiscoveredSecret(final Map<String, Secret> secrets, final AppKeys key) {
return Arrays.stream(discoverSecret(secrets, key).split(","))
.map(String::trim)
.map(String::toLowerCase)
.collect(Collectors.toList());
}

/**
* Resolves a model-specific secret value from the provided secrets map using the specified key and model type.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
*/
public int discoverIntSecret(final Map<String, Secret> secrets, final AppKeys key) {
return toInt(discoverSecret(secrets, key));
}

/**
* Resolves a model-specific secret value from the provided secrets map using the specified key and model type.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
*/
public boolean discoverBooleanSecret(final Map<String, Secret> secrets, final AppKeys key) {
return Boolean.parseBoolean(discoverSecret(secrets, key));
}

/**
* Resolves an environment-specific secret value from the provided secrets map using the specified key.
* If the secret is not found, it attempts to discover the value from environment variables.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @return the resolved environment-specific secret value or an empty string if not found
*/
public String discoverEnvSecret(final Map<String, Secret> secrets, final AppKeys key) {
final String secret = discoverSecret(secrets, key, StringPool.BLANK);
if (UtilMethods.isSet(secret)) {
return secret;
}

return Optional
.ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, null))
.orElse(StringPool.BLANK);
}

private int toInt(final String value) {
return Try.of(() -> Integer.parseInt(value)).getOrElse(0);
}

}
Loading

0 comments on commit da61861

Please sign in to comment.