Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dotAI): Adding fallback mechanism when it comes to send models #29761

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIAppUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.liferay.util.StringPool;
import io.vavr.Lazy;
import io.vavr.control.Try;
import org.apache.commons.collections4.CollectionUtils;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -40,9 +41,14 @@ public static AIAppUtil get() {
* @return the created text model instance
*/
public AIModel createTextModel(final Map<String, Secret> secrets) {
final List<String> modelNames = splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES);
if (CollectionUtils.isEmpty(modelNames)) {
return AIModel.NOOP_MODEL;
}

return AIModel.builder()
.withType(AIModelType.TEXT)
.withNames(splitDiscoveredSecret(secrets, AppKeys.TEXT_MODEL_NAMES))
.withModelNames(modelNames)
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.TEXT_MODEL_MAX_TOKENS))
Expand All @@ -57,9 +63,14 @@ public AIModel createTextModel(final Map<String, Secret> secrets) {
* @return the created image model instance
*/
public AIModel createImageModel(final Map<String, Secret> secrets) {
final List<String> modelNames = splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES);
if (CollectionUtils.isEmpty(modelNames)) {
return AIModel.NOOP_MODEL;
}

return AIModel.builder()
.withType(AIModelType.IMAGE)
.withNames(splitDiscoveredSecret(secrets, AppKeys.IMAGE_MODEL_NAMES))
.withModelNames(modelNames)
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.IMAGE_MODEL_MAX_TOKENS))
Expand All @@ -74,9 +85,14 @@ public AIModel createImageModel(final Map<String, Secret> secrets) {
* @return the created embeddings model instance
*/
public AIModel createEmbeddingsModel(final Map<String, Secret> secrets) {
final List<String> modelNames = splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES);
if (CollectionUtils.isEmpty(modelNames)) {
return AIModel.NOOP_MODEL;
}

return AIModel.builder()
.withType(AIModelType.EMBEDDINGS)
.withNames(splitDiscoveredSecret(secrets, AppKeys.EMBEDDINGS_MODEL_NAMES))
.withModelNames(modelNames)
.withTokensPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_TOKENS_PER_MINUTE))
.withApiPerMinute(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_API_PER_MINUTE))
.withMaxTokens(discoverIntSecret(secrets, AppKeys.EMBEDDINGS_MODEL_MAX_TOKENS))
Expand Down Expand Up @@ -117,9 +133,11 @@ public String discoverSecret(final Map<String, Secret> secrets, final AppKeys ke
* @return the list of split secret values
*/
public List<String> splitDiscoveredSecret(final Map<String, Secret> secrets, final AppKeys key) {
return Arrays.stream(Optional.ofNullable(discoverSecret(secrets, key)).orElse(StringPool.BLANK).split(","))
.map(String::trim)
.map(String::toLowerCase)
return Arrays
.stream(Optional
.ofNullable(discoverSecret(secrets, key))
.map(secret -> secret.split(StringPool.COMMA))
.orElse(new String[0]))
.collect(Collectors.toList());
}

Expand Down
128 changes: 74 additions & 54 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package com.dotcms.ai.app;

import com.dotcms.ai.domain.Model;
import com.dotcms.ai.exception.DotAIModelNotFoundException;
import com.dotcms.util.DotPreconditions;
import com.dotmarketing.util.Logger;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* Represents an AI model with various attributes such as type, names, tokens per minute,
Expand All @@ -18,43 +21,38 @@
*/
public class AIModel {

private static final int NOOP_INDEX = -1;

public static final AIModel NOOP_MODEL = AIModel.builder()
.withType(AIModelType.UNKNOWN)
.withNames(List.of())
.withModelNames(List.of())
.build();

private final AIModelType type;
private final List<String> names;
private final List<Model> models;
private final int tokensPerMinute;
private final int apiPerMinute;
private final int maxTokens;
private final boolean isCompletion;
private final AtomicInteger current;
private final AtomicBoolean decommissioned;

private AIModel(final AIModelType type,
final List<String> names,
final int tokensPerMinute,
final int apiPerMinute,
final int maxTokens,
final boolean isCompletion) {
DotPreconditions.checkNotNull(type, "type cannot be null");
this.type = type;
this.names = Optional.ofNullable(names).orElse(List.of());
this.tokensPerMinute = tokensPerMinute;
this.apiPerMinute = apiPerMinute;
this.maxTokens = maxTokens;
this.isCompletion = isCompletion;
current = new AtomicInteger(this.names.isEmpty() ? -1 : 0);
decommissioned = new AtomicBoolean(false);
private final AtomicInteger currentModelIndex;

private AIModel(final Builder builder) {
DotPreconditions.checkNotNull(builder.type, "type cannot be null");
this.type = builder.type;
this.models = builder.models;
this.tokensPerMinute = builder.tokensPerMinute;
this.apiPerMinute = builder.apiPerMinute;
this.maxTokens = builder.maxTokens;
this.isCompletion = builder.isCompletion;
currentModelIndex = new AtomicInteger(this.models.isEmpty() ? NOOP_INDEX : 0);
}

public AIModelType getType() {
return type;
}

public List<String> getNames() {
return names;
public List<Model> getModels() {
return models;
}

public int getTokensPerMinute() {
Expand All @@ -73,38 +71,49 @@ public boolean isCompletion() {
return isCompletion;
}

public int getCurrent() {
return current.get();
public int getCurrentModelIndex() {
return currentModelIndex.get();
}

public void setCurrent(final int current) {
if (!isCurrentValid(current)) {
logInvalidModelMessage();
return;
}
this.current.set(current);
}

public boolean isDecommissioned() {
return decommissioned.get();
}

public void setDecommissioned(final boolean decommissioned) {
this.decommissioned.set(decommissioned);
public void setCurrentModelIndex(final int currentModelIndex) {
this.currentModelIndex.set(currentModelIndex);
}

public boolean isOperational() {
return this != NOOP_MODEL;
return this != NOOP_MODEL && models.stream().anyMatch(Model::isOperational);
}

public String getCurrentModel() {
final int currentIndex = this.current.get();
public Model getCurrent() {
final int currentIndex = currentModelIndex.get();
if (!isCurrentValid(currentIndex)) {
logInvalidModelMessage();
return null;
}
return models.get(currentIndex);
}

return names.get(currentIndex);
public String getCurrentModel() {
return Optional.ofNullable(getCurrent()).map(Model::getName).orElse(null);
}

public Model getModel(final String modelName) {
final String normalized = modelName.trim().toLowerCase();
return models.stream()
.filter(model -> normalized.equals(model.getName()))
.findFirst()
.orElseThrow(() -> new DotAIModelNotFoundException(String.format("Model [%s] not found", modelName)));
}

public void repairCurrentIndexIfNeeded() {
if (getCurrentModelIndex() != NOOP_INDEX) {
return;
}

setCurrentModelIndex(
getModels()
.stream()
.filter(Model::isOperational).findFirst().map(Model::getIndex)
.orElse(NOOP_INDEX));
}

public long minIntervalBetweenCalls() {
Expand All @@ -115,22 +124,21 @@ public long minIntervalBetweenCalls() {
public String toString() {
return "AIModel{" +
"type=" + type +
", names=" + names +
", models='" + models + '\'' +
", tokensPerMinute=" + tokensPerMinute +
", apiPerMinute=" + apiPerMinute +
", maxTokens=" + maxTokens +
", isCompletion=" + isCompletion +
", current=" + current +
", decommissioned=" + decommissioned +
", currentModelIndex=" + currentModelIndex.get() +
'}';
}

private boolean isCurrentValid(final int current) {
return !names.isEmpty() && current >= 0 && current < names.size();
return !models.isEmpty() && current >= 0 && current < models.size();
}

private void logInvalidModelMessage() {
Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", names.size()));
Logger.debug(getClass(), String.format("Current model index must be between 0 and %d", models.size()));
}

public static Builder builder() {
Expand All @@ -140,7 +148,7 @@ public static Builder builder() {
public static class Builder {

private AIModelType type;
private List<String> names;
private List<Model> models;
private int tokensPerMinute;
private int apiPerMinute;
private int maxTokens;
Expand All @@ -154,13 +162,25 @@ public Builder withType(final AIModelType type) {
return this;
}

public Builder withNames(final List<String> names) {
this.names = names;
public Builder withModels(final List<Model> models) {
this.models = Optional.ofNullable(models).orElse(List.of());
return this;
}

public Builder withNames(final String... names) {
return withNames(List.of(names));
public Builder withModelNames(final List<String> names) {
return withModels(
Optional.ofNullable(names)
.map(modelNames -> IntStream.range(0, modelNames.size())
.mapToObj(index -> Model.builder()
.withName(modelNames.get(index))
.withIndex(index)
.build())
.collect(Collectors.toList()))
.orElse(List.of()));
}

public Builder withModelNames(final String... names) {
return withModelNames(List.of(names));
}

public Builder withTokensPerMinute(final int tokensPerMinute) {
Expand All @@ -184,7 +204,7 @@ public Builder withIsCompletion(final boolean isCompletion) {
}

public AIModel build() {
return new AIModel(type, names, tokensPerMinute, apiPerMinute, maxTokens, isCompletion);
return new AIModel(this);
}

}
Expand Down
16 changes: 9 additions & 7 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,22 @@ public void loadModels(final String host, final List<AIModel> loading) {
loading.stream()
.map(model -> Tuple.of(model.getType(), model))
.collect(Collectors.toList())));
loading.forEach(model -> model
.getNames()
.forEach(name -> {
loading.forEach(aiModel -> aiModel
.getModels()
.forEach(model -> {
final Tuple2<String, String> key = Tuple.of(
host,
name.toLowerCase().trim());
model.getName().toLowerCase().trim());
if (modelsByName.containsKey(key)) {
Logger.debug(
this,
String.format(
"Model [%s] already exists for host [%s], ignoring it",
name,
model.getName(),
host));
return;
}
modelsByName.putIfAbsent(key, model);
modelsByName.putIfAbsent(key, aiModel);
}));
}

Expand Down Expand Up @@ -192,7 +192,9 @@ public List<SimpleModel> getAvailableModels() {
.stream()
.flatMap(entry -> entry.getValue().stream())
.map(Tuple2::_2)
.flatMap(model -> model.getNames().stream().map(name -> new SimpleModel(name, model.getType())))
.flatMap(aiModel -> aiModel.getModels()
.stream()
.map(model -> new SimpleModel(model.getName(), aiModel.getType())))
.collect(Collectors.toSet());
final Set<SimpleModel> supported = getOrPullSupportedModels()
.stream()
Expand Down
32 changes: 14 additions & 18 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dotcms.ai.app;

import com.dotcms.ai.exception.DotAIModelNotFoundException;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
Expand Down Expand Up @@ -113,6 +114,15 @@ public static void setSystemHostConfig(final AppConfig systemHostConfig) {
AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig);
}

/**
* Retrieves the host.
*
* @return the host
*/
public String getHost() {
return host;
}

/**
* Retrieves the API URL.
*
Expand All @@ -134,7 +144,7 @@ public String getApiImageUrl() {
/**
* Retrieves the API Embeddings URL.
*
* @return
* @return the API Embeddings URL
*/
public String getApiEmbeddingsUrl() {
return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl;
Expand Down Expand Up @@ -293,24 +303,10 @@ public AIModel resolveModel(final AIModelType type) {
* @param modelName the name of the model to find
*/
public AIModel resolveModelOrThrow(final String modelName) {
final AIModel aiModel = AIModels.get()
return AIModels.get()
.findModel(host, modelName)
.orElseThrow(() -> {
final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels());
return new DotRuntimeException(
"Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported ");
});

if (!aiModel.isOperational()) {
debugLogger(
AppConfig.class,
() -> String.format(
"Resolved model [%s] is not operational, avoiding its usage",
aiModel.getCurrentModel()));
throw new DotRuntimeException(String.format("Model [%s] is not operational", aiModel.getCurrentModel()));
}

return aiModel;
.orElseThrow(() ->
new DotAIModelNotFoundException(String.format("Unable to find model: [%s].", modelName)));
}

/**
Expand Down
Loading
Loading