From a3d4970f7bffc580db543f0b4b23026a0820639d Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Thu, 8 Aug 2024 19:26:42 -0600 Subject: [PATCH] feat(dotAI): Adding fallback mechanism when it comes to send models to AI Provider (OpenAI) Refs: #29284 --- .../main/java/com/dotcms/ai/app/AIModels.java | 26 ++-- .../java/com/dotcms/ai/client/AIClient.java | 54 ++++++++ .../dotcms/ai/client/AIDefaultStrategy.java | 20 +++ .../ai/client/AIModelFallbackStrategy.java | 15 ++ .../com/dotcms/ai/client/AIProxiedClient.java | 28 ++++ .../java/com/dotcms/ai/client/AIProxy.java | 44 ++++++ .../com/dotcms/ai/client/AIProxyStrategy.java | 14 ++ .../ai/client/AIProxyStrategyWrapper.java | 18 +++ .../com/dotcms/ai/client/OpenAIClient.java | 128 ++++++++++++++++++ .../java/com/dotcms/ai/domain/AIProvider.java | 20 +++ .../java/com/dotcms/ai/domain/AIRequest.java | 118 ++++++++++++++++ .../java/com/dotcms/ai/domain/AIResponse.java | 47 +++++++ .../dotcms/ai/domain/JSONObjectAIRequest.java | 31 +++++ .../java/com/dotcms/ai/app/AIModelsTest.java | 3 - 14 files changed, 549 insertions(+), 17 deletions(-) create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategyWrapper.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java create mode 100644 dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java diff --git a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java index 8f88e214d9ca..8095f55ec36f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java +++ b/dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java @@ -44,21 +44,24 @@ public class AIModels { "https://api.openai.com/v1/models"); private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours private static final int AI_MODELS_CACHE_SIZE = 128; + private static final Supplier APP_CONFIG_SUPPLIER = ConfigService.INSTANCE::config; - private final ConcurrentMap>> internalModels = new ConcurrentHashMap<>(); - private final ConcurrentMap, AIModel> modelsByName = new ConcurrentHashMap<>(); - private final Cache> supportedModelsCache = - Caffeine.newBuilder() - .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) - .maximumSize(AI_MODELS_CACHE_SIZE) - .build(); - private Supplier appConfigSupplier = ConfigService.INSTANCE::config; + private final ConcurrentMap>> internalModels; + private final ConcurrentMap, AIModel> modelsByName; + private final Cache> supportedModelsCache; public static AIModels get() { return INSTANCE.get(); } private AIModels() { + internalModels = new ConcurrentHashMap<>(); + modelsByName = new ConcurrentHashMap<>(); + supportedModelsCache = + Caffeine.newBuilder() + .expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL)) + .maximumSize(AI_MODELS_CACHE_SIZE) + .build(); } /** @@ -152,7 +155,7 @@ public List getOrPullSupportedModels() { return cached; } - final AppConfig appConfig = appConfigSupplier.get(); + final AppConfig appConfig = APP_CONFIG_SUPPLIER.get(); if (!appConfig.isEnabled()) { Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models"); return List.of(); @@ -211,11 +214,6 @@ private static CircuitBreakerUrl.Response fetchOpenAIModels(final return response; } - @VisibleForTesting - void setAppConfigSupplier(final Supplier appConfigSupplier) { - this.appConfigSupplier = appConfigSupplier; - } - @VisibleForTesting void cleanSupportedModelsCache() { supportedModelsCache.invalidateAll(); diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java new file mode 100644 index 000000000000..b045a59ce94d --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java @@ -0,0 +1,54 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import org.apache.http.client.methods.HttpDelete; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPatch; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpUriRequest; + +import javax.ws.rs.HttpMethod; +import java.io.OutputStream; +import java.io.Serializable; + +public interface AIClient { + + AIClient NOOP = new AIClient() { + @Override + public AIProvider getProvider() { + return AIProvider.NONE; + } + + @Override + public OutputStream sendRequest(final AIRequest request) { + return throwUnsupported(); + } + + private OutputStream throwUnsupported() { + throw new UnsupportedOperationException("Noop client does not support sending requests"); + } + }; + + AIProvider getProvider(); + + static HttpUriRequest resolveMethod(final String method, final String url) { + switch(method) { + case HttpMethod.POST: + return new HttpPost(url); + case HttpMethod.PUT: + return new HttpPut(url); + case HttpMethod.DELETE: + return new HttpDelete(url); + case "patch": + return new HttpPatch(url); + case HttpMethod.GET: + default: + return new HttpGet(url); + } + } + + OutputStream sendRequest(final AIRequest request); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java new file mode 100644 index 000000000000..7422e2718d6a --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -0,0 +1,20 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.OutputStream; +import java.io.Serializable; + +public class AIDefaultStrategy implements AIProxyStrategy { + + @Override + public AIResponse applyStrategy(final AIClient client, final AIRequest request) { + final OutputStream output = client.sendRequest(request); + return AIResponse.builder() + .output(output) + .response(output.toString()) + .build(); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java new file mode 100644 index 000000000000..50421679ecab --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -0,0 +1,15 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.Serializable; + +public class AIModelFallbackStrategy implements AIProxyStrategy { + + @Override + public AIResponse applyStrategy(final AIClient client, final AIRequest request) { + return null; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java new file mode 100644 index 000000000000..c6597732cc4d --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -0,0 +1,28 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.Serializable; + +public class AIProxiedClient { + + public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIProxyStrategy.NOOP); + + private final AIClient client; + private final AIProxyStrategy strategy; + + private AIProxiedClient(final AIClient client, final AIProxyStrategy strategy) { + this.client = client; + this.strategy = strategy; + } + + public static AIProxiedClient of(final AIClient client, final AIProxyStrategyWrapper strategy) { + return new AIProxiedClient(client, strategy.getStrategy()); + } + + public AIResponse callToAI(final AIRequest request) { + return strategy.applyStrategy(client, request); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxy.java new file mode 100644 index 000000000000..d94eeedec333 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxy.java @@ -0,0 +1,44 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; +import io.vavr.Lazy; + +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +public class AIProxy { + + private static final Lazy INSTANCE = Lazy.of(AIProxy::new); + + private final ConcurrentMap proxiedClients; + private final AtomicReference currentProvider; + + private AIProxy() { + proxiedClients = new ConcurrentHashMap<>(); + addClient(AIProvider.OPEN_AI, AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategyWrapper.DEFAULT)); + currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); + } + + public static AIProxy get() { + return INSTANCE.get(); + } + + public AIProxiedClient getClient(final AIProvider provider) { + return proxiedClients.get(provider); + } + + public void addClient(final AIProvider provider, final AIProxiedClient client) { + proxiedClients.put(provider, client); + } + + public AIResponse sendRequest(final AIProvider provider, final AIRequest request) { + return Optional.ofNullable(proxiedClients.getOrDefault(provider, AIProxiedClient.NOOP)) + .map(client -> client.callToAI(request)) + .orElse(null); + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java new file mode 100644 index 000000000000..da5c7359a3c2 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java @@ -0,0 +1,14 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.AIResponse; + +import java.io.Serializable; + +public interface AIProxyStrategy { + + AIProxyStrategy NOOP = (client, request) -> AIResponse.builder().build(); + + AIResponse applyStrategy(final AIClient client, final AIRequest request); + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategyWrapper.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategyWrapper.java new file mode 100644 index 000000000000..98fcfb3e0ba0 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategyWrapper.java @@ -0,0 +1,18 @@ +package com.dotcms.ai.client; + +public enum AIProxyStrategyWrapper { + + DEFAULT(new AIDefaultStrategy()), + MODEL_FALLBACK(new AIModelFallbackStrategy()); + + private final AIProxyStrategy strategy; + + AIProxyStrategyWrapper(final AIProxyStrategy strategy) { + this.strategy = strategy; + } + + public AIProxyStrategy getStrategy() { + return strategy; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java new file mode 100644 index 000000000000..ef7a3abe8d0a --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java @@ -0,0 +1,128 @@ +package com.dotcms.ai.client; + +import com.dotcms.ai.AiKeys; +import com.dotcms.ai.app.AIModel; +import com.dotcms.ai.app.AppConfig; +import com.dotcms.ai.app.AppKeys; +import com.dotcms.ai.domain.AIProvider; +import com.dotcms.ai.domain.AIRequest; +import com.dotcms.ai.domain.JSONObjectAIRequest; +import com.dotcms.ai.util.OpenAIRequest; +import com.dotmarketing.exception.DotRuntimeException; +import com.dotmarketing.util.Logger; +import com.dotmarketing.util.json.JSONObject; +import io.vavr.Lazy; +import io.vavr.control.Try; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; + +import javax.ws.rs.core.MediaType; +import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.concurrent.ConcurrentHashMap; + +public class OpenAIClient implements AIClient { + + private static final Lazy INSTANCE = Lazy.of(OpenAIClient::new); + + private final ConcurrentHashMap lastRestCall; + + public static OpenAIClient get() { + return INSTANCE.get(); + } + + private OpenAIClient() { + lastRestCall = new ConcurrentHashMap<>(); + } + + @Override + public AIProvider getProvider() { + return AIProvider.OPEN_AI; + } + + @Override + public OutputStream sendRequest(final AIRequest request) { + final ByteArrayOutputStream output = new ByteArrayOutputStream(); + sendRequest(request, output); + return output; + } + + private OutputStream sendRequest(final AIRequest request, final OutputStream output) { + final AppConfig config = request.getConfig(); + if (!config.isEnabled()) { + Logger.debug(OpenAIRequest.class, "OpenAI is not enabled and will not send request."); + throw new IllegalStateException("OpenAI is not enabled"); + } + + // When we get rid of JSONObject usage, we can remove this check + if (!(request instanceof JSONObjectAIRequest)) { + throw new UnsupportedOperationException("Only JsonAIRequest (JSONObject) is supported"); + } + + final JSONObject json = ((JSONObjectAIRequest) request).getPayload(); + final AIModel model = config.resolveModelOrThrow(json.optString(AiKeys.MODEL)); + + if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { + Logger.debug(OpenAIRequest.class, "posting: " + json); + } + + final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) + + model.minIntervalBetweenCalls() + - System.currentTimeMillis(); + if (sleep > 0) { + Logger.info( + OpenAIRequest.class, + "Rate limit:" + + model.getApiPerMinute() + + "/minute, or 1 every " + + model.minIntervalBetweenCalls() + + "ms. Sleeping:" + + sleep); + Try.run(() -> Thread.sleep(sleep)); + } + + lastRestCall.put(model, System.currentTimeMillis()); + + try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); + final HttpUriRequest httpRequest = AIClient.resolveMethod(request.getMethod(), request.getUrl()); + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); + httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + config.getApiKey()); + + if (!json.getAsMap().isEmpty()) { + Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); + } + + try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { + final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); + final byte[] buffer = new byte[1024]; + int len; + while ((len = in.read(buffer)) != -1) { + output.write(buffer, 0, len); + output.flush(); + } + } + } catch (Exception e) { + if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)){ + Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e); + } else { + Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage()); + } + + Logger.warn(OpenAIRequest.class, " - " + request.getMethod() + " : " + json); + + throw new DotRuntimeException(e); + } + + return output; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java new file mode 100644 index 000000000000..e7681ca84cab --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIProvider.java @@ -0,0 +1,20 @@ +package com.dotcms.ai.domain; + +public enum AIProvider { + + NONE("None"), + OPEN_AI("OpenAI"), + BEDROCK("Amazon Bedrock"), + GEMINI("Google Gemini"); + + private final String provider; + + AIProvider(final String provider) { + this.provider = provider; + } + + public String getProvider() { + return provider; + } + +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java new file mode 100644 index 000000000000..d2b6c3566106 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java @@ -0,0 +1,118 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; + +import java.io.Serializable; + +public class AIRequest { + + private final String url; + private final String method; + private final AppConfig config; + private final AIModelType type; + private final T payload; + private boolean useOutput; + + AIRequest(final String url, + final String method, + final AppConfig config, + final AIModelType type, + final T payload, + final boolean useOutput) { + this.config = config; + this.url = url; + this.method = method; + this.type = type; + this.payload = payload; + this.useOutput = useOutput; + } + + static Builder builder() { + return new Builder<>(); + } + + public String getUrl() { + return url; + } + + public String getMethod() { + return method; + } + + public AppConfig getConfig() { + return config; + } + + public AIModelType getType() { + return type; + } + + public T getPayload() { + return payload; + } + + public boolean isUseOutput() { + return useOutput; + } + + @Override + public String toString() { + return "AIRequest{" + + "url='" + url + '\'' + + ", method='" + method + '\'' + + ", config=" + config + + ", type=" + type + + ", payload=" + payload + + ", useOutput=" + useOutput + + '}'; + } + + static class Builder { + + String url; + String method; + AppConfig config; + AIModelType type; + T data; + boolean useOutput; + + Builder() { + } + + public Builder withUrl(final String url) { + this.url = url; + return this; + } + + public Builder withMethod(final String method) { + this.method = method; + return this; + } + + public Builder withConfig(final AppConfig config) { + this.config = config; + return this; + } + + public Builder withType(final AIModelType type) { + this.type = type; + return this; + } + + public Builder withData(final T data) { + this.data = data; + return this; + } + + public Builder withUseOutput(final boolean useOutput) { + this.useOutput = useOutput; + return this; + } + + public AIRequest build() { + return new AIRequest<>(url, method, config, type, data, useOutput); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java new file mode 100644 index 000000000000..d728a164bce5 --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java @@ -0,0 +1,47 @@ +package com.dotcms.ai.domain; + +import java.io.OutputStream; + +public class AIResponse { + + private final String response; + private final OutputStream output; + + private AIResponse(final String response, final OutputStream output) { + this.response = response; + this.output = output; + } + + public static Builder builder() { + return new Builder(); + } + + public String getResponse() { + return response; + } + + public OutputStream getOutput() { + return output; + } + + public static class Builder { + + private String response; + private OutputStream output; + + public Builder response(final String response) { + this.response = response; + return this; + } + + public Builder output(final OutputStream output) { + this.output = output; + return this; + } + + public AIResponse build() { + return new AIResponse(response, output); + } + + } +} diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java new file mode 100644 index 000000000000..7d08d5b35f9c --- /dev/null +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/JSONObjectAIRequest.java @@ -0,0 +1,31 @@ +package com.dotcms.ai.domain; + +import com.dotcms.ai.app.AIModelType; +import com.dotcms.ai.app.AppConfig; +import com.dotmarketing.util.json.JSONObject; + +public class JSONObjectAIRequest extends AIRequest { + + JSONObjectAIRequest(final String url, + final String method, + final AppConfig config, + final AIModelType type, + final JSONObject data, + final boolean useOutput) { + super(url, method, config, type, data, useOutput); + } + + static Builder builder() { + return new Builder(); + } + + static class Builder extends AIRequest.Builder { + + @Override + public JSONObjectAIRequest build() { + return new JSONObjectAIRequest(url, method, config, type, data, useOutput); + } + + } + +} diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java index 2ea51fe91ab4..3b7dc69675e1 100644 --- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java +++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java @@ -145,8 +145,6 @@ public void test_getOrPullSupportedModules() throws DotDataException, DotSecurit supported = aiModels.getOrPullSupportedModels(); assertNotNull(supported); assertEquals(32, supported.size()); - - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /** @@ -164,7 +162,6 @@ public void test_getOrPullSupportedModules_invalidEndpoint() { assertTrue(supported.isEmpty()); IPUtils.disabledIpPrivateSubnet(true); - AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config); } /**