-
Notifications
You must be signed in to change notification settings - Fork 467
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(dotAI): Adding fallback mechanism when it comes to send models t…
…o AI Provider (OpenAI) Refs: #29284
- Loading branch information
1 parent
5205ee6
commit a3d4970
Showing
14 changed files
with
549 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIProvider; | ||
import com.dotcms.ai.domain.AIRequest; | ||
import org.apache.http.client.methods.HttpDelete; | ||
import org.apache.http.client.methods.HttpGet; | ||
import org.apache.http.client.methods.HttpPatch; | ||
import org.apache.http.client.methods.HttpPost; | ||
import org.apache.http.client.methods.HttpPut; | ||
import org.apache.http.client.methods.HttpUriRequest; | ||
|
||
import javax.ws.rs.HttpMethod; | ||
import java.io.OutputStream; | ||
import java.io.Serializable; | ||
|
||
public interface AIClient { | ||
|
||
AIClient NOOP = new AIClient() { | ||
@Override | ||
public AIProvider getProvider() { | ||
return AIProvider.NONE; | ||
} | ||
|
||
@Override | ||
public OutputStream sendRequest(final AIRequest<? extends Serializable> request) { | ||
return throwUnsupported(); | ||
} | ||
|
||
private OutputStream throwUnsupported() { | ||
throw new UnsupportedOperationException("Noop client does not support sending requests"); | ||
} | ||
}; | ||
|
||
AIProvider getProvider(); | ||
|
||
static HttpUriRequest resolveMethod(final String method, final String url) { | ||
switch(method) { | ||
case HttpMethod.POST: | ||
return new HttpPost(url); | ||
case HttpMethod.PUT: | ||
return new HttpPut(url); | ||
case HttpMethod.DELETE: | ||
return new HttpDelete(url); | ||
case "patch": | ||
return new HttpPatch(url); | ||
case HttpMethod.GET: | ||
default: | ||
return new HttpGet(url); | ||
} | ||
} | ||
|
||
OutputStream sendRequest(final AIRequest<? extends Serializable> request); | ||
|
||
} |
20 changes: 20 additions & 0 deletions
20
dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.AIResponse; | ||
|
||
import java.io.OutputStream; | ||
import java.io.Serializable; | ||
|
||
public class AIDefaultStrategy implements AIProxyStrategy { | ||
|
||
@Override | ||
public AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request) { | ||
final OutputStream output = client.sendRequest(request); | ||
return AIResponse.builder() | ||
.output(output) | ||
.response(output.toString()) | ||
.build(); | ||
} | ||
|
||
} |
15 changes: 15 additions & 0 deletions
15
dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.AIResponse; | ||
|
||
import java.io.Serializable; | ||
|
||
public class AIModelFallbackStrategy implements AIProxyStrategy { | ||
|
||
@Override | ||
public AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request) { | ||
return null; | ||
} | ||
|
||
} |
28 changes: 28 additions & 0 deletions
28
dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.AIResponse; | ||
|
||
import java.io.Serializable; | ||
|
||
public class AIProxiedClient { | ||
|
||
public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIProxyStrategy.NOOP); | ||
|
||
private final AIClient client; | ||
private final AIProxyStrategy strategy; | ||
|
||
private AIProxiedClient(final AIClient client, final AIProxyStrategy strategy) { | ||
this.client = client; | ||
this.strategy = strategy; | ||
} | ||
|
||
public static AIProxiedClient of(final AIClient client, final AIProxyStrategyWrapper strategy) { | ||
return new AIProxiedClient(client, strategy.getStrategy()); | ||
} | ||
|
||
public AIResponse callToAI(final AIRequest<? extends Serializable> request) { | ||
return strategy.applyStrategy(client, request); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIProvider; | ||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.AIResponse; | ||
import io.vavr.Lazy; | ||
|
||
import java.util.Optional; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.ConcurrentMap; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
public class AIProxy { | ||
|
||
private static final Lazy<AIProxy> INSTANCE = Lazy.of(AIProxy::new); | ||
|
||
private final ConcurrentMap<AIProvider, AIProxiedClient> proxiedClients; | ||
private final AtomicReference<AIProvider> currentProvider; | ||
|
||
private AIProxy() { | ||
proxiedClients = new ConcurrentHashMap<>(); | ||
addClient(AIProvider.OPEN_AI, AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategyWrapper.DEFAULT)); | ||
currentProvider = new AtomicReference<>(AIProvider.OPEN_AI); | ||
} | ||
|
||
public static AIProxy get() { | ||
return INSTANCE.get(); | ||
} | ||
|
||
public AIProxiedClient getClient(final AIProvider provider) { | ||
return proxiedClients.get(provider); | ||
} | ||
|
||
public void addClient(final AIProvider provider, final AIProxiedClient client) { | ||
proxiedClients.put(provider, client); | ||
} | ||
|
||
public AIResponse sendRequest(final AIProvider provider, final AIRequest request) { | ||
return Optional.ofNullable(proxiedClients.getOrDefault(provider, AIProxiedClient.NOOP)) | ||
.map(client -> client.callToAI(request)) | ||
.orElse(null); | ||
} | ||
|
||
} |
14 changes: 14 additions & 0 deletions
14
dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.AIResponse; | ||
|
||
import java.io.Serializable; | ||
|
||
public interface AIProxyStrategy { | ||
|
||
AIProxyStrategy NOOP = (client, request) -> AIResponse.builder().build(); | ||
|
||
AIResponse applyStrategy(final AIClient client, final AIRequest<? extends Serializable> request); | ||
|
||
} |
18 changes: 18 additions & 0 deletions
18
dotCMS/src/main/java/com/dotcms/ai/client/AIProxyStrategyWrapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package com.dotcms.ai.client; | ||
|
||
public enum AIProxyStrategyWrapper { | ||
|
||
DEFAULT(new AIDefaultStrategy()), | ||
MODEL_FALLBACK(new AIModelFallbackStrategy()); | ||
|
||
private final AIProxyStrategy strategy; | ||
|
||
AIProxyStrategyWrapper(final AIProxyStrategy strategy) { | ||
this.strategy = strategy; | ||
} | ||
|
||
public AIProxyStrategy getStrategy() { | ||
return strategy; | ||
} | ||
|
||
} |
128 changes: 128 additions & 0 deletions
128
dotCMS/src/main/java/com/dotcms/ai/client/OpenAIClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package com.dotcms.ai.client; | ||
|
||
import com.dotcms.ai.AiKeys; | ||
import com.dotcms.ai.app.AIModel; | ||
import com.dotcms.ai.app.AppConfig; | ||
import com.dotcms.ai.app.AppKeys; | ||
import com.dotcms.ai.domain.AIProvider; | ||
import com.dotcms.ai.domain.AIRequest; | ||
import com.dotcms.ai.domain.JSONObjectAIRequest; | ||
import com.dotcms.ai.util.OpenAIRequest; | ||
import com.dotmarketing.exception.DotRuntimeException; | ||
import com.dotmarketing.util.Logger; | ||
import com.dotmarketing.util.json.JSONObject; | ||
import io.vavr.Lazy; | ||
import io.vavr.control.Try; | ||
import org.apache.http.HttpHeaders; | ||
import org.apache.http.client.methods.CloseableHttpResponse; | ||
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; | ||
import org.apache.http.client.methods.HttpUriRequest; | ||
import org.apache.http.entity.ContentType; | ||
import org.apache.http.entity.StringEntity; | ||
import org.apache.http.impl.client.CloseableHttpClient; | ||
import org.apache.http.impl.client.HttpClients; | ||
|
||
import javax.ws.rs.core.MediaType; | ||
import java.io.BufferedInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.OutputStream; | ||
import java.io.Serializable; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
public class OpenAIClient implements AIClient { | ||
|
||
private static final Lazy<OpenAIClient> INSTANCE = Lazy.of(OpenAIClient::new); | ||
|
||
private final ConcurrentHashMap<AIModel, Long> lastRestCall; | ||
|
||
public static OpenAIClient get() { | ||
return INSTANCE.get(); | ||
} | ||
|
||
private OpenAIClient() { | ||
lastRestCall = new ConcurrentHashMap<>(); | ||
} | ||
|
||
@Override | ||
public AIProvider getProvider() { | ||
return AIProvider.OPEN_AI; | ||
} | ||
|
||
@Override | ||
public OutputStream sendRequest(final AIRequest<? extends Serializable> request) { | ||
final ByteArrayOutputStream output = new ByteArrayOutputStream(); | ||
sendRequest(request, output); | ||
return output; | ||
} | ||
|
||
private OutputStream sendRequest(final AIRequest<? extends Serializable> request, final OutputStream output) { | ||
final AppConfig config = request.getConfig(); | ||
if (!config.isEnabled()) { | ||
Logger.debug(OpenAIRequest.class, "OpenAI is not enabled and will not send request."); | ||
throw new IllegalStateException("OpenAI is not enabled"); | ||
} | ||
|
||
// When we get rid of JSONObject usage, we can remove this check | ||
if (!(request instanceof JSONObjectAIRequest)) { | ||
throw new UnsupportedOperationException("Only JsonAIRequest (JSONObject) is supported"); | ||
} | ||
|
||
final JSONObject json = ((JSONObjectAIRequest) request).getPayload(); | ||
final AIModel model = config.resolveModelOrThrow(json.optString(AiKeys.MODEL)); | ||
|
||
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) { | ||
Logger.debug(OpenAIRequest.class, "posting: " + json); | ||
} | ||
|
||
final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L) | ||
+ model.minIntervalBetweenCalls() | ||
- System.currentTimeMillis(); | ||
if (sleep > 0) { | ||
Logger.info( | ||
OpenAIRequest.class, | ||
"Rate limit:" | ||
+ model.getApiPerMinute() | ||
+ "/minute, or 1 every " | ||
+ model.minIntervalBetweenCalls() | ||
+ "ms. Sleeping:" | ||
+ sleep); | ||
Try.run(() -> Thread.sleep(sleep)); | ||
} | ||
|
||
lastRestCall.put(model, System.currentTimeMillis()); | ||
|
||
try (CloseableHttpClient httpClient = HttpClients.createDefault()) { | ||
final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON); | ||
final HttpUriRequest httpRequest = AIClient.resolveMethod(request.getMethod(), request.getUrl()); | ||
httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); | ||
httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + config.getApiKey()); | ||
|
||
if (!json.getAsMap().isEmpty()) { | ||
Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); | ||
} | ||
|
||
try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { | ||
final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); | ||
final byte[] buffer = new byte[1024]; | ||
int len; | ||
while ((len = in.read(buffer)) != -1) { | ||
output.write(buffer, 0, len); | ||
output.flush(); | ||
} | ||
} | ||
} catch (Exception e) { | ||
if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)){ | ||
Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e); | ||
} else { | ||
Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage()); | ||
} | ||
|
||
Logger.warn(OpenAIRequest.class, " - " + request.getMethod() + " : " + json); | ||
|
||
throw new DotRuntimeException(e); | ||
} | ||
|
||
return output; | ||
} | ||
|
||
} |
Oops, something went wrong.