Skip to content

Commit

Permalink
feat(dotAI): Adding fallback mechanism when it comes to send models t…
Browse files Browse the repository at this point in the history
…o AI Provider (OpenAI)

Refs: #29284
  • Loading branch information
victoralfaro-dotcms committed Aug 9, 2024
1 parent 5205ee6 commit a3d4970
Show file tree
Hide file tree
Showing 14 changed files with 549 additions and 17 deletions.
26 changes: 12 additions & 14 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,24 @@ public class AIModels {
"https://api.openai.com/v1/models");
private static final int AI_MODELS_CACHE_TTL = 28800; // 8 hours
private static final int AI_MODELS_CACHE_SIZE = 128;
private static final Supplier<AppConfig> APP_CONFIG_SUPPLIER = ConfigService.INSTANCE::config;

private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels = new ConcurrentHashMap<>();
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName = new ConcurrentHashMap<>();
private final Cache<String, List<String>> supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
.build();
private Supplier<AppConfig> appConfigSupplier = ConfigService.INSTANCE::config;
private final ConcurrentMap<String, List<Tuple2<AIModelType, AIModel>>> internalModels;
private final ConcurrentMap<Tuple2<String, String>, AIModel> modelsByName;
private final Cache<String, List<String>> supportedModelsCache;

public static AIModels get() {
return INSTANCE.get();
}

private AIModels() {
internalModels = new ConcurrentHashMap<>();
modelsByName = new ConcurrentHashMap<>();
supportedModelsCache =
Caffeine.newBuilder()
.expireAfterWrite(Duration.ofSeconds(AI_MODELS_CACHE_TTL))
.maximumSize(AI_MODELS_CACHE_SIZE)
.build();
}

