Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Aug 12, 2024
1 parent 735fd7a commit 4137367
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 29 deletions.
6 changes: 5 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/app/AIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public void setCurrentModelIndex(final int currentModelIndex) {
}

public boolean isOperational() {
return this != NOOP_MODEL || models.stream().anyMatch(model -> model.getStatus() == ModelStatus.ACTIVE);
return this != NOOP_MODEL || getActiveModels().isEmpty();
}

public Model getCurrent() {
Expand All @@ -101,6 +101,10 @@ public long minIntervalBetweenCalls() {
return 60000 / apiPerMinute;
}

public List<Model> getActiveModels() {
return models.stream().filter(model -> model.getStatus() == ModelStatus.ACTIVE).collect(Collectors.toList());
}

@Override
public String toString() {
return "AIModel{" +
Expand Down
11 changes: 7 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.dotcms.ai.client;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPatch;
Expand All @@ -22,11 +24,12 @@ public AIProvider getProvider() {
}

@Override
public <T extends Serializable> void sendRequest(final AIRequest<T> request, final OutputStream output) {
throwUnsupported();
public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T> request,
final OutputStream output) {
return throwUnsupported();
}

private void throwUnsupported() {
private AIResponseMetadata throwUnsupported() {
throw new UnsupportedOperationException("Noop client does not support sending requests");
}
};
Expand All @@ -49,6 +52,6 @@ static HttpUriRequest resolveMethod(final String method, final String url) {

AIProvider getProvider();

<T extends Serializable> void sendRequest(final AIRequest<T> request, final OutputStream output);
<T extends Serializable> AIResponseMetadata sendRequest(AIRequest<T> request, OutputStream output);

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
import io.vavr.Tuple2;

import java.io.OutputStream;
import java.io.Serializable;
Expand All @@ -10,6 +11,8 @@ public interface AIClientStrategy {

AIClientStrategy NOOP = (client, request, output) -> AIResponse.builder().build();

void applyStrategy(AIClient client, AIRequest<? extends Serializable> request, OutputStream output);
void applyStrategy(Tuple2<AIClient, AIResponseValidator> clientAndParser,
AIRequest<? extends Serializable> request,
OutputStream output);

}
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import io.vavr.Tuple2;

import java.io.OutputStream;
import java.io.Serializable;

public class AIDefaultStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final AIClient client,
public void applyStrategy(final Tuple2<AIClient, AIResponseValidator> clientAndParser,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
client.sendRequest(request, output);
clientAndParser._1.sendRequest(request, output);
}

}
Original file line number Diff line number Diff line change
@@ -1,16 +1,89 @@
package com.dotcms.ai.client;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.domain.Model;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Tuple2;
import org.apache.commons.io.IOUtils;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;

public class AIModelFallbackStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final AIClient client,
public void applyStrategy(final Tuple2<AIClient, AIResponseValidator> clientAndParser,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
final OutputStream originalOutput) {
final JSONObject payload = ((JSONObjectAIRequest) request).getPayload();
final String modelInPayload = payload.optString(AiKeys.MODEL);
final AIModel aiModel = request.getConfig().resolveModelOrThrow(modelInPayload);

final List<Model> activeModels = aiModel.getActiveModels();
if (activeModels.isEmpty()) {
Logger.debug(
this,
() -> String.format(
"There are no active models left in model fallback strategy [%s]",
aiModel.getModels().stream().map(Model::getName).collect(Collectors.joining(", "))));
return;
}

boolean success = false;
for (int index = 0; index < aiModel.getModels().size(); index++) {
final Model model = aiModel.getModels().get(index);
if (!model.isOperational()) {
Logger.debug("Model [%s] is not operational. Skipping.", model.getName());
continue;
}

final ByteArrayOutputStream output = new ByteArrayOutputStream();
final AIResponseMetadata metadata = clientAndParser._1.sendRequest(request, output);
final String response = output.toString();

clientAndParser._2.lookForError(response, metadata);
if (metadata.isSuccess()) {
try {
IOUtils.copy(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)), originalOutput);
} catch (IOException e) {
throw new DotRuntimeException(e);
}

aiModel.setCurrentModelIndex(index);
success = true;

break;
}

Logger.debug(
this,
() -> String.format(
"Model [%s] failed with response [%s%s%s]. Trying next model.",
model.getName(),
System.lineSeparator(),
response,
System.lineSeparator()));
model.setStatus(metadata.getStatus());
Logger.debug(
this,
() -> String.format(
"Model [%s] status updated to [%s].",
model.getName(),
response));
}

}

}
19 changes: 15 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
import io.vavr.Tuple;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
Expand All @@ -11,26 +12,36 @@

public class AIProxiedClient {

public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP);
public static final AIProxiedClient NOOP = new AIProxiedClient(null, AIClientStrategy.NOOP, null);

private final AIClient client;
private final AIClientStrategy strategy;
private final AIResponseValidator responseParser;

