Skip to content

Commit

Permalink
#29281: Adding fixes for different multiple hosts AppConfig states (#…
Browse files Browse the repository at this point in the history
…29517)

Fixing inconsistencies when fetching dotAI App secrets caused by a wrong
way of resolving the host to use.
This was causing that AIModels singleton class would store inconsistent
models to be used later to request AI data in this case from OpenAI.
  • Loading branch information
victoralfaro-dotcms authored Aug 9, 2024
1 parent 6d5f918 commit 5205ee6
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 136 deletions.
52 changes: 26 additions & 26 deletions dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* This class implements the CompletionsAPI interface and provides the specific logic for interacting with the AI service.
Expand All @@ -40,18 +41,17 @@
*/
public class CompletionsAPIImpl implements CompletionsAPI {

private final Lazy<AppConfig> config;

private final Lazy<AppConfig> defaultConfig =
Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost()))
);

public CompletionsAPIImpl(final Lazy<AppConfig> config) {
this.config = (config != null) ? config : defaultConfig;
private final AppConfig config;
private final Lazy<AppConfig> defaultConfig;

public CompletionsAPIImpl(final AppConfig config) {
defaultConfig =
Lazy.of(() -> ConfigService.INSTANCE.config(
Try.of(() -> WebAPILocator
.getHostWebAPI()
.getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
.getOrElse(APILocator.systemHost())));
this.config = Optional.ofNullable(config).orElse(defaultConfig.get());
}

@Override
Expand All @@ -60,7 +60,7 @@ public JSONObject prompt(final String systemPrompt,
final String modelIn,
final float temperature,
final int maxTokens) {
final AIModel model = config.get().resolveModelOrThrow(modelIn);
final AIModel model = config.resolveModelOrThrow(modelIn);
final JSONObject json = new JSONObject();

json.put(AiKeys.TEMPERATURE, temperature);
Expand Down Expand Up @@ -89,9 +89,9 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
json.put(AiKeys.STREAM, false);
final String openAiResponse =
Try.of(() -> OpenAIRequest.doRequest(
config.get().getApiUrl(),
config.getApiUrl(),
HttpMethod.POST,
config.get(),
config,
json))
.getOrElseThrow(DotRuntimeException::new);
final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults);
Expand All @@ -107,21 +107,21 @@ public void summarizeStream(final CompletionsForm summaryRequest, final OutputSt

final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doPost(config.get().getApiUrl(), config.get(), json, out);
OpenAIRequest.doPost(config.getApiUrl(), config, json, out);
}

@Override
public JSONObject raw(final JSONObject json) {
if (config.get().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI request:" + json.toString(2));
}

final String response = OpenAIRequest.doRequest(
config.get().getApiUrl(),
config.getApiUrl(),
HttpMethod.POST,
config.get(),
config,
json);
if (config.get().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.info(this.getClass(), "OpenAI response:" + response);
}

Expand All @@ -138,7 +138,7 @@ public JSONObject raw(CompletionsForm promptForm) {
public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
OpenAIRequest.doRequest(config.get().getApiUrl(), HttpMethod.POST, config.get(), json, out);
OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out);
}

private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
Expand All @@ -151,7 +151,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}

private JSONObject buildRequestJson(final CompletionsForm form, final List<EmbeddingsDTO> searchResults) {
final AIModel model = config.get().resolveModelOrThrow(form.model);
final AIModel model = config.resolveModelOrThrow(form.model);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
Expand Down Expand Up @@ -184,7 +184,7 @@ private String getPrompt(final String prompt, final String supportingContent, fi
throw new DotRuntimeException("no prompt or supporting content to summarize found");
}

final String resolvedPrompt = config.get().getConfig(key);
final String resolvedPrompt = config.getConfig(key);
final HttpServletRequest requestProxy = new FakeHttpRequest("localhost", "/").request();
final HttpServletResponse responseProxy = new BaseResponse().response();

Expand All @@ -205,7 +205,7 @@ private String getTextPrompt(final String prompt, final String supportingContent

private int countTokens(final String testString) {
return EncodingUtil.get().registry
.getEncodingForModel(config.get().getModel().getCurrentModel())
.getEncodingForModel(config.getModel().getCurrentModel())
.map(enc -> enc.countTokens(testString))
.orElseThrow(() -> new DotRuntimeException("Encoder not found"));
}
Expand Down Expand Up @@ -244,7 +244,7 @@ private String reduceStringToTokenSize(final String incomingString, final int ma
}

private JSONObject buildRequestJson(final CompletionsForm form) {
final AIModel aiModel = config.get().getModel();
final AIModel aiModel = config.getModel();
final int promptTokens = countTokens(form.prompt);

final JSONArray messages = new JSONArray();
Expand All @@ -256,7 +256,7 @@ private JSONObject buildRequestJson(final CompletionsForm form) {

final JSONObject json = new JSONObject();
json.put(AiKeys.MESSAGES, messages);
json.putIfAbsent(AiKeys.MODEL, config.get().getModel().getCurrentModel());
json.putIfAbsent(AiKeys.MODEL, config.getModel().getCurrentModel());
json.put(AiKeys.TEMPERATURE, form.temperature);
json.put(AiKeys.MAX_TOKENS, form.responseLengthTokens);
json.put(AiKeys.STREAM, form.stream);
Expand Down
27 changes: 11 additions & 16 deletions dotCMS/src/main/java/com/dotcms/ai/api/DotAIAPIFacadeImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.dotcms.ai.app.AppConfig;
import com.dotmarketing.beans.Host;
import com.dotmarketing.util.Logger;
import io.vavr.Lazy;

import java.util.Map;
import java.util.Objects;
Expand All @@ -30,40 +29,36 @@ public class DotAIAPIFacadeImpl implements DotAIAPI {
}
}

private static <T> T unwrap(final Class<T> clazz, final Object... initArguments) {
return Objects.nonNull(initArguments)
&& initArguments.length > 0
&& clazz.isInstance(initArguments[0]) ? clazz.cast(initArguments[0]) : null;
}

private static class DefaultCompletionsAPIProvider implements CompletionsAPIProvider {

private final CompletionsAPI defaultCompletionAPI = new CompletionsAPIImpl(null);
@Override
public CompletionsAPI getCompletionsAPI(final Object... initArguments) {
return Objects.nonNull(initArguments) && initArguments.length > 0?
new CompletionsAPIImpl(unwrap(initArguments)):
defaultCompletionAPI;
return new CompletionsAPIImpl(unwrap(initArguments));
}

private Lazy<AppConfig> unwrap(final Object... initArguments) {
return initArguments[0] instanceof AppConfig?
Lazy.of (()-> (AppConfig) initArguments[0]):(Lazy<AppConfig>) initArguments[0];
private AppConfig unwrap(final Object... initArguments) {
return DotAIAPIFacadeImpl.unwrap(AppConfig.class, initArguments);
}
}

public static class DefaultEmbeddingsAPIProvider implements EmbeddingsAPIProvider {

private final EmbeddingsAPI defaultEmbeddingsAPI = new EmbeddingsAPIImpl(null);

@Override
public EmbeddingsAPI getEmbeddingsAPI(final Object... initArguments) {
return Objects.nonNull(initArguments) && initArguments.length > 0?
new EmbeddingsAPIImpl(unwrap(initArguments)):
defaultEmbeddingsAPI;
return new EmbeddingsAPIImpl(unwrap(initArguments));
}

private Host unwrap(final Object... initArguments) {
return initArguments[0] instanceof Host ?
(Host) initArguments[0]:null;
return DotAIAPIFacadeImpl.unwrap(Host.class, initArguments);
}
}


/**
* Sets the current API implementation name.
* @param apiName
Expand Down
21 changes: 0 additions & 21 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package com.dotcms.ai.app;

import com.dotcms.security.apps.AppsUtil;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
import io.vavr.Lazy;
import io.vavr.control.Try;
Expand Down Expand Up @@ -143,25 +141,6 @@ public boolean discoverBooleanSecret(final Map<String, Secret> secrets, final Ap
return Boolean.parseBoolean(discoverSecret(secrets, key));
}

/**
* Resolves an environment-specific secret value from the provided secrets map using the specified key.
* If the secret is not found, it attempts to discover the value from environment variables.
*
* @param secrets the map of secrets
* @param key the key to look up the secret
* @return the resolved environment-specific secret value or an empty string if not found
*/
public String discoverEnvSecret(final Map<String, Secret> secrets, final AppKeys key) {
final String secret = discoverSecret(secrets, key, StringPool.BLANK);
if (UtilMethods.isSet(secret)) {
return secret;
}

return Optional
.ofNullable(AppsUtil.discoverEnvVarValue(AppKeys.APP_KEY, key.key, null))
.orElse(StringPool.BLANK);
}

private int toInt(final String value) {
return Try.of(() -> Integer.parseInt(value)).getOrElse(0);
}
Expand Down
11 changes: 4 additions & 7 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.dotcms.ai.model.OpenAIModel;
import com.dotcms.ai.model.OpenAIModels;
import com.dotcms.http.CircuitBreakerUrl;
import com.dotmarketing.beans.Host;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.github.benmanes.caffeine.cache.Cache;
Expand Down Expand Up @@ -129,18 +128,16 @@ public Optional<AIModel> findModel(final String host, final AIModelType type) {
*
* @param host the host for which the models are being reset
*/
public void resetModels(final Host host) {
final String hostKey = host.getHostname();
Optional.ofNullable(internalModels.get(hostKey)).ifPresent(models -> {
public void resetModels(final String host) {
Optional.ofNullable(internalModels.get(host)).ifPresent(models -> {
models.clear();
internalModels.remove(hostKey);
internalModels.remove(host);
});
modelsByName.keySet()
.stream()
.filter(key -> key._1.equals(hostKey))
.filter(key -> key._1.equals(host))
.collect(Collectors.toSet())
.forEach(modelsByName::remove);
ConfigService.INSTANCE.config(host);
}

/**
Expand Down
34 changes: 13 additions & 21 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.dotcms.security.apps.Secret;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import io.vavr.control.Try;
Expand All @@ -21,7 +20,6 @@
*/
public class AppConfig implements Serializable {

private static final String OPEN_AI_EMBEDDINGS_URL_KEY = "OPEN_AI_EMBEDDINGS_URL";
public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

private final String host;
Expand All @@ -43,18 +41,19 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;

final AIAppUtil aiAppUtil = AIAppUtil.get();

AIModels.get().loadModels(
this.host,
List.of(
aiAppUtil.createTextModel(secrets),
aiAppUtil.createImageModel(secrets),
aiAppUtil.createEmbeddingsModel(secrets)));

apiUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_URL);
apiImageUrl = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_IMAGE_URL);
apiEmbeddingsUrl = discoverEmbeddingsApiUrl(secrets);
apiKey = aiAppUtil.discoverEnvSecret(secrets, AppKeys.API_KEY);
apiKey = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY);
apiUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_URL);
apiImageUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_IMAGE_URL);
apiEmbeddingsUrl = aiAppUtil.discoverSecret(secrets, AppKeys.API_EMBEDDINGS_URL);

if (!secrets.isEmpty() || isEnabled()) {
AIModels.get().loadModels(
this.host,
List.of(
aiAppUtil.createTextModel(secrets),
aiAppUtil.createImageModel(secrets),
aiAppUtil.createEmbeddingsModel(secrets)));
}

model = resolveModel(AIModelType.TEXT);
imageModel = resolveModel(AIModelType.IMAGE);
Expand Down Expand Up @@ -299,11 +298,4 @@ public boolean isEnabled() {
return StringUtils.isNotBlank(apiKey);
}

private String discoverEmbeddingsApiUrl(final Map<String, Secret> secrets) {
final String url = AIAppUtil.get().discoverEnvSecret(secrets, AppKeys.API_EMBEDDINGS_URL);
return StringUtils.isBlank(url)
? Config.getStringProperty(OPEN_AI_EMBEDDINGS_URL_KEY, "https://api.openai.com/v1/embeddings")
: url;
}

}
6 changes: 3 additions & 3 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppKeys.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ public enum AppKeys {
TEXT_PROMPT("textPrompt", "Use Descriptive writing style."),
IMAGE_PROMPT("imagePrompt", "Use 16:9 aspect ratio."),
IMAGE_SIZE("imageSize", "1024x1024"),
TEXT_MODEL_NAMES("textModelNames", "gpt-3.5-turbo-16k"),
TEXT_MODEL_NAMES("textModelNames", null),
TEXT_MODEL_TOKENS_PER_MINUTE("textModelTokensPerMinute", "180000"),
TEXT_MODEL_API_PER_MINUTE("textModelApiPerMinute", "3500"),
TEXT_MODEL_MAX_TOKENS("textModelMaxTokens", "16384"),
TEXT_MODEL_COMPLETION("textModelCompletion", "true"),
IMAGE_MODEL_NAMES("imageModelNames", "dall-e-3"),
IMAGE_MODEL_NAMES("imageModelNames", null),
IMAGE_MODEL_TOKENS_PER_MINUTE("imageModelTokensPerMinute", "0"),
IMAGE_MODEL_API_PER_MINUTE("imageModelApiPerMinute", "50"),
IMAGE_MODEL_MAX_TOKENS("imageModelMaxTokens", "0"),
IMAGE_MODEL_COMPLETION("imageModelCompletion", "false"),
EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", "text-embedding-ada-002"),
EMBEDDINGS_MODEL_NAMES("embeddingsModelNames", null),
EMBEDDINGS_MODEL_TOKENS_PER_MINUTE("embeddingsModelTokensPerMinute", "1000000"),
EMBEDDINGS_MODEL_API_PER_MINUTE("embeddingsModelApiPerMinute", "3000"),
EMBEDDINGS_MODEL_MAX_TOKENS("embeddingsModelMaxTokens", "8191"),
Expand Down
23 changes: 18 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/app/ConfigService.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import com.dotmarketing.beans.Host;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.exception.DotDataException;
import com.dotmarketing.exception.DotSecurityException;
import com.liferay.portal.model.User;
import io.vavr.control.Try;

import java.util.Map;
Expand All @@ -25,13 +28,23 @@ private ConfigService() {
* by dotCMS.
*/
public AppConfig config(final Host host) {
final User systemUser = APILocator.systemUser();
final Host resolved = resolveHost(host);
final Optional<AppSecrets> appSecrets = Try.of(() -> APILocator
.getAppsAPI()
.getSecrets(AppKeys.APP_KEY, true, resolved, APILocator.systemUser()))
.getOrElse(Optional.empty());
Optional<AppSecrets> appSecrets = Try
.of(() -> APILocator.getAppsAPI().getSecrets(AppKeys.APP_KEY, false, resolved, systemUser))
.get();
final Host realHost;
if (appSecrets.isEmpty()) {
realHost = APILocator.systemHost();
appSecrets = Try
.of(() -> APILocator.getAppsAPI().getSecrets(AppKeys.APP_KEY, false, realHost, systemUser))
.get();
} else {
realHost = resolved;
}

return new AppConfig(resolved.getHostname(), appSecrets.map(AppSecrets::getSecrets).orElse(Map.of()));

return new AppConfig(realHost.getHostname(), appSecrets.map(AppSecrets::getSecrets).orElse(Map.of()));
}

/**
Expand Down
Loading

0 comments on commit 5205ee6

Please sign in to comment.