diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java index 9c04e189ed55..4b33ca0bb2b3 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java +++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java @@ -214,8 +214,8 @@ private String getTextPrompt(final String prompt, final String supportingContent } private int countTokens(final String testString) { - return EncodingUtil.get().registry - .getEncodingForModel(config.getModel().getCurrentModel()) + return EncodingUtil.get() + .getEncoding(config, AIModelType.TEXT) .map(enc -> enc.countTokens(testString)) .orElseThrow(() -> new DotRuntimeException("Encoder not found")); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index 3e98b716ce04..39558755bbe7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -21,8 +21,7 @@ */ public class AIModel { - private static final int NOOP_INDEX = -1; - + public static final int NOOP_INDEX = -1; public static final AIModel NOOP_MODEL = AIModel.builder() .withType(AIModelType.UNKNOWN) .withModelNames(List.of()) diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 91be3359de7f..acead42c48d5 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -293,7 +293,7 @@ private void activateModels(final String host, boolean wasAdded) { final String modelName = model.getName().trim().toLowerCase(); final ModelStatus status; status = ModelStatus.ACTIVE; - if (aiModel.getCurrentModelIndex() == -1) { + if (aiModel.getCurrentModelIndex() == AIModel.NOOP_INDEX) { aiModel.setCurrentModelIndex(model.getIndex()); } Logger.debug( diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index 04647f8be4b0..0ab3b83b4273 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -297,16 +297,6 @@ public AIModel resolveModel(final AIModelType type) { return AIModels.get().resolveModel(host, type); } - /** - * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. - * - * @param modelName the name of the model to find - * @param type the type of the model to find - */ - public AIModel resolveAIModelOrThrow(final String modelName, final AIModelType type) { - return AIModels.get().resolveAIModelOrThrow(this, modelName, type); - } - /** * Resolves a model-specific secret value from the provided secrets map using the specified key and model type. * If the model is not found or is not operational, it throws an appropriate exception. diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java index d5689c71da2e..0553645ece58 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -97,12 +97,6 @@ private static Tuple2 resolveModel(final JSONObjectAIRequest req private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { if (firstAttempt.equals(model)) { - AppConfig.debugLogger( - AIModelFallbackStrategy.class, - () -> String.format( - "Model [%s] is the same as the current one [%s].", - model.getName(), - firstAttempt.getName())); return true; } @@ -150,11 +144,17 @@ private static void handleFailure(final Tuple2 modelTuple, final Model model = modelTuple._2; if (!responseData.getStatus().doesNeedToThrow()) { + AppConfig.debugLogger( + AIModelFallbackStrategy.class, + () -> String.format( + "Model [%s] failed then setting its status to [%s].", + model.getName(), + responseData.getStatus())); model.setStatus(responseData.getStatus()); } if (model.getIndex() == aiModel.getModels().size() - 1) { - aiModel.setCurrentModelIndex(-1); + aiModel.setCurrentModelIndex(AIModel.NOOP_INDEX); AppConfig.debugLogger( AIModelFallbackStrategy.class, () -> String.format( @@ -217,11 +217,10 @@ private static void logFailure(final Tuple2 modelTuple, final AI response -> AppConfig.debugLogger( AIModelFallbackStrategy.class, () -> String.format( - "Model [%s] failed with response:%s%s%s. Trying next model.", + "Model [%s] failed with response:%s%sTrying next model.", modelTuple._2.getName(), System.lineSeparator(), - response, - System.lineSeparator())), + response)), () -> AppConfig.debugLogger( AIModelFallbackStrategy.class, () -> String.format( diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java b/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java index e279e3085ae9..9aed5869213a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java +++ b/dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java @@ -1,11 +1,17 @@ package com.dotcms.ai.util; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; +import com.dotcms.ai.domain.Model; +import com.dotcms.ai.domain.ModelStatus; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingRegistry; import io.vavr.Lazy; +import java.util.Objects; import java.util.Optional; /** @@ -26,10 +32,65 @@ public static EncodingUtil get() { return INSTANCE.get(); } + public Optional getEncoding(final AppConfig appConfig, final AIModelType type) { + final AIModel aiModel = appConfig.resolveModel(type); + final Model currentModel = aiModel.getCurrent(); + + if (Objects.isNull(currentModel)) { + AppConfig.debugLogger( + getClass(), + () -> String.format( + "No current model found for type [%s], meaning the are all are exhausted", + type)); + return Optional.empty(); + } + + return registry + .getEncodingForModel(currentModel.getName()) + .or(() -> modelFallback(aiModel, currentModel)); + } + public Optional getEncoding() { - return Optional - .ofNullable(ConfigService.INSTANCE.config().getEmbeddingsModel().getCurrentModel()) - .flatMap(registry::getEncodingForModel); + return getEncoding(ConfigService.INSTANCE.config(), AIModelType.EMBEDDINGS); + } + + private Optional modelFallback(final AIModel aiModel, + final Model currentModel) { + AppConfig.debugLogger( + getClass(), + () -> String.format( + "Model [%s] is not suitable for encoding, marking it as invalid and falling back to other models", + currentModel.getName())); + currentModel.setStatus(ModelStatus.INVALID); + + return aiModel.getModels() + .stream() + .filter(model -> !model.equals(currentModel)) + .map(model -> { + if (aiModel.getCurrentModelIndex() != currentModel.getIndex()) { + return null; + } + + final Optional encoding = registry.getEncodingForModel(model.getName()); + if (encoding.isEmpty()) { + model.setStatus(ModelStatus.INVALID); + AppConfig.debugLogger( + getClass(), + () -> String.format( + "Model [%s] is not suitable for encoding, marking as invalid", + model.getName())); + return null; + } + + aiModel.setCurrentModelIndex(model.getIndex()); + AppConfig.debugLogger( + getClass(), + () -> "Model [" + model.getName() + "] found, setting as current model"); + return encoding.get(); + + }) + .filter(Objects::nonNull) + .findFirst(); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java index e8dc49722255..414f18926747 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java +++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java @@ -1,5 +1,6 @@ package com.dotcms.ai.viewtool; +import com.dotcms.ai.app.AIModelType; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.ConfigService; import com.dotcms.ai.util.EncodingUtil; @@ -58,8 +59,8 @@ public void init(Object initData) { * @return The number of tokens in the prompt, or -1 if no encoding is found for the model. */ public int countTokens(final String prompt) { - return EncodingUtil.get().registry - .getEncodingForModel(appConfig.getModel().getCurrentModel()) + return EncodingUtil.get() + .getEncoding(appConfig, AIModelType.TEXT) .map(encoding -> encoding.countTokens(prompt)) .orElse(-1); } diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js index f2db45f1a34c..34a4462162ba 100644 --- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js +++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js @@ -136,7 +136,6 @@ const writeModelToDropdown = async () => { } const newOption = document.createElement("option"); - console.log(JSON.stringify(dotAiState.config, null, 2)); newOption.value = dotAiState.config.availableModels[i].name; newOption.text = `${dotAiState.config.availableModels[i].name}` if (dotAiState.config.availableModels[i].current) {