From 4137367c3a62f68a4da5379eeaa4806eff8e392b Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Sun, 11 Aug 2024 22:36:31 -0600 Subject: [PATCH] test --- .../main/java/com/dotcms/ai/app/AIModel.java | 6 +- .../java/com/dotcms/ai/client/AIClient.java | 11 ++- .../dotcms/ai/client/AIClientStrategy.java | 5 +- .../dotcms/ai/client/AIDefaultStrategy.java | 5 +- .../ai/client/AIModelFallbackStrategy.java | 77 ++++++++++++++++++- .../com/dotcms/ai/client/AIProxiedClient.java | 19 ++++- .../com/dotcms/ai/client/AIProxyClient.java | 6 +- .../dotcms/ai/client/AIResponseValidator.java | 9 +++ .../ai/client/{ => openai}/OpenAIClient.java | 11 ++- .../openai/OpenAIResponseValidator.java | 41 ++++++++++ .../java/com/dotcms/ai/domain/AIRequest.java | 10 +-- .../dotcms/ai/domain/AIResponseMetadata.java | 49 ++++++++++++ .../dotcms/ai/domain/JSONObjectAIRequest.java | 8 +- .../main/java/com/dotcms/ai/domain/Model.java | 5 ++ .../java/com/dotcms/ai/app/AIAppUtilTest.java | 9 ++- 15 files changed, 242 insertions(+), 29 deletions(-) create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java rename dotCMS/src/main/java/com/dotcms/ai/client/{ => openai}/OpenAIClient.java (90%) create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseValidator.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java index 100f721c006c..078576955fea 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java @@ -81,7 +81,7 @@ public void setCurrentModelIndex(final int currentModelIndex) { } public boolean isOperational() { - return this != NOOP_MODEL || models.stream().anyMatch(model -> model.getStatus() == ModelStatus.ACTIVE); + return this != NOOP_MODEL || getActiveModels().isEmpty(); } public Model getCurrent() { @@ -101,6 +101,10 @@ public long minIntervalBetweenCalls() { return 60000 / apiPerMinute; } + public List getActiveModels() { + return models.stream().filter(model -> model.getStatus() == ModelStatus.ACTIVE).collect(Collectors.toList()); + } + @Override public String toString() { return "AIModel{" + diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java index 4d9caa4d82be..e99296f99713 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java @@ -1,7 +1,9 @@ package com.dotcms.ai.client; +import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponseMetadata; import org.apache.http.client.methods.HttpDelete; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPatch; @@ -22,11 +24,12 @@ public AIProvider getProvider() { } @Override - public void sendRequest(final AIRequest request, final OutputStream output) { - throwUnsupported(); + public AIResponseMetadata sendRequest(final AIRequest request, + final OutputStream output) { + return throwUnsupported(); } - private void throwUnsupported() { + private AIResponseMetadata throwUnsupported() { throw new UnsupportedOperationException("Noop client does not support sending requests"); } }; @@ -49,6 +52,6 @@ static HttpUriRequest resolveMethod(final String method, final String url) { AIProvider getProvider(); - void sendRequest(final AIRequest request, final OutputStream output); + AIResponseMetadata sendRequest(AIRequest request, OutputStream output); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java index 93a82e447451..e4c484b84646 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -2,6 +2,7 @@ import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; +import io.vavr.Tuple2; import java.io.OutputStream; import java.io.Serializable; @@ -10,6 +11,8 @@ public interface AIClientStrategy { AIClientStrategy NOOP = (client, request, output) -> AIResponse.builder().build(); - void applyStrategy(AIClient client, AIRequest request, OutputStream output); + void applyStrategy(Tuple2 clientAndParser, + AIRequest request, + OutputStream output); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java index da14ed4a9a93..f5b123e981a7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -1,6 +1,7 @@ package com.dotcms.ai.client; import com.dotcms.ai.domain.AIRequest; +import io.vavr.Tuple2; import java.io.OutputStream; import java.io.Serializable; @@ -8,10 +9,10 @@ public class AIDefaultStrategy implements AIClientStrategy { @Override - public void applyStrategy(final AIClient client, + public void applyStrategy(final Tuple2 clientAndParser, final AIRequest request, final OutputStream output) { - client.sendRequest(request, output); + clientAndParser._1.sendRequest(request, output); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java index 71a4f44252ad..9b1e7ad3504c 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -1,16 +1,89 @@ package com.dotcms.ai.client; +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponseMetadata; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.Model; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.Logger; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Tuple2; +import org.apache.commons.io.IOUtils; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.OutputStream; import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.stream.Collectors; public class AIModelFallbackStrategy implements AIClientStrategy { @Override - public void applyStrategy(final AIClient client, + public void applyStrategy(final Tuple2 clientAndParser, final AIRequest request, - final OutputStream output) { + final OutputStream originalOutput) { + final JSONObject payload = ((JSONObjectAIRequest) request).getPayload(); + final String modelInPayload = payload.optString(AiKeys.MODEL); + final AIModel aiModel = request.getConfig().resolveModelOrThrow(modelInPayload); + + final List activeModels = aiModel.getActiveModels(); + if (activeModels.isEmpty()) { + Logger.debug( + this, + () -> String.format( + "There are no active models left in model fallback strategy [%s]", + aiModel.getModels().stream().map(Model::getName).collect(Collectors.joining(", ")))); + return; + } + + boolean success = false; + for (int index = 0; index < aiModel.getModels().size(); index++) { + final Model model = aiModel.getModels().get(index); + if (!model.isOperational()) { + Logger.debug("Model [%s] is not operational. Skipping.", model.getName()); + continue; + } + + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + final AIResponseMetadata metadata = clientAndParser._1.sendRequest(request, output); + final String response = output.toString(); + + clientAndParser._2.lookForError(response, metadata); + if (metadata.isSuccess()) { + try { + IOUtils.copy(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)), originalOutput); + } catch (IOException e) { + throw new DotRuntimeException(e); + } + + aiModel.setCurrentModelIndex(index); + success = true; + + break; + } + + Logger.debug( + this, + () -> String.format( + "Model [%s] failed with response [%s%s%s]. Trying next model.", + model.getName(), + System.lineSeparator(), + response, + System.lineSeparator())); + model.setStatus(metadata.getStatus()); + Logger.debug( + this, + () -> String.format( + "Model [%s] status updated to [%s].", + model.getName(), + response)); + } + } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java index 874c729fd815..a6aa99c483b8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -2,6 +2,7 @@ import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; +import io.vavr.Tuple; import java.io.ByteArrayOutputStream; import java.io.OutputStream; @@ -11,18 +12,28 @@ public class AIProxiedClient { - public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP); + public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP, null); private final AIClient client; private final AIClientStrategy strategy; + private final AIResponseValidator responseParser; - private AIProxiedClient(final AIClient client, final AIClientStrategy strategy) { + private AIProxiedClient(final AIClient client, + final AIClientStrategy strategy, + final AIResponseValidator responseParser) { this.client = client; this.strategy = strategy; + this.responseParser = responseParser; + } + + public static AIProxiedClient of(final AIClient client, + final AIProxyStrategy strategy, + final AIResponseValidator responseParser) { + return new AIProxiedClient(client, strategy.getStrategy(), responseParser); } public static AIProxiedClient of(final AIClient client, final AIProxyStrategy strategy) { - return new AIProxiedClient(client, strategy.getStrategy()); + return of(client, strategy, null); } public AIResponse callToAI(final AIRequest request, final OutputStream output) { @@ -30,7 +41,7 @@ public AIResponse callToAI(final AIRequest request, .ofNullable(output) .orElseGet(ByteArrayOutputStream::new); - strategy.applyStrategy(client, request, finalOutput); + strategy.applyStrategy(Tuple.of(client, responseParser), request, finalOutput); return (Objects.nonNull(output)) ? AIResponse.EMPTY diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java index 5c13814768c6..32650313cda2 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java @@ -1,5 +1,7 @@ package com.dotcms.ai.client; +import com.dotcms.ai.client.openai.OpenAIClient; +import com.dotcms.ai.client.openai.OpenAIResponseValidator; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; @@ -21,7 +23,9 @@ public class AIProxyClient { private AIProxyClient() { proxiedClients = new ConcurrentHashMap<>(); - addClient(AIProvider.OPEN_AI, AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.DEFAULT)); + addClient( + AIProvider.OPEN_AI, + AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseValidator.get())); currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java new file mode 100644 index 000000000000..bb96f8300248 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java @@ -0,0 +1,9 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIResponseMetadata; + +public interface AIResponseValidator { + + void lookForError(String response, AIResponseMetadata metadata); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java similarity index 90% rename from dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java rename to dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java index de346e5baa9d..d4a52690a975 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -1,11 +1,13 @@ -package com.dotcms.ai.client; +package com.dotcms.ai.client.openai; import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.client.AIClient; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponseMetadata; import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; @@ -47,7 +49,8 @@ public AIProvider getProvider() { } @Override - public void sendRequest(final AIRequest request, final OutputStream output) { + public AIResponseMetadata sendRequest(final AIRequest request, + final OutputStream output) { final AppConfig config = request.getConfig(); if (!config.isEnabled()) { Logger.debug(this, "OpenAI is not enabled and will not send request."); @@ -56,7 +59,7 @@ public void sendRequest(final AIRequest request, fin // When we get rid of JSONObject usage, we can remove this check if (!(request instanceof JSONObjectAIRequest)) { - throw new UnsupportedOperationException("Only JsonAIRequest (JSONObject) is supported"); + throw new UnsupportedOperationException("Only JSONObjectAIRequest (JSONObject) is supported"); } final JSONObject json = ((JSONObjectAIRequest) request).getPayload(); @@ -113,6 +116,8 @@ public void sendRequest(final AIRequest request, fin throw new DotRuntimeException(e); } + + return new AIResponseMetadata(model); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseValidator.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseValidator.java new file mode 100644 index 000000000000..34344dfc9fdb --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIResponseValidator.java @@ -0,0 +1,41 @@ +package com.dotcms.ai.client.openai; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.client.AIResponseValidator; +import com.dotcms.ai.domain.AIResponseMetadata; +import com.dotcms.ai.domain.ModelStatus; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; + +public class OpenAIResponseValidator implements AIResponseValidator { + + private static final Lazy INSTANCE = Lazy.of(OpenAIResponseValidator::new); + + public static OpenAIResponseValidator get() { + return INSTANCE.get(); + } + + private OpenAIResponseValidator() { + } + + @Override + public void lookForError(final String response, final AIResponseMetadata metadata) { + final JSONObject jsonResponse = new JSONObject(response); + if (jsonResponse.has(AiKeys.ERROR)) { + final String error = jsonResponse.getString(AiKeys.ERROR); + metadata.setError(error); + metadata.setStatus(resolveStatus(error)); + } + } + + private ModelStatus resolveStatus(final String error) { + if (error.contains("has been deprecated")) { + return ModelStatus.DECOMMISSIONED; + } else if (error.contains("does not exist or you do not have access to it")) { + return ModelStatus.INVALID; + } else { + return null; + } + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java index a4176fbae010..af5dda76a68f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java @@ -54,13 +54,16 @@ static String resolveUrl(final AIModelType type, final AppConfig appConfig) { } } + @SuppressWarnings("unchecked") private static , R extends AIRequest> R quick( final String url, final AppConfig appConfig, + final AIModelType type, final T payload) { return (R) AIRequest.builder() .withUrl(url) .withConfig(appConfig) + .withType(type) .withPayload(payload) .build(); } @@ -69,7 +72,7 @@ private static > R quick( final AIModelType type, final AppConfig appConfig, final T payload) { - return quick(resolveUrl(type, appConfig), appConfig, payload); + return quick(resolveUrl(type, appConfig), appConfig, type, payload); } public String getUrl() { @@ -121,11 +124,6 @@ public B withUrl(final String url) { return self(); } - public B withMethod(final String method) { - this.method = method; - return self(); - } - public B withConfig(final AppConfig config) { this.config = config; return self(); diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java new file mode 100644 index 000000000000..64ad56d7d863 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java @@ -0,0 +1,49 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModel; +import org.apache.commons.lang3.StringUtils; + +public class AIResponseMetadata { + + private final AIModel model; + private String error; + private ModelStatus status; + + public AIResponseMetadata(final AIModel model) { + this.model = model; + } + + public AIModel getModel() { + return model; + } + + public String getError() { + return error; + } + + public void setError(final String error) { + this.error = error; + } + + public ModelStatus getStatus() { + return status; + } + + public void setStatus(ModelStatus status) { + this.status = status; + } + + public boolean isSuccess() { + return StringUtils.isBlank(error); + } + + @Override + public String toString() { + return "AIResponseMetadata{" + + "model=" + model + + ", error='" + error + '\'' + + ", status=" + status + + '}'; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java index f7603be95435..3b750cc90335 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java @@ -33,10 +33,14 @@ public JSONObjectAIRequest build() { } } - private static JSONObjectAIRequest quick(final String url, final AppConfig appConfig, final JSONObject payload) { + private static JSONObjectAIRequest quick(final String url, + final AppConfig appConfig, + final AIModelType type, + final JSONObject payload) { return JSONObjectAIRequest.builder() .withUrl(url) .withConfig(appConfig) + .withType(type) .withPayload(payload) .build(); } @@ -44,7 +48,7 @@ private static JSONObjectAIRequest quick(final String url, final AppConfig appCo private static JSONObjectAIRequest quick(final AIModelType type, final AppConfig appConfig, final JSONObject payload) { - return quick(resolveUrl(type, appConfig), appConfig, payload); + return quick(resolveUrl(type, appConfig), appConfig, type, payload); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java index 03b763fbfe1c..6378854659e7 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/Model.java @@ -1,5 +1,6 @@ package com.dotcms.ai.domain; +import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; @@ -29,6 +30,10 @@ public void setStatus(final ModelStatus status) { this.status.set(status); } + public boolean isOperational() { + return List.of(ModelStatus.ACTIVE, ModelStatus.VALID).contains(status.get()); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java index f99766469a0d..d8ff3d02b120 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java @@ -1,10 +1,12 @@ package com.dotcms.ai.app; +import com.dotcms.ai.domain.Model; import com.dotcms.security.apps.Secret; import org.junit.Before; import org.junit.Test; import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -127,7 +129,7 @@ public void testCreateTextModel() { AIModel model = aiAppUtil.createTextModel(secrets); assertNotNull(model); assertEquals(AIModelType.TEXT, model.getType()); - assertTrue(model.getModels().contains("textmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("textmodel")); } /** @@ -143,7 +145,7 @@ public void testCreateImageModel() { AIModel model = aiAppUtil.createImageModel(secrets); assertNotNull(model); assertEquals(AIModelType.IMAGE, model.getType()); - assertTrue(model.getModels().contains("imagemodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("imagemodel")); } /** @@ -159,7 +161,8 @@ public void testCreateEmbeddingsModel() { AIModel model = aiAppUtil.createEmbeddingsModel(secrets); assertNotNull(model); assertEquals(AIModelType.EMBEDDINGS, model.getType()); - assertTrue(model.getModels().contains("embeddingsmodel")); + assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()) + .contains("embeddingsmodel")); } } \ No newline at end of file