private AIProxiedClient(final AIClient client, final AIClientStrategy strategy) {
private AIProxiedClient(final AIClient client,
final AIClientStrategy strategy,
final AIResponseValidator responseParser) {
this.client = client;
this.strategy = strategy;
this.responseParser = responseParser;
}

public static AIProxiedClient of(final AIClient client,
final AIProxyStrategy strategy,
final AIResponseValidator responseParser) {
return new AIProxiedClient(client, strategy.getStrategy(), responseParser);
}

public static AIProxiedClient of(final AIClient client, final AIProxyStrategy strategy) {
return new AIProxiedClient(client, strategy.getStrategy());
return of(client, strategy, null);
}

public <T extends Serializable> AIResponse callToAI(final AIRequest<T> request, final OutputStream output) {
final OutputStream finalOutput = Optional
.ofNullable(output)
.orElseGet(ByteArrayOutputStream::new);

strategy.applyStrategy(client, request, finalOutput);
strategy.applyStrategy(Tuple.of(client, responseParser), request, finalOutput);

return (Objects.nonNull(output))
? AIResponse.EMPTY
Expand Down
6 changes: 5 additions & 1 deletion dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.dotcms.ai.client;

import com.dotcms.ai.client.openai.OpenAIClient;
import com.dotcms.ai.client.openai.OpenAIResponseValidator;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
Expand All @@ -21,7 +23,9 @@ public class AIProxyClient {

private AIProxyClient() {
proxiedClients = new ConcurrentHashMap<>();
addClient(AIProvider.OPEN_AI, AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.DEFAULT));
addClient(
AIProvider.OPEN_AI,
AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseValidator.get()));
currentProvider = new AtomicReference<>(AIProvider.OPEN_AI);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponseMetadata;

public interface AIResponseValidator {

void lookForError(String response, AIResponseMetadata metadata);

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.dotcms.ai.client;
package com.dotcms.ai.client.openai;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.client.AIClient;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
Expand Down Expand Up @@ -47,7 +49,8 @@ public AIProvider getProvider() {
}

@Override
public <T extends Serializable> void sendRequest(final AIRequest<T> request, final OutputStream output) {
public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T> request,
final OutputStream output) {
final AppConfig config = request.getConfig();
if (!config.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled and will not send request.");
Expand All @@ -56,7 +59,7 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin

// When we get rid of JSONObject usage, we can remove this check
if (!(request instanceof JSONObjectAIRequest)) {
throw new UnsupportedOperationException("Only JsonAIRequest (JSONObject) is supported");
throw new UnsupportedOperationException("Only JSONObjectAIRequest (JSONObject) is supported");
}

final JSONObject json = ((JSONObjectAIRequest) request).getPayload();
Expand Down Expand Up @@ -113,6 +116,8 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin

throw new DotRuntimeException(e);
}

return new AIResponseMetadata(model);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.dotcms.ai.client.openai;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.client.AIResponseValidator;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.ModelStatus;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;

public class OpenAIResponseValidator implements AIResponseValidator {

private static final Lazy<OpenAIResponseValidator> INSTANCE = Lazy.of(OpenAIResponseValidator::new);

public static OpenAIResponseValidator get() {
return INSTANCE.get();
}

private OpenAIResponseValidator() {
}

@Override
public void lookForError(final String response, final AIResponseMetadata metadata) {
final JSONObject jsonResponse = new JSONObject(response);
if (jsonResponse.has(AiKeys.ERROR)) {
final String error = jsonResponse.getString(AiKeys.ERROR);
metadata.setError(error);
metadata.setStatus(resolveStatus(error));
}
}

private ModelStatus resolveStatus(final String error) {
if (error.contains("has been deprecated")) {
return ModelStatus.DECOMMISSIONED;
} else if (error.contains("does not exist or you do not have access to it")) {
return ModelStatus.INVALID;
} else {
return null;
}
}

}
10 changes: 4 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/domain/AIRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ static String resolveUrl(final AIModelType type, final AppConfig appConfig) {
}
}

@SuppressWarnings("unchecked")
private static <T extends Serializable, B extends AIRequest.Builder<T, B>, R extends AIRequest<T>> R quick(
final String url,
final AppConfig appConfig,
final AIModelType type,
final T payload) {
return (R) AIRequest.<T, B>builder()
.withUrl(url)
.withConfig(appConfig)
.withType(type)
.withPayload(payload)
.build();
}
Expand All @@ -69,7 +72,7 @@ private static <T extends Serializable, R extends AIRequest<T>> R quick(
final AIModelType type,
final AppConfig appConfig,
final T payload) {
return quick(resolveUrl(type, appConfig), appConfig, payload);
return quick(resolveUrl(type, appConfig), appConfig, type, payload);
}

public String getUrl() {
Expand Down Expand Up @@ -121,11 +124,6 @@ public B withUrl(final String url) {
return self();
}

public B withMethod(final String method) {
this.method = method;
return self();
}

public B withConfig(final AppConfig config) {
this.config = config;
return self();
Expand Down
Loading

0 comments on commit 4137367

Please sign in to comment.