/**
Expand Down Expand Up @@ -152,7 +155,7 @@ public List<String> getOrPullSupportedModels() {
return cached;
}

final AppConfig appConfig = appConfigSupplier.get();
final AppConfig appConfig = APP_CONFIG_SUPPLIER.get();
if (!appConfig.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled, returning empty list of supported models");
return List.of();
Expand Down Expand Up @@ -211,11 +214,6 @@ private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final
return response;
}

@VisibleForTesting
void setAppConfigSupplier(final Supplier<AppConfig> appConfigSupplier) {
this.appConfigSupplier = appConfigSupplier;
}

@VisibleForTesting
void cleanSupportedModelsCache() {
supportedModelsCache.invalidateAll();
Expand Down
54 changes: 54 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPatch;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpUriRequest;

import javax.ws.rs.HttpMethod;
import java.io.OutputStream;
import java.io.Serializable;

public interface AIClient {

AIClient NOOP = new AIClient() {
@Override
public AIProvider getProvider() {
return AIProvider.NONE;
}

@Override
public OutputStream sendRequest(final AIRequest<? extends Serializable> request) {
return throwUnsupported();
}

private OutputStream throwUnsupported() {
throw new UnsupportedOperationException("Noop client does not support sending requests");
}
};

AIProvider getProvider();

static HttpUriRequest resolveMethod(final String method, final String url) {
switch(method) {
case HttpMethod.POST:
return new HttpPost(url);
case HttpMethod.PUT:
return new HttpPut(url);
case HttpMethod.DELETE:
return new HttpDelete(url);
case "patch":
return new HttpPatch(url);
case HttpMethod.GET:
default:
return new HttpGet(url);
}
}

OutputStream sendRequest(final AIRequest<? extends Serializable> request);

}
20 changes: 20 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;

import java.io.OutputStream;
import java.io.Serializable;

public class AIDefaultStrategy implements AIProxyStrategy {

@Override
public AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request) {
final OutputStream output = client.sendRequest(request);
return AIResponse.builder()
.output(output)
.response(output.toString())
.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;

import java.io.Serializable;

public class AIModelFallbackStrategy implements AIProxyStrategy {

@Override
public AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request) {
return null;
}

}
28 changes: 28 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;

import java.io.Serializable;

public class AIProxiedClient {

public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIProxyStrategy.NOOP);

private final AIClient client;
private final AIProxyStrategy strategy;

private AIProxiedClient(final AIClient client, final AIProxyStrategy strategy) {
this.client = client;
this.strategy = strategy;
}

public static AIProxiedClient of(final AIClient client, final AIProxyStrategyWrapper strategy) {
return new AIProxiedClient(client, strategy.getStrategy());
}

public AIResponse callToAI(final AIRequest<? extends Serializable> request) {
return strategy.applyStrategy(client, request);
}

}
44 changes: 44 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIProxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
import io.vavr.Lazy;

import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;

public class AIProxy {

private static final Lazy<AIProxy> INSTANCE = Lazy.of(AIProxy::new);

private final ConcurrentMap<AIProvider, AIProxiedClient> proxiedClients;
private final AtomicReference<AIProvider> currentProvider;

private AIProxy() {
proxiedClients = new ConcurrentHashMap<>();
addClient(AIProvider.OPEN_AI, AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategyWrapper.DEFAULT));
currentProvider = new AtomicReference<>(AIProvider.OPEN_AI);
}

public static AIProxy get() {
return INSTANCE.get();
}

public AIProxiedClient getClient(final AIProvider provider) {
return proxiedClients.get(provider);
}

public void addClient(final AIProvider provider, final AIProxiedClient client) {
proxiedClients.put(provider, client);
}

public AIResponse sendRequest(final AIProvider provider, final AIRequest request) {
return Optional.ofNullable(proxiedClients.getOrDefault(provider, AIProxiedClient.NOOP))
.map(client -> client.callToAI(request))
.orElse(null);
}

}
14 changes: 14 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;

import java.io.Serializable;

public interface AIProxyStrategy {

AIProxyStrategy NOOP = (client, request) -> AIResponse.builder().build();

AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.dotcms.ai.client;

public enum AIProxyStrategyWrapper {

DEFAULT(new AIDefaultStrategy()),
MODEL_FALLBACK(new AIModelFallbackStrategy());

private final AIProxyStrategy strategy;

AIProxyStrategyWrapper(final AIProxyStrategy strategy) {
this.strategy = strategy;
}

public AIProxyStrategy getStrategy() {
return strategy;
}

}
128 changes: 128 additions & 0 deletions dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package com.dotcms.ai.client;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.util.OpenAIRequest;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;
import io.vavr.control.Try;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;

import javax.ws.rs.core.MediaType;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.concurrent.ConcurrentHashMap;

public class OpenAIClient implements AIClient {

private static final Lazy<OpenAIClient> INSTANCE = Lazy.of(OpenAIClient::new);

private final ConcurrentHashMap<AIModel, Long> lastRestCall;

public static OpenAIClient get() {
return INSTANCE.get();
}

private OpenAIClient() {
lastRestCall = new ConcurrentHashMap<>();
}

@Override
public AIProvider getProvider() {
return AIProvider.OPEN_AI;
}

@Override
public OutputStream sendRequest(final AIRequest<? extends Serializable> request) {
final ByteArrayOutputStream output = new ByteArrayOutputStream();
sendRequest(request, output);
return output;
}

private OutputStream sendRequest(final AIRequest<? extends Serializable> request, final OutputStream output) {
final AppConfig config = request.getConfig();
if (!config.isEnabled()) {
Logger.debug(OpenAIRequest.class, "OpenAI is not enabled and will not send request.");
throw new IllegalStateException("OpenAI is not enabled");
}

// When we get rid of JSONObject usage, we can remove this check
if (!(request instanceof JSONObjectAIRequest)) {
throw new UnsupportedOperationException("Only JsonAIRequest (JSONObject) is supported");
}

final JSONObject json = ((JSONObjectAIRequest) request).getPayload();
final AIModel model = config.resolveModelOrThrow(json.optString(AiKeys.MODEL));

if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.debug(OpenAIRequest.class, "posting: " + json);
}

final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L)
+ model.minIntervalBetweenCalls()
- System.currentTimeMillis();
if (sleep > 0) {
Logger.info(
OpenAIRequest.class,
"Rate limit:"
+ model.getApiPerMinute()
+ "/minute, or 1 every "
+ model.minIntervalBetweenCalls()
+ "ms. Sleeping:"
+ sleep);
Try.run(() -> Thread.sleep(sleep));
}

lastRestCall.put(model, System.currentTimeMillis());

try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON);
final HttpUriRequest httpRequest = AIClient.resolveMethod(request.getMethod(), request.getUrl());
httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON);
httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + config.getApiKey());

if (!json.getAsMap().isEmpty()) {
Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity));
}

try (CloseableHttpResponse response = httpClient.execute(httpRequest)) {
final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent());
final byte[] buffer = new byte[1024];
int len;
while ((len = in.read(buffer)) != -1) {
output.write(buffer, 0, len);
output.flush();
}
}
} catch (Exception e) {
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)){
Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e);
} else {
Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage());
}

Logger.warn(OpenAIRequest.class, " - " + request.getMethod() + " : " + json);

throw new DotRuntimeException(e);
}

return output;
}

}
Loading

0 comments on commit a3d4970

Please sign in to comment.