Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
victoralfaro-dotcms committed Aug 12, 2024
1 parent 4137367 commit a6653c9
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 83 deletions.
9 changes: 5 additions & 4 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.dotcms.ai.app;

import com.dotcms.ai.exception.DotAIModelNotFound;
import com.dotcms.ai.exception.DotAIModelNotOperational;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import io.vavr.control.Try;
Expand Down Expand Up @@ -103,7 +104,7 @@ public String getApiImageUrl() {
/**
* Retrieves the API Embeddings URL.
*
* @return
* @return the API Embeddings URL
*/
public String getApiEmbeddingsUrl() {
return UtilMethods.isEmpty(apiEmbeddingsUrl) ? AppKeys.API_EMBEDDINGS_URL.defaultValue : apiEmbeddingsUrl;
Expand Down Expand Up @@ -266,7 +267,7 @@ public AIModel resolveModelOrThrow(final String modelName) {
.findModel(host, modelName)
.orElseThrow(() -> {
final String supported = String.join(", ", AIModels.get().getOrPullSupportedModels(apiKey));
return new DotRuntimeException(
return new DotAIModelNotFound(
"Unable to find model: [" + modelName + "]. Only [" + supported + "] are supported ");
});

Expand All @@ -276,7 +277,7 @@ public AIModel resolveModelOrThrow(final String modelName) {
() -> String.format(
"Resolved model [%s] is not operational, avoiding its usage",
aiModel.getCurrentModel()));
throw new DotRuntimeException(String.format("Model [%s] is not operational", aiModel.getCurrentModel()));
throw new DotAIModelNotOperational(String.format("Model [%s] is not operational", aiModel.getCurrentModel()));
}

return aiModel;
Expand Down
11 changes: 5 additions & 6 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIClient.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package com.dotcms.ai.client;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.AIResponseData;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPatch;
Expand All @@ -24,12 +23,12 @@ public AIProvider getProvider() {
}

@Override
public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T> request,
final OutputStream output) {
public <T extends Serializable> AIResponseData sendRequest(final AIRequest<T> request,
final OutputStream output) {
return throwUnsupported();
}

private AIResponseMetadata throwUnsupported() {
private AIResponseData throwUnsupported() {
throw new UnsupportedOperationException("Noop client does not support sending requests");
}
};
Expand All @@ -52,6 +51,6 @@ static HttpUriRequest resolveMethod(final String method, final String url) {

AIProvider getProvider();

<T extends Serializable> AIResponseMetadata sendRequest(AIRequest<T> request, OutputStream output);
<T extends Serializable> AIResponseData sendRequest(AIRequest<T> request, OutputStream output);

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
import io.vavr.Tuple2;

import java.io.OutputStream;
import java.io.Serializable;

public interface AIClientStrategy {

AIClientStrategy NOOP = (client, request, output) -> AIResponse.builder().build();
AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build();

void applyStrategy(Tuple2<AIClient, AIResponseValidator> clientAndParser,
void applyStrategy(AIClient client,
AIResponseHandler handler,
AIRequest<? extends Serializable> request,
OutputStream output);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIRequest;
import io.vavr.Tuple2;

import java.io.OutputStream;
import java.io.Serializable;

public class AIDefaultStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final Tuple2<AIClient, AIResponseValidator> clientAndParser,
public void applyStrategy(final AIClient client,
final AIResponseHandler handler,
final AIRequest<? extends Serializable> request,
final OutputStream output) {
clientAndParser._1.sendRequest(request, output);
client.sendRequest(request, output);
}

}
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
package com.dotcms.ai.client;

import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotcms.ai.domain.AIResponseData;
import com.dotcms.ai.domain.Model;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Tuple2;
import org.apache.commons.io.IOUtils;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
Expand All @@ -24,12 +21,13 @@
public class AIModelFallbackStrategy implements AIClientStrategy {

@Override
public void applyStrategy(final Tuple2<AIClient, AIResponseValidator> clientAndParser,
public void applyStrategy(final AIClient client,
final AIResponseHandler handler,
final AIRequest<? extends Serializable> request,
final OutputStream originalOutput) {
final JSONObject payload = ((JSONObjectAIRequest) request).getPayload();
final String modelInPayload = payload.optString(AiKeys.MODEL);
final AIModel aiModel = request.getConfig().resolveModelOrThrow(modelInPayload);
final AIResponseData responseData = doSend(client, request);
if (handleResponse(originalOutput, responseData)) return;


final List<Model> activeModels = aiModel.getActiveModels();
if (activeModels.isEmpty()) {
Expand All @@ -49,41 +47,72 @@ public void applyStrategy(final Tuple2<AIClient, AIResponseValidator> clientAndP
continue;
}

final ByteArrayOutputStream output = new ByteArrayOutputStream();
final AIResponseMetadata metadata = clientAndParser._1.sendRequest(request, output);
final String response = output.toString();
if (sendAttempt(clientAndParser, request, originalOutput, aiModel, index, model)) break;
}

clientAndParser._2.lookForError(response, metadata);
if (metadata.isSuccess()) {
try {
IOUtils.copy(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)), originalOutput);
} catch (IOException e) {
throw new DotRuntimeException(e);
}
}

aiModel.setCurrentModelIndex(index);
success = true;
private static boolean handleResponse(final OutputStream originalOutput, final AIResponseData responseData) {
if (responseData.isSuccess()) {
redirectOutput(originalOutput, responseData.getResponse());
return true;
}

break;
}
return false;
}

private AIResponseData doSend(final AIClient client, final AIRequest<? extends Serializable> request) {
final ByteArrayOutputStream output = new ByteArrayOutputStream();
final AIResponseData responseData = client.sendRequest(request, output);

responseData.setResponse(output.toString());
IOUtils.closeQuietly(output);

return responseData;
}

private static void redirectOutput(final OutputStream originalOutput, final String response) {
try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) {
IOUtils.copy(input, originalOutput);
} catch (IOException e) {
throw new DotRuntimeException(e);
}
}

private boolean sendAttempt(final AIClient client,
final AIResponseHandler handler,
final AIRequest<? extends Serializable> request,
final OutputStream originalOutput) {

final AIResponseData responseData = doSend(client, request);
final String response = responseData.getResponse();

handler.handleResponse(response, responseData);
if (!responseData.isSuccess()) {
final AIModel aiModel = resolveModelFromPayload(request);
Logger.debug(
this,
() -> String.format(
"Model [%s] failed with response [%s%s%s]. Trying next model.",
model.getName(),
aiModel.getCurrentModel(),
System.lineSeparator(),
response,
System.lineSeparator()));
model.setStatus(metadata.getStatus());
Logger.debug(
this,
() -> String.format(
"Model [%s] status updated to [%s].",
model.getName(),
response));
return false;
}

redirectOutput(originalOutput, response);

return true;

/*model.setStatus(responseData.getStatus());
Logger.debug(
this,
() -> String.format(
"Model [%s] status updated to [%s].",
model.getName(),
response));*/
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
import io.vavr.Tuple;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
Expand All @@ -16,19 +15,19 @@ public class AIProxiedClient {

private final AIClient client;
private final AIClientStrategy strategy;
private final AIResponseValidator responseParser;
private final AIResponseHandler responseParser;

private AIProxiedClient(final AIClient client,
final AIClientStrategy strategy,
final AIResponseValidator responseParser) {
final AIResponseHandler responseParser) {
this.client = client;
this.strategy = strategy;
this.responseParser = responseParser;
}

public static AIProxiedClient of(final AIClient client,
final AIProxyStrategy strategy,
final AIResponseValidator responseParser) {
final AIResponseHandler responseParser) {
return new AIProxiedClient(client, strategy.getStrategy(), responseParser);
}

Expand All @@ -41,7 +40,7 @@ public <T extends Serializable> AIResponse callToAI(final AIRequest<T> request,
.ofNullable(output)
.orElseGet(ByteArrayOutputStream::new);

strategy.applyStrategy(Tuple.of(client, responseParser), request, finalOutput);
strategy.applyStrategy(client, responseParser, request, finalOutput);

return (Objects.nonNull(output))
? AIResponse.EMPTY
Expand Down
4 changes: 2 additions & 2 deletions dotCMS/src/main/java/com/dotcms/ai/client/AIProxyClient.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.dotcms.ai.client;

import com.dotcms.ai.client.openai.OpenAIClient;
import com.dotcms.ai.client.openai.OpenAIResponseValidator;
import com.dotcms.ai.client.openai.OpenAIResponseHandler;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponse;
Expand All @@ -25,7 +25,7 @@ private AIProxyClient() {
proxiedClients = new ConcurrentHashMap<>();
addClient(
AIProvider.OPEN_AI,
AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseValidator.get()));
AIProxiedClient.of(OpenAIClient.get(), AIProxyStrategy.MODEL_FALLBACK, OpenAIResponseHandler.get()));
currentProvider = new AtomicReference<>(AIProvider.OPEN_AI);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.dotcms.ai.client;

import com.dotcms.ai.domain.AIResponseData;

public interface AIResponseHandler {

void handleResponse(String response, AIResponseData metadata);

}

This file was deleted.

21 changes: 11 additions & 10 deletions dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import com.dotcms.ai.client.AIClient;
import com.dotcms.ai.domain.AIProvider;
import com.dotcms.ai.domain.AIRequest;
import com.dotcms.ai.domain.AIResponseMetadata;
import com.dotcms.ai.domain.AIResponseData;
import com.dotcms.ai.domain.JSONObjectAIRequest;
import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.util.Logger;
Expand Down Expand Up @@ -49,8 +49,8 @@ public AIProvider getProvider() {
}

@Override
public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T> request,
final OutputStream output) {
public <T extends Serializable> AIResponseData sendRequest(final AIRequest<T> request,
final OutputStream output) {
final AppConfig config = request.getConfig();
if (!config.isEnabled()) {
Logger.debug(this, "OpenAI is not enabled and will not send request.");
Expand All @@ -63,28 +63,29 @@ public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T
}

final JSONObject json = ((JSONObjectAIRequest) request).getPayload();
final AIModel model = config.resolveModelOrThrow(json.optString(AiKeys.MODEL));
final AIModel aiModel = config.resolveModelOrThrow(json.optString(AiKeys.MODEL));
final AIResponseData responseData = new AIResponseData(aiModel);

if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
Logger.debug(this, "posting: " + json);
}

final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L)
+ model.minIntervalBetweenCalls()
final long sleep = lastRestCall.computeIfAbsent(aiModel, m -> 0L)
+ aiModel.minIntervalBetweenCalls()
- System.currentTimeMillis();
if (sleep > 0) {
Logger.info(
this,
"Rate limit:"
+ model.getApiPerMinute()
+ aiModel.getApiPerMinute()
+ "/minute, or 1 every "
+ model.minIntervalBetweenCalls()
+ aiModel.minIntervalBetweenCalls()
+ "ms. Sleeping:"
+ sleep);
Try.run(() -> Thread.sleep(sleep));
}

lastRestCall.put(model, System.currentTimeMillis());
lastRestCall.put(aiModel, System.currentTimeMillis());

try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON);
Expand Down Expand Up @@ -117,7 +118,7 @@ public <T extends Serializable> AIResponseMetadata sendRequest(final AIRequest<T
throw new DotRuntimeException(e);
}

return new AIResponseMetadata(model);
return responseData;
}

}
Loading

0 comments on commit a6653c9

Please sign in to comment.