Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#29284: Applying feedback #29806

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ private String getTextPrompt(final String prompt, final String supportingContent
}

private int countTokens(final String testString) {
return EncodingUtil.get().registry
.getEncodingForModel(config.getModel().getCurrentModel())
return EncodingUtil.get()
.getEncoding(config, AIModelType.TEXT)
.map(enc -> enc.countTokens(testString))
.orElseThrow(() -> new DotRuntimeException("Encoder not found"));
}
Expand Down
3 changes: 1 addition & 2 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
*/
public class AIModel {

private static final int NOOP_INDEX = -1;

public static final int NOOP_INDEX = -1;
public static final AIModel NOOP_MODEL = AIModel.builder()
.withType(AIModelType.UNKNOWN)
.withModelNames(List.of())
Expand Down
2 changes: 1 addition & 1 deletion dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ private void activateModels(final String host, boolean wasAdded) {
final String modelName = model.getName().trim().toLowerCase();
final ModelStatus status;
status = ModelStatus.ACTIVE;
if (aiModel.getCurrentModelIndex() == -1) {
if (aiModel.getCurrentModelIndex() == AIModel.NOOP_INDEX) {
aiModel.setCurrentModelIndex(model.getIndex());
}
Logger.debug(
Expand Down
10 changes: 0 additions & 10 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,6 @@ public AIModel resolveModel(final AIModelType type) {
return AIModels.get().resolveModel(host, type);
}

/**
* Resolves a model-specific secret value from the provided secrets map using the specified key and model type.
*
* @param modelName the name of the model to find
* @param type the type of the model to find
*/
public AIModel resolveAIModelOrThrow(final String modelName, final AIModelType type) {
return AIModels.get().resolveAIModelOrThrow(this, modelName, type);
}

/**
* Resolves a model-specific secret value from the provided secrets map using the specified key and model type.
* If the model is not found or is not operational, it throws an appropriate exception.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ private static Tuple2<AIModel, Model> resolveModel(final JSONObjectAIRequest req

private static boolean isSameAsFirst(final Model firstAttempt, final Model model) {
if (firstAttempt.equals(model)) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] is the same as the current one [%s].",
model.getName(),
firstAttempt.getName()));
return true;
}

Expand Down Expand Up @@ -150,11 +144,17 @@ private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
final Model model = modelTuple._2;

if (!responseData.getStatus().doesNeedToThrow()) {
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] failed then setting its status to [%s].",
model.getName(),
responseData.getStatus()));
model.setStatus(responseData.getStatus());
}

if (model.getIndex() == aiModel.getModels().size() - 1) {
aiModel.setCurrentModelIndex(-1);
aiModel.setCurrentModelIndex(AIModel.NOOP_INDEX);
AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
Expand Down Expand Up @@ -217,11 +217,10 @@ private static void logFailure(final Tuple2<AIModel, Model> modelTuple, final AI
response -> AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] failed with response:%s%s%s. Trying next model.",
"Model [%s] failed with response:%s%sTrying next model.",
modelTuple._2.getName(),
System.lineSeparator(),
response,
System.lineSeparator())),
response)),
() -> AppConfig.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
Expand Down
67 changes: 64 additions & 3 deletions dotCMS/src/main/java/com/dotcms/ai/util/EncodingUtil.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package com.dotcms.ai.util;

import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.domain.Model;
import com.dotcms.ai.domain.ModelStatus;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import io.vavr.Lazy;

import java.util.Objects;
import java.util.Optional;

/**
Expand All @@ -26,10 +32,65 @@ public static EncodingUtil get() {
return INSTANCE.get();
}

public Optional<Encoding> getEncoding(final AppConfig appConfig, final AIModelType type) {
final AIModel aiModel = appConfig.resolveModel(type);
final Model currentModel = aiModel.getCurrent();

if (Objects.isNull(currentModel)) {
AppConfig.debugLogger(
getClass(),
() -> String.format(
"No current model found for type [%s], meaning the are all are exhausted",
type));
return Optional.empty();
}

return registry
.getEncodingForModel(currentModel.getName())
.or(() -> modelFallback(aiModel, currentModel));
}

public Optional<Encoding> getEncoding() {
return Optional
.ofNullable(ConfigService.INSTANCE.config().getEmbeddingsModel().getCurrentModel())
.flatMap(registry::getEncodingForModel);
return getEncoding(ConfigService.INSTANCE.config(), AIModelType.EMBEDDINGS);
}

private Optional<Encoding> modelFallback(final AIModel aiModel,
final Model currentModel) {
AppConfig.debugLogger(
getClass(),
() -> String.format(
"Model [%s] is not suitable for encoding, marking it as invalid and falling back to other models",
currentModel.getName()));
currentModel.setStatus(ModelStatus.INVALID);

return aiModel.getModels()
.stream()
.filter(model -> !model.equals(currentModel))
.map(model -> {
if (aiModel.getCurrentModelIndex() != currentModel.getIndex()) {
return null;
}

final Optional<Encoding> encoding = registry.getEncodingForModel(model.getName());
if (encoding.isEmpty()) {
model.setStatus(ModelStatus.INVALID);
AppConfig.debugLogger(
getClass(),
() -> String.format(
"Model [%s] is not suitable for encoding, marking as invalid",
model.getName()));
return null;
}

aiModel.setCurrentModelIndex(model.getIndex());
AppConfig.debugLogger(
getClass(),
() -> "Model [" + model.getName() + "] found, setting as current model");
return encoding.get();

})
.filter(Objects::nonNull)
.findFirst();
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dotcms.ai.viewtool;

import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.util.EncodingUtil;
Expand Down Expand Up @@ -58,8 +59,8 @@ public void init(Object initData) {
* @return The number of tokens in the prompt, or -1 if no encoding is found for the model.
*/
public int countTokens(final String prompt) {
return EncodingUtil.get().registry
.getEncodingForModel(appConfig.getModel().getCurrentModel())
return EncodingUtil.get()
.getEncoding(appConfig, AIModelType.TEXT)
.map(encoding -> encoding.countTokens(prompt))
.orElse(-1);
}
Expand Down
1 change: 0 additions & 1 deletion dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ const writeModelToDropdown = async () => {
}

const newOption = document.createElement("option");
console.log(JSON.stringify(dotAiState.config, null, 2));
newOption.value = dotAiState.config.availableModels[i].name;
newOption.text = `${dotAiState.config.availableModels[i].name}`
if (dotAiState.config.availableModels[i].current) {
Expand Down
Loading