diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java index 0282dd1f8987..5461b12f98bf 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java @@ -1,7 +1,8 @@ package com.dotcms.ai.app; +import com.dotcms.ai.exception.DotAIModelNotFound; +import com.dotcms.ai.exception.DotAIModelNotOperational; import com.dotcms.security.apps.Secret; -import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; import com.dotmarketing.util.UtilMethods; import io.vavr.control.Try; @@ -103,7 +104,7 @@ public String getApiImageUrl() { /** * Retrieves the API Embeddings URL. * - * @return + * @return the API Embeddings URL */ public String getApiEmbeddingsUrl() { return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl; @@ -266,7 +267,7 @@ public AIModel resolveModelOrThrow(final String modelName) { .findModel(host, modelName) .orElseThrow(() -> { final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels(apiKey)); - return new DotRuntimeException( + return new DotAIModelNotFound( "Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported "); }); @@ -276,7 +277,7 @@ public AIModel resolveModelOrThrow(final String modelName) { () -> String.format( "Resolved model [%s] is not operational, avoiding its usage", aiModel.getCurrentModel())); - throw new DotRuntimeException(String.format("Model [%s] is not operational", aiModel.getCurrentModel())); + throw new DotAIModelNotOperational(String.format("Model [%s] is not operational", aiModel.getCurrentModel())); } return aiModel; diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java index e99296f99713..7e52d082e3ad 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java @@ -1,9 +1,8 @@ package com.dotcms.ai.client; -import com.dotcms.ai.app.AppConfig; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; -import com.dotcms.ai.domain.AIResponseMetadata; +import com.dotcms.ai.domain.AIResponseData; import org.apache.http.client.methods.HttpDelete; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPatch; @@ -24,12 +23,12 @@ public AIProvider getProvider() { } @Override - public AIResponseMetadata sendRequest(final AIRequest request, - final OutputStream output) { + public AIResponseData sendRequest(final AIRequest request, + final OutputStream output) { return throwUnsupported(); } - private AIResponseMetadata throwUnsupported() { + private AIResponseData throwUnsupported() { throw new UnsupportedOperationException("Noop client does not support sending requests"); } }; @@ -52,6 +51,6 @@ static HttpUriRequest resolveMethod(final String method, final String url) { AIProvider getProvider(); - AIResponseMetadata sendRequest(AIRequest request, OutputStream output); + AIResponseData sendRequest(AIRequest request, OutputStream output); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java index e4c484b84646..8529f1e3adf8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -2,16 +2,16 @@ import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; -import io.vavr.Tuple2; import java.io.OutputStream; import java.io.Serializable; public interface AIClientStrategy { - AIClientStrategy NOOP = (client, request, output) -> AIResponse.builder().build(); + AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build(); - void applyStrategy(Tuple2 clientAndParser, + void applyStrategy(AIClient client, + AIResponseHandler handler, AIRequest request, OutputStream output); diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java index f5b123e981a7..54cf914a9a02 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -1,7 +1,6 @@ package com.dotcms.ai.client; import com.dotcms.ai.domain.AIRequest; -import io.vavr.Tuple2; import java.io.OutputStream; import java.io.Serializable; @@ -9,10 +8,11 @@ public class AIDefaultStrategy implements AIClientStrategy { @Override - public void applyStrategy(final Tuple2 clientAndParser, + public void applyStrategy(final AIClient client, + final AIResponseHandler handler, final AIRequest request, final OutputStream output) { - clientAndParser._1.sendRequest(request, output); + client.sendRequest(request, output); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java index 9b1e7ad3504c..851d986467fb 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -1,20 +1,17 @@ package com.dotcms.ai.client; -import com.dotcms.ai.AiKeys; import com.dotcms.ai.app.AIModel; import com.dotcms.ai.domain.AIRequest; -import com.dotcms.ai.domain.AIResponseMetadata; -import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.domain.AIResponseData; import com.dotcms.ai.domain.Model; import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; -import com.dotmarketing.util.json.JSONObject; -import io.vavr.Tuple2; import org.apache.commons.io.IOUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; import java.nio.charset.StandardCharsets; @@ -24,12 +21,13 @@ public class AIModelFallbackStrategy implements AIClientStrategy { @Override - public void applyStrategy(final Tuple2 clientAndParser, + public void applyStrategy(final AIClient client, + final AIResponseHandler handler, final AIRequest request, final OutputStream originalOutput) { - final JSONObject payload = ((JSONObjectAIRequest) request).getPayload(); - final String modelInPayload = payload.optString(AiKeys.MODEL); - final AIModel aiModel = request.getConfig().resolveModelOrThrow(modelInPayload); + final AIResponseData responseData = doSend(client, request); + if (handleResponse(originalOutput, responseData)) return; + final List activeModels = aiModel.getActiveModels(); if (activeModels.isEmpty()) { @@ -49,41 +47,72 @@ public void applyStrategy(final Tuple2 clientAndP continue; } - final ByteArrayOutputStream output = new ByteArrayOutputStream(); - final AIResponseMetadata metadata = clientAndParser._1.sendRequest(request, output); - final String response = output.toString(); + if (sendAttempt(clientAndParser, request, originalOutput, aiModel, index, model)) break; + } - clientAndParser._2.lookForError(response, metadata); - if (metadata.isSuccess()) { - try { - IOUtils.copy(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)), originalOutput); - } catch (IOException e) { - throw new DotRuntimeException(e); - } + } - aiModel.setCurrentModelIndex(index); - success = true; + private static boolean handleResponse(final OutputStream originalOutput, final AIResponseData responseData) { + if (responseData.isSuccess()) { + redirectOutput(originalOutput, responseData.getResponse()); + return true; + } - break; - } + return false; + } + private AIResponseData doSend(final AIClient client, final AIRequest request) { + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + final AIResponseData responseData = client.sendRequest(request, output); + + responseData.setResponse(output.toString()); + IOUtils.closeQuietly(output); + + return responseData; + } + + private static void redirectOutput(final OutputStream originalOutput, final String response) { + try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { + IOUtils.copy(input, originalOutput); + } catch (IOException e) { + throw new DotRuntimeException(e); + } + } + + private boolean sendAttempt(final AIClient client, + final AIResponseHandler handler, + final AIRequest request, + final OutputStream originalOutput) { + + final AIResponseData responseData = doSend(client, request); + final String response = responseData.getResponse(); + + handler.handleResponse(response, responseData); + if (!responseData.isSuccess()) { + final AIModel aiModel = resolveModelFromPayload(request); Logger.debug( this, () -> String.format( "Model [%s] failed with response [%s%s%s]. Trying next model.", - model.getName(), + aiModel.getCurrentModel(), System.lineSeparator(), response, System.lineSeparator())); - model.setStatus(metadata.getStatus()); - Logger.debug( - this, - () -> String.format( - "Model [%s] status updated to [%s].", - model.getName(), - response)); + return false; } + redirectOutput(originalOutput, response); + + return true; + + /*model.setStatus(responseData.getStatus()); + Logger.debug( + this, + () -> String.format( + "Model [%s] status updated to [%s].", + model.getName(), + response));*/ } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java index a6aa99c483b8..08ab8375241a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -2,7 +2,6 @@ import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; -import io.vavr.Tuple; import java.io.ByteArrayOutputStream; import java.io.OutputStream; @@ -16,11 +15,11 @@ public class AIProxiedClient { private final AIClient client; private final AIClientStrategy strategy; - private final AIResponseValidator responseParser; + private final AIResponseHandler responseParser; private AIProxiedClient(final AIClient client, final AIClientStrategy strategy, - final AIResponseValidator responseParser) { + final AIResponseHandler responseParser) { this.client = client; this.strategy = strategy; this.responseParser = responseParser; @@ -28,7 +27,7 @@ private AIProxiedClient(final AIClient client, public static AIProxiedClient of(final AIClient client, final AIProxyStrategy strategy, - final AIResponseValidator responseParser) { + final AIResponseHandler responseParser) { return new AIProxiedClient(client, strategy.getStrategy(), responseParser); } @@ -41,7 +40,7 @@ public AIResponse callToAI(final AIRequest request, .ofNullable(output) .orElseGet(ByteArrayOutputStream::new); - strategy.applyStrategy(Tuple.of(client, responseParser), request, finalOutput); + strategy.applyStrategy(client, responseParser, request, finalOutput); return (Objects.nonNull(output)) ? AIResponse.EMPTY diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java index 32650313cda2..f8e0d3e772a4 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java @@ -1,7 +1,7 @@ package com.dotcms.ai.client; import com.dotcms.ai.client.openai.OpenAIClient; -import com.dotcms.ai.client.openai.OpenAIResponseValidator; +import com.dotcms.ai.client.openai.OpenAIResponseHandler; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; import com.dotcms.ai.domain.AIResponse; @@ -25,7 +25,7 @@ private AIProxyClient() { proxiedClients = new ConcurrentHashMap<>(); addClient( AIProvider.OPEN_AI, - AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseValidator.get())); + AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseHandler.get())); currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseHandler.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseHandler.java new file mode 100644 index 000000000000..c674c738398d --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseHandler.java @@ -0,0 +1,9 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIResponseData; + +public interface AIResponseHandler { + + void handleResponse(String response, AIResponseData metadata); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java deleted file mode 100644 index bb96f8300248..000000000000 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIResponseValidator.java +++ /dev/null @@ -1,9 +0,0 @@ -package com.dotcms.ai.client; - -import com.dotcms.ai.domain.AIResponseMetadata; - -public interface AIResponseValidator { - - void lookForError(String response, AIResponseMetadata metadata); - -} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java index d4a52690a975..1fa602810a08 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -7,7 +7,7 @@ import com.dotcms.ai.client.AIClient; import com.dotcms.ai.domain.AIProvider; import com.dotcms.ai.domain.AIRequest; -import com.dotcms.ai.domain.AIResponseMetadata; +import com.dotcms.ai.domain.AIResponseData; import com.dotcms.ai.domain.JSONObjectAIRequest; import com.dotmarketing.exception.DotRuntimeException; import com.dotmarketing.util.Logger; @@ -49,8 +49,8 @@ public AIProvider getProvider() { } @Override - public AIResponseMetadata sendRequest(final AIRequest request, - final OutputStream output) { + public AIResponseData sendRequest(final AIRequest request, + final OutputStream output) { final AppConfig config = request.getConfig(); if (!config.isEnabled()) { Logger.debug(this, "OpenAI is not enabled and will not send request."); @@ -63,28 +63,29 @@ public AIResponseMetadata sendRequest(final AIRequest 0L) - + model.minIntervalBetweenCalls() + final long sleep = lastRestCall.computeIfAbsent(aiModel, m -> 0L) + + aiModel.minIntervalBetweenCalls() - System.currentTimeMillis(); if (sleep > 0) { Logger.info( this, "Rate limit:" - + model.getApiPerMinute() + + aiModel.getApiPerMinute() + "/minute, or 1 every " - + model.minIntervalBetweenCalls() + + aiModel.minIntervalBetweenCalls() + "ms. Sleeping:" + sleep); Try.run(() -> Thread.sleep(sleep)); } - lastRestCall.put(model, System.currentTimeMillis()); + lastRestCall.put(aiModel, System.currentTimeMillis()); try (CloseableHttpClient httpClient = HttpClients.createDefault()) { final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); @@ -117,7 +118,7 @@ public AIResponseMetadata sendRequest(final AIRequest INSTANCE = Lazy.of(OpenAIResponseValidator::new); + private static final Lazy INSTANCE = Lazy.of(OpenAIResponseHandler::new); - public static OpenAIResponseValidator get() { + public static OpenAIResponseHandler get() { return INSTANCE.get(); } - private OpenAIResponseValidator() { + private OpenAIResponseHandler() { } @Override - public void lookForError(final String response, final AIResponseMetadata metadata) { + public void handleResponse(final String response, final AIResponseData metadata) { final JSONObject jsonResponse = new JSONObject(response); if (jsonResponse.has(AiKeys.ERROR)) { final String error = jsonResponse.getString(AiKeys.ERROR); diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java similarity index 74% rename from dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java rename to dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java index 64ad56d7d863..1ba1a2ef1d4b 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseMetadata.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponseData.java @@ -3,13 +3,14 @@ import com.dotcms.ai.app.AIModel; import org.apache.commons.lang3.StringUtils; -public class AIResponseMetadata { +public class AIResponseData { private final AIModel model; + private String response; private String error; private ModelStatus status; - public AIResponseMetadata(final AIModel model) { + public AIResponseData(final AIModel model) { this.model = model; } @@ -17,6 +18,14 @@ public AIModel getModel() { return model; } + public String getResponse() { + return response; + } + + public void setResponse(String response) { + this.response = response; + } + public String getError() { return error; } @@ -41,6 +50,7 @@ public boolean isSuccess() { public String toString() { return "AIResponseMetadata{" + "model=" + model + + ", response='" + response + '\'' + ", error='" + error + '\'' + ", status=" + status + '}'; diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFound.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFound.java new file mode 100644 index 000000000000..3c1a4a3d6ad0 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotFound.java @@ -0,0 +1,11 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +public class DotAIModelNotFound extends DotRuntimeException { + + public DotAIModelNotFound(String message) { + super(message); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperational.java b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperational.java new file mode 100644 index 000000000000..2e4aac05ed3c --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/exception/DotAIModelNotOperational.java @@ -0,0 +1,11 @@ +package com.dotcms.ai.exception; + +import com.dotmarketing.exception.DotRuntimeException; + +public class DotAIModelNotOperational extends DotRuntimeException { + + public DotAIModelNotOperational(String message) { + super(message); + } + +} diff --git a/parent/pom.xml b/parent/pom.xml index be41efd3b0bd..4572a123e924 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -77,7 +77,7 @@ true ${project.build.directory}/starter - empty_20240719 + 20240729 starter.zip ${starter.deploy.version} starter-${starter.run.version}.zip