diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
index d980647f7b41..38ea42606703 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPI.java
@@ -1,9 +1,7 @@
package com.dotcms.ai.api;
-import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotmarketing.util.json.JSONObject;
-import io.vavr.Lazy;
import java.io.OutputStream;
@@ -37,9 +35,10 @@ public interface CompletionsAPI {
* this method takes a prompt in the form of json and returns a json AI response based upon that prompt
*
* @param promptJSON
+ * @param userId
* @return
*/
- JSONObject raw(JSONObject promptJSON);
+ JSONObject raw(JSONObject promptJSON, String userId);
/**
* this method takes a prompt and returns the AI response based upon that prompt
@@ -58,9 +57,15 @@ public interface CompletionsAPI {
* @param model
* @param temperature
* @param maxTokens
+ * @param userId
* @return
*/
- JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens);
+ JSONObject prompt(String systemPrompt,
+ String userPrompt,
+ String model,
+ float temperature,
+ int maxTokens,
+ String userId);
/**
* this method takes a prompt in the form of json and returns streaming AI response based upon that prompt
diff --git a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
index 008f65609096..56af6ddab956 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/api/CompletionsAPIImpl.java
@@ -2,13 +2,17 @@
import com.dotcms.ai.AiKeys;
import com.dotcms.ai.app.AIModel;
+import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
+import com.dotcms.ai.client.AIProxyClient;
import com.dotcms.ai.db.EmbeddingsDTO;
+import com.dotcms.ai.domain.AIResponse;
+import com.dotcms.ai.domain.JSONObjectAIRequest;
+import com.dotcms.ai.domain.Model;
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotcms.ai.util.EncodingUtil;
-import com.dotcms.ai.util.OpenAIRequest;
import com.dotcms.api.web.HttpServletRequestThreadLocal;
import com.dotcms.mock.request.FakeHttpRequest;
import com.dotcms.mock.response.BaseResponse;
@@ -16,18 +20,18 @@
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.exception.DotRuntimeException;
-import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.dotmarketing.util.json.JSONArray;
import com.dotmarketing.util.json.JSONObject;
+import com.liferay.portal.model.User;
import io.vavr.Lazy;
+import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.velocity.context.Context;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
-import javax.ws.rs.HttpMethod;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
@@ -42,15 +46,13 @@
public class CompletionsAPIImpl implements CompletionsAPI {
private final AppConfig config;
- private final Lazy defaultConfig;
public CompletionsAPIImpl(final AppConfig config) {
- defaultConfig =
- Lazy.of(() -> ConfigService.INSTANCE.config(
- Try.of(() -> WebAPILocator
- .getHostWebAPI()
- .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
- .getOrElse(APILocator.systemHost())));
+ final Lazy defaultConfig = Lazy.of(() -> ConfigService.INSTANCE.config(
+ Try.of(() -> WebAPILocator
+ .getHostWebAPI()
+ .getCurrentHostNoThrow(HttpServletRequestThreadLocal.INSTANCE.getRequest()))
+ .getOrElse(APILocator.systemHost())));
this.config = Optional.ofNullable(config).orElse(defaultConfig.get());
}
@@ -59,8 +61,9 @@ public JSONObject prompt(final String systemPrompt,
final String userPrompt,
final String modelIn,
final float temperature,
- final int maxTokens) {
- final AIModel model = config.resolveModelOrThrow(modelIn);
+ final int maxTokens,
+ final String userId) {
+ final Model model = config.resolveModelOrThrow(modelIn, AIModelType.TEXT)._2;
final JSONObject json = new JSONObject();
json.put(AiKeys.TEMPERATURE, temperature);
@@ -70,15 +73,17 @@ public JSONObject prompt(final String systemPrompt,
json.put(AiKeys.MAX_TOKENS, maxTokens);
}
- json.put(AiKeys.MODEL, model.getCurrentModel());
+ json.put(AiKeys.MODEL, model.getName());
- return raw(json);
+ return raw(json, userId);
}
@Override
public JSONObject summarize(final CompletionsForm summaryRequest) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
- final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
+ final List localResults = APILocator.getDotAIAPI()
+ .getEmbeddingsAPI()
+ .getEmbeddingResults(searcher);
// send all this as a json blob to OpenAI
final JSONObject json = buildRequestJson(summaryRequest, localResults);
@@ -87,58 +92,64 @@ public JSONObject summarize(final CompletionsForm summaryRequest) {
}
json.put(AiKeys.STREAM, false);
- final String openAiResponse =
- Try.of(() -> OpenAIRequest.doRequest(
- config.getApiUrl(),
- HttpMethod.POST,
- config,
- json))
- .getOrElseThrow(DotRuntimeException::new);
- final JSONObject dotCMSResponse = APILocator.getDotAIAPI().getEmbeddingsAPI().reduceChunksToContent(searcher, localResults);
+ final String openAiResponse = Try.of(() -> sendRequest(config, json, getUserIdIfNotNull(summaryRequest.user)))
+ .getOrElseThrow(DotRuntimeException::new)
+ .getResponse();
+ final JSONObject dotCMSResponse = APILocator.getDotAIAPI()
+ .getEmbeddingsAPI()
+ .reduceChunksToContent(searcher, localResults);
dotCMSResponse.put(AiKeys.OPEN_AI_RESPONSE, new JSONObject(openAiResponse));
return dotCMSResponse;
}
@Override
- public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream out) {
+ public void summarizeStream(final CompletionsForm summaryRequest, final OutputStream output) {
final EmbeddingsDTO searcher = EmbeddingsDTO.from(summaryRequest).build();
final List localResults = APILocator.getDotAIAPI().getEmbeddingsAPI().getEmbeddingResults(searcher);
final JSONObject json = buildRequestJson(summaryRequest, localResults);
json.put(AiKeys.STREAM, true);
- OpenAIRequest.doPost(config.getApiUrl(), config, json, out);
+ AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(
+ config,
+ json,
+ getUserIdIfNotNull(summaryRequest.user)),
+ output);
}
@Override
- public JSONObject raw(final JSONObject json) {
- if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
- Logger.info(this.getClass(), "OpenAI request:" + json.toString(2));
- }
+ public JSONObject raw(final JSONObject json, final String userId) {
+ AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2));
- final String response = OpenAIRequest.doRequest(
- config.getApiUrl(),
- HttpMethod.POST,
- config,
- json);
- if (config.getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
- Logger.info(this.getClass(), "OpenAI response:" + response);
- }
+ final String response = sendRequest(config, json, userId).getResponse();
+ AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response);
return new JSONObject(response);
}
@Override
- public JSONObject raw(CompletionsForm promptForm) {
+ public JSONObject raw(final CompletionsForm promptForm) {
JSONObject jsonObject = buildRequestJson(promptForm);
- return raw(jsonObject);
+ return raw(jsonObject, getUserIdIfNotNull(promptForm.user));
}
@Override
- public void rawStream(final CompletionsForm promptForm, final OutputStream out) {
+ public void rawStream(final CompletionsForm promptForm, final OutputStream output) {
final JSONObject json = buildRequestJson(promptForm);
json.put(AiKeys.STREAM, true);
- OpenAIRequest.doRequest(config.getApiUrl(), HttpMethod.POST, config, json, out);
+ AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(
+ config,
+ json,
+ getUserIdIfNotNull(promptForm.user)),
+ output);
+ }
+
+ private String getUserIdIfNotNull(final User user) {
+ return Optional.ofNullable(user).map(User::getUserId).orElse(null);
+ }
+
+ private AIResponse sendRequest(final AppConfig appConfig, final JSONObject payload, final String userId) {
+ return AIProxyClient.get().sendRequest(JSONObjectAIRequest.quickText(appConfig, payload, userId));
}
private void buildMessages(final String systemPrompt, final String userPrompt, final JSONObject json) {
@@ -151,7 +162,7 @@ private void buildMessages(final String systemPrompt, final String userPrompt, f
}
private JSONObject buildRequestJson(final CompletionsForm form, final List searchResults) {
- final AIModel model = config.resolveModelOrThrow(form.model);
+ final Tuple2 modelTuple = config.resolveModelOrThrow(form.model, AIModelType.TEXT);
// aggregate matching results into text
final StringBuilder supportingContent = new StringBuilder();
searchResults.forEach(s -> supportingContent.append(s.extractedText).append(" "));
@@ -162,7 +173,7 @@ private JSONObject buildRequestJson(final CompletionsForm form, final List
+ *
+ * @param event the {@link AppSecretSavedEvent} that triggered the notification
+ */
@Override
public void notify(final AppSecretSavedEvent event) {
if (Objects.isNull(event)) {
@@ -51,7 +68,9 @@ public void notify(final AppSecretSavedEvent event) {
final Host host = Try.of(() -> hostAPI.find(hostId, APILocator.systemUser(), false)).getOrNull();
Optional.ofNullable(host).ifPresent(found -> AIModels.get().resetModels(found.getHostname()));
- ConfigService.INSTANCE.config(host);
+ final AppConfig appConfig = ConfigService.INSTANCE.config(host);
+
+ AIAppValidator.get().validateAIConfig(appConfig, event.getUserId());
}
@Override
diff --git a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java
index 9739bab313eb..5c5a7b24d5ef 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java
@@ -3,6 +3,7 @@
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.db.EmbeddingsDTO;
+import com.dotcms.ai.exception.DotAIAppConfigDisabledException;
import com.dotcms.content.elasticsearch.business.event.ContentletArchiveEvent;
import com.dotcms.content.elasticsearch.business.event.ContentletDeletedEvent;
import com.dotcms.content.elasticsearch.business.event.ContentletPublishEvent;
@@ -10,7 +11,6 @@
import com.dotcms.system.event.local.model.Subscriber;
import com.dotmarketing.beans.Host;
import com.dotmarketing.business.APILocator;
-import com.dotmarketing.exception.DotRuntimeException;
import com.dotmarketing.portlets.contentlet.model.Contentlet;
import com.dotmarketing.portlets.contentlet.model.ContentletListener;
import com.dotmarketing.util.Logger;
@@ -86,7 +86,7 @@ private AppConfig getAppConfig(final String hostId) {
AppConfig.debugLogger(
getClass(),
() -> "dotAI is not enabled since no API urls or API key found in app config");
- throw new DotRuntimeException("App dotAI config without API urls or API key");
+ throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key");
}
return appConfig;
diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java
index 53f83c3ab149..e289b64c9a0d 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/model/AIImageRequestDTO.java
@@ -1,6 +1,7 @@
package com.dotcms.ai.model;
+import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.fasterxml.jackson.annotation.JsonSetter;
import com.fasterxml.jackson.annotation.Nulls;
@@ -15,12 +16,11 @@ public class AIImageRequestDTO {
private final String model;
- public AIImageRequestDTO(Builder builder) {
+ public AIImageRequestDTO(final Builder builder) {
this.numberOfImages = builder.numberOfImages;
this.model = builder.model;
this.prompt = builder.prompt;
this.size = builder.size;
-
}
public String getSize() {
@@ -40,14 +40,15 @@ public String getModel() {
}
public static class Builder {
+ private AppConfig appConfig = ConfigService.INSTANCE.config();
@JsonSetter(nulls = Nulls.SKIP)
private String prompt;
@JsonSetter(nulls = Nulls.SKIP)
private int numberOfImages = 1;
@JsonSetter(nulls = Nulls.SKIP)
- private String size = ConfigService.INSTANCE.config().getImageSize();
+ private String size = appConfig.getImageSize();
@JsonSetter(nulls = Nulls.SKIP)
- private String model = ConfigService.INSTANCE.config().getImageModel().getCurrentModel();
+ private String model = appConfig.getImageModel().getCurrentModel();
public AIImageRequestDTO build() {
return new AIImageRequestDTO(this);
diff --git a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java
index c5486b61191f..b24e042e853f 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/model/SimpleModel.java
@@ -17,16 +17,20 @@ public class SimpleModel implements Serializable {
private final String name;
private final AIModelType type;
+ private final boolean current;
@JsonCreator
- public SimpleModel(@JsonProperty("name") final String name, @JsonProperty("type") final AIModelType type) {
+ public SimpleModel(@JsonProperty("name") final String name,
+ @JsonProperty("type") final AIModelType type,
+ @JsonProperty("current") final boolean current) {
this.name = name;
this.type = type;
+ this.current = current;
}
@JsonCreator
public SimpleModel(@JsonProperty("name") final String name) {
- this(name, null);
+ this(name, null, false);
}
public String getName() {
@@ -37,17 +41,30 @@ public AIModelType getType() {
return type;
}
+ public boolean isCurrent() {
+ return current;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SimpleModel that = (SimpleModel) o;
- return Objects.equals(name, that.name);
+ return Objects.equals(name, that.name) && type == that.type;
}
@Override
public int hashCode() {
- return Objects.hashCode(name);
+ return Objects.hash(name, type);
+ }
+
+ @Override
+ public String toString() {
+ return "SimpleModel{" +
+ "name='" + name + '\'' +
+ ", type=" + type +
+ ", current=" + current +
+ '}';
}
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java
index e7b62cf46712..5499de4ce660 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java
@@ -61,7 +61,9 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req
response,
formIn,
() -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn),
- out -> APILocator.getDotAIAPI().getCompletionsAPI().summarizeStream(formIn, new LineReadingOutputStream(out)));
+ output -> APILocator.getDotAIAPI()
+ .getCompletionsAPI()
+ .summarizeStream(formIn, new LineReadingOutputStream(output)));
}
/**
@@ -84,7 +86,9 @@ public final Response rawPrompt(@Context final HttpServletRequest request,
response,
formIn,
() -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn),
- out -> APILocator.getDotAIAPI().getCompletionsAPI().rawStream(formIn, new LineReadingOutputStream(out)));
+ output -> APILocator.getDotAIAPI()
+ .getCompletionsAPI()
+ .rawStream(formIn, new LineReadingOutputStream(output)));
}
/**
@@ -107,16 +111,15 @@ public final Response getConfig(@Context final HttpServletRequest request,
.init()
.getUser();
final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request);
- final AppConfig app = ConfigService.INSTANCE.config(host);
-
+ final AppConfig appConfig = ConfigService.INSTANCE.config(host);
final Map map = new HashMap<>();
map.put(AiKeys.CONFIG_HOST, host.getHostname() + " (falls back to system host)");
for (final AppKeys config : AppKeys.values()) {
- map.put(config.key, app.getConfig(config));
+ map.put(config.key, appConfig.getConfig(config));
}
- final String apiKey = UtilMethods.isSet(app.getApiKey()) ? "*****" : "NOT SET";
+ final String apiKey = UtilMethods.isSet(appConfig.getApiKey()) ? "*****" : "NOT SET";
map.put(AppKeys.API_KEY.key, apiKey);
final List models = AIModels.get().getAvailableModels();
@@ -140,19 +143,25 @@ private static CompletionsForm resolveForm(final HttpServletRequest request,
.init()
.getUser();
final Host host = WebAPILocator.getHostWebAPI().getCurrentHostNoThrow(request);
- return (!user.isAdmin())
- ? CompletionsForm
- .copy(formIn)
- .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel())
- .build()
- : formIn;
+ return withUserId(
+ !user.isAdmin()
+ ? CompletionsForm
+ .copy(formIn)
+ .model(ConfigService.INSTANCE.config(host).getModel().getCurrentModel())
+ .build()
+ : formIn,
+ user);
+ }
+
+ private static CompletionsForm withUserId(final CompletionsForm completionsForm, final User user) {
+ return CompletionsForm.copy(completionsForm).user(user).build();
}
private static Response getResponse(final HttpServletRequest request,
final HttpServletResponse response,
final CompletionsForm formIn,
final Supplier noStream,
- final Consumer stream) {
+ final Consumer outputStream) {
if (StringUtils.isBlank(formIn.prompt)) {
return badRequestResponse();
}
@@ -162,7 +171,7 @@ private static Response getResponse(final HttpServletRequest request,
if (resolvedForm.stream) {
final StreamingOutput streaming = output -> {
- stream.accept(output);
+ outputStream.accept(output);
output.flush();
output.close();
};
@@ -174,5 +183,4 @@ private static Response getResponse(final HttpServletRequest request,
return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build();
}
-
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java
index 375625d58adf..e536de66e87c 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/rest/ImageResource.java
@@ -1,7 +1,7 @@
package com.dotcms.ai.rest;
import com.dotcms.ai.AiKeys;
-import com.dotcms.ai.Marshaller;
+import com.dotcms.ai.util.Marshaller;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.model.AIImageRequestDTO;
diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java
index fae06a565d3b..f0a05c50f4a4 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/rest/TextResource.java
@@ -10,6 +10,7 @@
import com.dotmarketing.util.UtilMethods;
import com.dotmarketing.util.json.JSONArray;
import com.dotmarketing.util.json.JSONObject;
+import com.liferay.portal.model.User;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -56,7 +57,7 @@ public Response doGet(@Context final HttpServletRequest request,
*
* @param request the HTTP request
* @param response the HTTP response
- * @param formIn the form data containing the prompt
+ * @param form the form data containing the prompt
* @return a Response object containing the generated text
* @throws IOException if an I/O error occurs
*/
@@ -65,13 +66,14 @@ public Response doGet(@Context final HttpServletRequest request,
@Produces(MediaType.APPLICATION_JSON)
public Response doPost(@Context final HttpServletRequest request,
@Context final HttpServletResponse response,
- final CompletionsForm formIn) throws IOException {
+ final CompletionsForm form) throws IOException {
- new WebResource.InitBuilder(request, response)
+ final User user = new WebResource.InitBuilder(request, response)
.requiredBackendUser(true)
.requiredFrontendUser(true)
.init()
.getUser();
+ final CompletionsForm formIn = CompletionsForm.copy(form).user(user).build();
if (UtilMethods.isEmpty(formIn.prompt)) {
return Response
@@ -82,7 +84,12 @@ public Response doPost(@Context final HttpServletRequest request,
final AppConfig config = ConfigService.INSTANCE.config(WebAPILocator.getHostWebAPI().getHost(request));
- return Response.ok(APILocator.getDotAIAPI().getCompletionsAPI().raw(generateRequest(formIn, config)).toString()).build();
+ return Response.ok(
+ APILocator.getDotAIAPI()
+ .getCompletionsAPI()
+ .raw(generateRequest(formIn, config), user.getUserId())
+ .toString())
+ .build();
}
/**
diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java
index f4eb199d4bf2..2e1f58923556 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/CompletionsForm.java
@@ -8,6 +8,7 @@
import com.fasterxml.jackson.annotation.JsonSetter;
import com.fasterxml.jackson.annotation.Nulls;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+import com.liferay.portal.model.User;
import io.vavr.control.Try;
import javax.validation.constraints.Max;
@@ -49,6 +50,7 @@ public class CompletionsForm {
public final String model;
public final String operator;
public final String site;
+ public final User user;
@Override
public boolean equals(final Object o) {
@@ -88,6 +90,7 @@ public String toString() {
", operator='" + operator + '\'' +
", site='" + site + '\'' +
", contentType=" + Arrays.toString(contentType) +
+ ", user=" + user +
'}';
}
@@ -118,6 +121,7 @@ private CompletionsForm(final Builder builder) {
this.temperature = builder.temperature >= 2 ? 2 : builder.temperature;
}
this.model = UtilMethods.isSet(builder.model) ? builder.model : ConfigService.INSTANCE.config().getModel().getCurrentModel();
+ this.user = builder.user;
}
private String validateBuilderQuery(final String query) {
@@ -131,7 +135,6 @@ private long validateLanguage(final String language) {
return Try.of(() -> Long.parseLong(language))
.recover(x -> APILocator.getLanguageAPI().getLanguage(language).getId())
.getOrElseTry(() -> APILocator.getLanguageAPI().getDefaultLanguage().getId());
-
}
public static Builder copy(final CompletionsForm form) {
@@ -149,7 +152,8 @@ public static Builder copy(final CompletionsForm form) {
.operator(form.operator)
.indexName(form.indexName)
.threshold(form.threshold)
- .stream(form.stream);
+ .stream(form.stream)
+ .user(form.user);
}
public static final class Builder {
@@ -182,6 +186,8 @@ public static final class Builder {
private String operator = "cosine";
@JsonSetter(nulls = Nulls.SKIP)
private String site;
+ @JsonSetter(nulls = Nulls.SKIP)
+ private User user;
public Builder prompt(String queryOrPrompt) {
this.prompt = queryOrPrompt;
@@ -224,7 +230,7 @@ public Builder fieldVar(String fieldVar) {
}
public Builder model(String model) {
- this.model =model;
+ this.model = model;
return this;
}
@@ -254,7 +260,12 @@ public Builder operator(String operator) {
}
public Builder site(String site) {
- this.site =site;
+ this.site = site;
+ return this;
+ }
+
+ public Builder user(User user) {
+ this.user = user;
return this;
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java
index 61815b1307eb..62c61fa9d229 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/rest/forms/EmbeddingsForm.java
@@ -1,7 +1,6 @@
package com.dotcms.ai.rest.forms;
import com.dotcms.ai.app.AppConfig;
-import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.util.UtilMethods;
@@ -65,8 +64,6 @@ public static final Builder copy(EmbeddingsForm form) {
.fields(String.join(",", form.fields))
.velocityTemplate(form.velocityTemplate)
.indexName(form.indexName);
-
-
}
@Override
@@ -103,7 +100,6 @@ public String toString() {
'}';
}
-
public static final class Builder {
@JsonSetter(nulls = Nulls.SKIP)
public String fields;
@@ -135,7 +131,6 @@ public Builder limit(int limit) {
return this;
}
-
public Builder offset(int offset) {
this.offset = offset;
return this;
@@ -161,10 +156,10 @@ public Builder velocityTemplate(String velocityTemplate) {
return this;
}
-
public EmbeddingsForm build() {
return new EmbeddingsForm(this);
-
}
+
}
+
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java
similarity index 98%
rename from dotCMS/src/main/java/com/dotcms/ai/Marshaller.java
rename to dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java
index fc39f5f88e8c..0f92396e50be 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/Marshaller.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/util/Marshaller.java
@@ -1,4 +1,4 @@
-package com.dotcms.ai;
+package com.dotcms.ai.util;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
diff --git a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java b/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java
deleted file mode 100644
index b2a9b9adf789..000000000000
--- a/dotCMS/src/main/java/com/dotcms/ai/util/OpenAIRequest.java
+++ /dev/null
@@ -1,189 +0,0 @@
-package com.dotcms.ai.util;
-
-import com.dotcms.ai.AiKeys;
-import com.dotcms.ai.app.AIModel;
-import com.dotcms.ai.app.AppConfig;
-import com.dotcms.ai.app.AppKeys;
-import com.dotcms.ai.app.ConfigService;
-import com.dotmarketing.exception.DotRuntimeException;
-import com.dotmarketing.util.Logger;
-import com.dotmarketing.util.json.JSONObject;
-import io.vavr.control.Try;
-import org.apache.http.HttpHeaders;
-import org.apache.http.client.methods.*;
-import org.apache.http.entity.ContentType;
-import org.apache.http.entity.StringEntity;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.impl.client.HttpClients;
-
-import javax.ws.rs.HttpMethod;
-import javax.ws.rs.core.MediaType;
-import java.io.BufferedInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.OutputStream;
-import java.util.concurrent.ConcurrentHashMap;
-
-/**
- * The OpenAIRequest class is a utility class that handles HTTP requests to the OpenAI API.
- * It provides methods for sending GET, POST, PUT, DELETE, and PATCH requests.
- * This class also manages rate limiting for the OpenAI API by keeping track of the last time a request was made.
- *
- * This class is implemented as a singleton, meaning that only one instance of the class is created throughout the execution of the program.
- */
-public class OpenAIRequest {
-
- private static final ConcurrentHashMap lastRestCall = new ConcurrentHashMap<>();
-
- private OpenAIRequest() {}
-
- /**
- * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload.
- * The response from the request is written to the provided OutputStream.
- * This method also manages rate limiting for the OpenAI API by keeping track of the last time a request was made.
- *
- * @param urlIn the URL to send the request to
- * @param method the HTTP method to use for the request
- * @param appConfig the AppConfig object containing the OpenAI API key and models
- * @param json the JSON payload to send with the request
- * @param out the OutputStream to write the response to
- */
- public static void doRequest(final String urlIn,
- final String method,
- final AppConfig appConfig,
- final JSONObject json,
- final OutputStream out) {
- AppConfig.debugLogger(
- OpenAIRequest.class,
- () -> String.format(
- "Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s",
- urlIn,
- method,
- System.lineSeparator(),
- appConfig.toString(),
- System.lineSeparator(),
- json.toString(2)));
-
- if (!appConfig.isEnabled()) {
- AppConfig.debugLogger(OpenAIRequest.class, () -> "App dotAI is not enabled and will not send request.");
- throw new DotRuntimeException("App dotAI config without API urls or API key");
- }
-
- final AIModel model = appConfig.resolveModelOrThrow(json.optString(AiKeys.MODEL));
- final long sleep = lastRestCall.computeIfAbsent(model, m -> 0L)
- + model.minIntervalBetweenCalls()
- - System.currentTimeMillis();
- if (sleep > 0) {
- Logger.info(
- OpenAIRequest.class,
- "Rate limit:"
- + model.getApiPerMinute()
- + "/minute, or 1 every "
- + model.minIntervalBetweenCalls()
- + "ms. Sleeping:"
- + sleep);
- Try.run(() -> Thread.sleep(sleep));
- }
-
- lastRestCall.put(model, System.currentTimeMillis());
-
- try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
- final StringEntity jsonEntity = new StringEntity(json.toString(), ContentType.APPLICATION_JSON);
- final HttpUriRequest httpRequest = resolveMethod(method, urlIn);
- httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON);
- httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey());
-
- if (!json.getAsMap().isEmpty()) {
- Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity));
- }
-
- try (CloseableHttpResponse response = httpClient.execute(httpRequest)) {
- final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent());
- final byte[] buffer = new byte[1024];
- int len;
- while ((len = in.read(buffer)) != -1) {
- out.write(buffer, 0, len);
- out.flush();
- }
- }
- } catch (Exception e) {
- if (ConfigService.INSTANCE.config().getConfigBoolean(AppKeys.DEBUG_LOGGING)){
- Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage(), e);
- } else {
- Logger.warn(OpenAIRequest.class, "INVALID REQUEST: " + e.getMessage());
- }
-
- Logger.warn(OpenAIRequest.class, " - " + method + " : " +json);
-
- throw new DotRuntimeException(e);
- }
- }
-
- /**
- * Sends a request to the specified URL with the specified method, OpenAI API key, and JSON payload.
- * The response from the request is returned as a string.
- *
- * @param url the URL to send the request to
- * @param method the HTTP method to use for the request
- * @param appConfig the AppConfig object containing the OpenAI API key and models
- * @param json the JSON payload to send with the request
- * @return the response from the request as a string
- */
- public static String doRequest(final String url,
- final String method,
- final AppConfig appConfig,
- final JSONObject json) {
- final ByteArrayOutputStream out = new ByteArrayOutputStream();
- doRequest(url, method, appConfig, json, out);
-
- return out.toString();
- }
-
- /**
- * Sends a POST request to the specified URL with the specified OpenAI API key and JSON payload.
- * The response from the request is written to the provided OutputStream.
- *
- * @param urlIn the URL to send the request to
- * @param appConfig the AppConfig object containing the OpenAI API key and models
- * @param json the JSON payload to send with the request
- * @param out the OutputStream to write the response to
- */
- public static void doPost(final String urlIn,
- final AppConfig appConfig,
- final JSONObject json,
- final OutputStream out) {
- doRequest(urlIn, HttpMethod.POST, appConfig, json, out);
- }
-
- /**
- * Sends a GET request to the specified URL with the specified OpenAI API key and JSON payload.
- * The response from the request is written to the provided OutputStream.
- *
- * @param urlIn the URL to send the request to
- * @param appConfig the AppConfig object containing the OpenAI API key and models
- * @param json the JSON payload to send with the request
- * @param out the OutputStream to write the response to
- */
- public static void doGet(final String urlIn,
- final AppConfig appConfig,
- final JSONObject json,
- final OutputStream out) {
- doRequest(urlIn, HttpMethod.GET, appConfig, json, out);
- }
-
- private static HttpUriRequest resolveMethod(final String method, final String urlIn) {
- switch(method) {
- case HttpMethod.POST:
- return new HttpPost(urlIn);
- case HttpMethod.PUT:
- return new HttpPut(urlIn);
- case HttpMethod.DELETE:
- return new HttpDelete(urlIn);
- case "patch":
- return new HttpPatch(urlIn);
- case HttpMethod.GET:
- default:
- return new HttpGet(urlIn);
- }
- }
-
-}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java
new file mode 100644
index 000000000000..344d4eaced34
--- /dev/null
+++ b/dotCMS/src/main/java/com/dotcms/ai/validator/AIAppValidator.java
@@ -0,0 +1,95 @@
+package com.dotcms.ai.validator;
+
+import com.dotcms.ai.app.AIModel;
+import com.dotcms.ai.app.AIModels;
+import com.dotcms.ai.app.AppConfig;
+import com.dotcms.ai.domain.Model;
+import com.dotcms.api.system.event.message.MessageSeverity;
+import com.dotcms.api.system.event.message.SystemMessageEventUtil;
+import com.dotcms.api.system.event.message.builder.SystemMessage;
+import com.dotcms.api.system.event.message.builder.SystemMessageBuilder;
+import com.dotmarketing.util.DateUtil;
+import com.google.common.annotations.VisibleForTesting;
+import com.liferay.portal.language.LanguageUtil;
+import io.vavr.Lazy;
+import io.vavr.control.Try;
+
+import java.util.Collections;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class AIAppValidator {
+
+ private static final Lazy INSTANCE = Lazy.of(AIAppValidator::new);
+
+ private SystemMessageEventUtil systemMessageEventUtil;
+
+ private AIAppValidator() {
+ setSystemMessageEventUtil(SystemMessageEventUtil.getInstance());
+ }
+
+ public static AIAppValidator get() {
+ return INSTANCE.get();
+ }
+
+ public void validateAIConfig(final AppConfig appConfig, final String userId) {
+ if (Objects.isNull(userId)) {
+ AppConfig.debugLogger(getClass(), () -> "User Id is null, skipping AI configuration validation");
+ return;
+ }
+
+ final Set supportedModels = AIModels.get().getOrPullSupportedModels(appConfig.getApiKey());
+ final Set unsupportedModels = Stream.of(
+ appConfig.getModel(),
+ appConfig.getImageModel(),
+ appConfig.getEmbeddingsModel())
+ .flatMap(aiModel -> aiModel.getModels().stream())
+ .map(Model::getName)
+ .filter(model -> !supportedModels.contains(model))
+ .collect(Collectors.toSet());
+ if (unsupportedModels.isEmpty()) {
+ return;
+ }
+
+ final String unsupported = String.join(", ", unsupportedModels);
+ final String message = Try
+ .of(() -> LanguageUtil.get("ai.unsupported.models", unsupported))
+ .getOrElse(String.format("The following models are not supported: [%s]", unsupported));
+ final SystemMessage systemMessage = new SystemMessageBuilder()
+ .setMessage(message)
+ .setSeverity(MessageSeverity.WARNING)
+ .setLife(DateUtil.SEVEN_SECOND_MILLIS)
+ .create();
+
+ systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId));
+ }
+
+ public void validateModelsUsage(final AIModel aiModel, final String userId) {
+ final String unavailableModels = aiModel.getModels()
+ .stream()
+ .map(Model::getName)
+ .collect(Collectors.joining(", "));
+ final String message = Try
+ .of(() -> LanguageUtil.get("ai.models.exhausted", aiModel.getType(), unavailableModels)).
+ getOrElse(
+ String.format(
+ "All the %s models: [%s] have been exhausted since they are invalid or has been decommissioned",
+ aiModel.getType(),
+ unavailableModels));
+ final SystemMessage systemMessage = new SystemMessageBuilder()
+ .setMessage(message)
+ .setSeverity(MessageSeverity.WARNING)
+ .setLife(DateUtil.SEVEN_SECOND_MILLIS)
+ .create();
+
+ systemMessageEventUtil.pushMessage(systemMessage, Collections.singletonList(userId));
+ }
+
+ @VisibleForTesting
+ void setSystemMessageEventUtil(SystemMessageEventUtil systemMessageEventUtil) {
+ this.systemMessageEventUtil = systemMessageEventUtil;
+ }
+
+}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java
index 050b56b1e535..0ad6d7837a2d 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/AIViewTool.java
@@ -4,9 +4,7 @@
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.api.ChatAPI;
-import com.dotcms.ai.api.OpenAIChatAPIImpl;
import com.dotcms.ai.api.ImageAPI;
-import com.dotcms.ai.api.OpenAIImageAPIImpl;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.util.json.JSONObject;
@@ -30,11 +28,13 @@ public class AIViewTool implements ViewTool {
private AppConfig config;
private ChatAPI chatService;
private ImageAPI imageService;
+ private User user;
@Override
public void init(final Object obj) {
context = (ViewContext) obj;
config = config();
+ user = user();
chatService = chatService();
imageService = imageService();
}
@@ -128,12 +128,12 @@ User user() {
@VisibleForTesting
ChatAPI chatService() {
- return APILocator.getDotAIAPI().getChatAPI(config);
+ return APILocator.getDotAIAPI().getChatAPI(config, user);
}
@VisibleForTesting
ImageAPI imageService() {
- return APILocator.getDotAIAPI().getImageAPI(config, user(), APILocator.getHostAPI(), APILocator.getTempFileAPI());
+ return APILocator.getDotAIAPI().getImageAPI(config, user, APILocator.getHostAPI(), APILocator.getTempFileAPI());
}
private
Try generate(final P prompt, final Function
serviceCall) {
diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java
index 5508a23f4e32..899f69efe93a 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/CompletionsTool.java
@@ -9,6 +9,8 @@
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.util.json.JSONObject;
import com.google.common.annotations.VisibleForTesting;
+import com.liferay.portal.model.User;
+import com.liferay.portal.util.PortalUtil;
import org.apache.velocity.tools.view.context.ViewContext;
import org.apache.velocity.tools.view.tools.ViewTool;
@@ -17,6 +19,7 @@
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Map;
+import java.util.Optional;
/**
* This class is a ViewTool that provides functionality related to completions.
@@ -24,14 +27,18 @@
*/
public class CompletionsTool implements ViewTool {
+ private final ViewContext context;
private final HttpServletRequest request;
private final Host host;
private final AppConfig config;
+ private final User user;
CompletionsTool(Object initData) {
- this.request = ((ViewContext) initData).getRequest();
+ this.context = (ViewContext) initData;
+ this.request = this.context.getRequest();
this.host = host();
this.config = config();
+ this.user = user();
}
@Override
@@ -69,7 +76,11 @@ public Object summarize(final String prompt) {
* @return The summarized object.
*/
public Object summarize(final String prompt, final String indexName) {
- final CompletionsForm form = new CompletionsForm.Builder().indexName(indexName).prompt(prompt).build();
+ final CompletionsForm form = new CompletionsForm.Builder()
+ .indexName(indexName)
+ .prompt(prompt)
+ .user(user)
+ .build();
try {
return APILocator.getDotAIAPI().getCompletionsAPI(config).summarize(form);
} catch (Exception e) {
@@ -112,7 +123,7 @@ public Object raw(String prompt) {
*/
public Object raw(final JSONObject prompt) {
try {
- return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt);
+ return APILocator.getDotAIAPI().getCompletionsAPI(config).raw(prompt, user.getUserId());
} catch (Exception e) {
return handleException(e);
}
@@ -141,4 +152,9 @@ AppConfig config() {
return ConfigService.INSTANCE.config(this.host);
}
+ @VisibleForTesting
+ User user() {
+ return PortalUtil.getUser(context.getRequest());
+ }
+
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java
index daf7e1756139..89414823aebb 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/viewtool/EmbeddingsTool.java
@@ -8,12 +8,15 @@
import com.dotmarketing.business.web.WebAPILocator;
import com.dotmarketing.util.Logger;
import com.google.common.annotations.VisibleForTesting;
+import com.liferay.portal.model.User;
+import com.liferay.portal.util.PortalUtil;
import org.apache.velocity.tools.view.context.ViewContext;
import org.apache.velocity.tools.view.tools.ViewTool;
import javax.servlet.http.HttpServletRequest;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
/**
* This class provides functionality for generating and managing embeddings.
@@ -22,9 +25,11 @@
*/
public class EmbeddingsTool implements ViewTool {
+ private final ViewContext context;
private final HttpServletRequest request;
private final Host host;
private final AppConfig appConfig;
+ private final User user;
/**
* Constructor for the EmbeddingsTool class.
@@ -33,9 +38,11 @@ public class EmbeddingsTool implements ViewTool {
* @param initData Initialization data for the tool.
*/
EmbeddingsTool(Object initData) {
- this.request = ((ViewContext) initData).getRequest();
+ this.context = (ViewContext) initData;
+ this.request = this.context.getRequest();
this.host = host();
this.appConfig = appConfig();
+ this.user = user();
}
@Override
@@ -71,10 +78,14 @@ public List generateEmbeddings(final String prompt) {
if (tokens > maxTokens) {
Logger.warn(
EmbeddingsTool.class,
- "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~" + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens ");
+ "Prompt is too long. Maximum prompt size is " + maxTokens + " tokens (roughly ~"
+ + maxTokens * .75 + " words). Your prompt was " + tokens + " tokens ");
}
- return APILocator.getDotAIAPI().getEmbeddingsAPI().pullOrGenerateEmbeddings(prompt)._2;
+ return APILocator.getDotAIAPI()
+ .getEmbeddingsAPI()
+ .pullOrGenerateEmbeddings(prompt, Optional.ofNullable(user).map(User::getUserId).orElse(null))
+ ._2;
}
/**
@@ -96,4 +107,9 @@ AppConfig appConfig() {
return ConfigService.INSTANCE.config(host);
}
+ @VisibleForTesting
+ User user() {
+ return PortalUtil.getUser(context.getRequest());
+ }
+
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java
index 87a068eca10c..d8d01341bd0d 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagActionlet.java
@@ -1,5 +1,6 @@
package com.dotcms.ai.workflow;
+import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.ConfigService;
import com.dotmarketing.portlets.workflows.actionlet.Actionlet;
import com.dotmarketing.portlets.workflows.actionlet.WorkFlowActionlet;
@@ -9,7 +10,6 @@
import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException;
import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter;
import com.dotmarketing.portlets.workflows.model.WorkflowProcessor;
-import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
@@ -24,7 +24,7 @@ public List getParameters() {
WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key,
"Overwrite tags ", Boolean.toString(true), true,
- () -> ImmutableList.of(
+ () -> List.of(
new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)),
new MultiKeyValue(Boolean.toString(true), Boolean.toString(true)))
);
@@ -32,16 +32,26 @@ public List getParameters() {
WorkflowActionletParameter limitTagsToHost = new MultiSelectionWorkflowActionletParameter(
OpenAIParams.LIMIT_TAGS_TO_HOST.key,
"Limit the keywords to pre-existing tags", "Limit", false,
- () -> ImmutableList.of(
+ () -> List.of(
new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)),
new MultiKeyValue(Boolean.toString(true), Boolean.toString(true))
)
);
+
+ final AppConfig appConfig = ConfigService.INSTANCE.config();
return List.of(
overwriteParameter,
limitTagsToHost,
- new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false),
- new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0.", ".1", false)
+ new WorkflowActionletParameter(
+ OpenAIParams.MODEL.key,
+ "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(),
+ appConfig.getModel().getCurrentModel(),
+ false),
+ new WorkflowActionletParameter(
+ OpenAIParams.TEMPERATURE.key,
+ "The AI temperature for the response. Between .1 and 2.0.",
+ ".1",
+ false)
);
}
@@ -63,5 +73,4 @@ public void executeAction(final WorkflowProcessor processor,
new OpenAIAutoTagRunner(processor, params).run();
}
-
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java
index 2f7013bfd362..a63c03743686 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIAutoTagRunner.java
@@ -146,7 +146,7 @@ private String openAIRequest(final Contentlet workingContentlet, final String co
final String parsedContentPrompt = VelocityUtil.eval(contentToTag, ctx);
final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI()
- .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000);
+ .prompt(parsedSystemPrompt, parsedContentPrompt, model, temperature, 2000, user.getUserId());
return openAIResponse.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content");
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java
index b6a14ab22d44..8a8e9320293b 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptActionlet.java
@@ -1,5 +1,6 @@
package com.dotcms.ai.workflow;
+import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotmarketing.portlets.workflows.actionlet.Actionlet;
@@ -10,7 +11,6 @@
import com.dotmarketing.portlets.workflows.model.WorkflowActionFailureException;
import com.dotmarketing.portlets.workflows.model.WorkflowActionletParameter;
import com.dotmarketing.portlets.workflows.model.WorkflowProcessor;
-import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
@@ -23,22 +23,39 @@ public class OpenAIContentPromptActionlet extends WorkFlowActionlet {
@Override
public List getParameters() {
- WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key,
+ final WorkflowActionletParameter overwriteParameter = new MultiSelectionWorkflowActionletParameter(OpenAIParams.OVERWRITE_FIELDS.key,
"Overwrite existing content (true|false)", Boolean.toString(true), true,
- () -> ImmutableList.of(
+ () -> List.of(
new MultiKeyValue(Boolean.toString(false), Boolean.toString(false)),
new MultiKeyValue(Boolean.toString(true), Boolean.toString(true)))
);
-
+ final AppConfig appConfig = ConfigService.INSTANCE.config();
return List.of(
- new WorkflowActionletParameter(OpenAIParams.FIELD_TO_WRITE.key, "The field where you want to write the results. " +
- " If your response is being returned as a json object, this field can be left blank" +
- " and the keys of the json object will be used to update the content fields.", "", false),
+ new WorkflowActionletParameter(
+ OpenAIParams.FIELD_TO_WRITE.key,
+ "The field where you want to write the results. "
+ + " If your response is being returned as a json object, this field can be left blank"
+ + " and the keys of the json object will be used to update the content fields.",
+ "",
+ false),
overwriteParameter,
- new WorkflowActionletParameter(OpenAIParams.OPEN_AI_PROMPT.key, "The prompt that will be sent to the AI", "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n", true),
- new WorkflowActionletParameter(OpenAIParams.MODEL.key, "The AI model to use, defaults to " + ConfigService.INSTANCE.config().getModel().getCurrentModel(), ConfigService.INSTANCE.config().getModel().getCurrentModel(), false),
- new WorkflowActionletParameter(OpenAIParams.TEMPERATURE.key, "The AI temperature for the response. Between .1 and 2.0. Defaults to " + ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), ConfigService.INSTANCE.config().getConfig(AppKeys.COMPLETION_TEMPERATURE), false)
+ new WorkflowActionletParameter(
+ OpenAIParams.OPEN_AI_PROMPT.key,
+ "The prompt that will be sent to the AI",
+ "We need an attractive search result in Google. Return a json object that includes the fields \"pageTitle\" for a meta title of less than 55 characters and \"metaDescription\" for the meta description of less than 300 characters using this content:\\n\\n${fieldContent}\\n\\n",
+ true),
+ new WorkflowActionletParameter(
+ OpenAIParams.MODEL.key,
+ "The AI model to use, defaults to " + appConfig.getModel().getCurrentModel(),
+ appConfig.getModel().getCurrentModel(),
+ false),
+ new WorkflowActionletParameter(
+ OpenAIParams.TEMPERATURE.key,
+ "The AI temperature for the response. Between .1 and 2.0. Defaults to "
+ + appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE),
+ appConfig.getConfig(AppKeys.COMPLETION_TEMPERATURE),
+ false)
);
}
diff --git a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java
index 176b12d860c4..5ebca943a044 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/workflow/OpenAIContentPromptRunner.java
@@ -145,7 +145,9 @@ private String openAIRequest(final Contentlet workingContentlet) throws Exceptio
final Context ctx = VelocityContextFactory.getMockContext(workingContentlet, user);
final String parsedPrompt = VelocityUtil.eval(prompt, ctx);
- final JSONObject openAIResponse = APILocator.getDotAIAPI().getCompletionsAPI().raw(buildRequest(parsedPrompt, model, temperature));
+ final JSONObject openAIResponse = APILocator.getDotAIAPI()
+ .getCompletionsAPI()
+ .raw(buildRequest(parsedPrompt, model, temperature), user.getUserId());
try {
return openAIResponse
diff --git a/dotCMS/src/main/resources/apps/dotAI.yml b/dotCMS/src/main/resources/apps/dotAI.yml
index e8f03c0c5f80..425c9b42b058 100644
--- a/dotCMS/src/main/resources/apps/dotAI.yml
+++ b/dotCMS/src/main/resources/apps/dotAI.yml
@@ -15,11 +15,11 @@ params:
hint: "Your ChatGPT API key"
required: true
textModelNames:
- value: "gpt-4o"
+ value: "gpt-4o-mini"
hidden: false
type: "STRING"
label: "Model Names"
- hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-3.5-turbo-16k)"
+ hint: "Comma delimited list of models used to generate OpenAI API response (e.g. gpt-4o-mini)"
required: true
rolePrompt:
value: "You are dotCMSbot, and AI assistant to help content creators generate and rewrite content in their content management system."
diff --git a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties
index 782f5f57793a..63c55fbc5354 100644
--- a/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties
+++ b/dotCMS/src/main/webapp/WEB-INF/messages/Language.properties
@@ -189,6 +189,8 @@ anonymous=Anonymous
another-layout-already-exists=Another Tool Group already exists in the system with the same name
Any-Structure-Type=Any Content Type
Any-Structure=Any Content Type
+ai.unsupported.models=The following models are not supported: [{0}]
+ai.models.exhausted=All the {0} models: [{1}] have been exhausted since they are invalid or has been decommissioned
api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.comparison.placeholder=Comparison
api.ruleengine.system.conditionlet.CurrentSessionLanguage.inputs.language.placeholder=Language
api.ruleengine.system.conditionlet.CurrentSessionLanguage.name=Selected Language
diff --git a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js
index 088436aef605..f2db45f1a34c 100644
--- a/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js
+++ b/dotCMS/src/main/webapp/html/portlet/ext/dotai/dotai.js
@@ -136,11 +136,12 @@ const writeModelToDropdown = async () => {
}
const newOption = document.createElement("option");
+ console.log(JSON.stringify(dotAiState.config, null, 2));
newOption.value = dotAiState.config.availableModels[i].name;
newOption.text = `${dotAiState.config.availableModels[i].name}`
- if (dotAiState.config.availableModels[i] === dotAiState.config.model) {
+ if (dotAiState.config.availableModels[i].current) {
newOption.selected = true;
- newOption.text = `${dotAiState.config.availableModels[i]} (default)`
+ newOption.text = `${dotAiState.config.availableModels[i].name} (default)`
}
modelName.appendChild(newOption);
}
diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java
similarity index 86%
rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java
rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java
index e4c43486c3f1..c51e9c6323a5 100644
--- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIChatServiceImplTest.java
+++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIChatAPIImplTest.java
@@ -1,12 +1,11 @@
-package com.dotcms.ai.service;
+package com.dotcms.ai.api;
-import com.dotcms.ai.api.ChatAPI;
-import com.dotcms.ai.api.OpenAIChatAPIImpl;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotmarketing.util.json.JSONObject;
+import com.liferay.portal.model.User;
import org.junit.Before;
import org.junit.Test;
@@ -17,17 +16,19 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
-public class OpenAIChatServiceImplTest {
+public class OpenAIChatAPIImplTest {
private static final String RESPONSE_JSON =
"{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}";
private AppConfig config;
private ChatAPI service;
+ private User user;
@Before
public void setUp() {
config = mock(AppConfig.class);
+ user = mock(User.class);
service = prepareService(RESPONSE_JSON);
}
@@ -54,11 +55,9 @@ public void test_sendTextPrompt() {
}
private ChatAPI prepareService(final String response) {
- return new OpenAIChatAPIImpl(config) {
-
-
+ return new OpenAIChatAPIImpl(config, user) {
@Override
- public String doRequest(final String urlIn, final JSONObject json) {
+ String doRequest(final JSONObject json, final String userId) {
return response;
}
};
@@ -66,7 +65,7 @@ public String doRequest(final String urlIn, final JSONObject json) {
private JSONObject prepareJsonObject(final String prompt) {
when(config.getModel())
- .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withNames("some-model").build());
+ .thenReturn(AIModel.builder().withType(AIModelType.TEXT).withModelNames("some-model").build());
when(config.getConfigFloat(AppKeys.COMPLETION_TEMPERATURE)).thenReturn(123.321F);
when(config.getRolePrompt()).thenReturn("some-role-prompt");
diff --git a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java
similarity index 97%
rename from dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java
rename to dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java
index 6c3fc6822473..e73d9352a59b 100644
--- a/dotCMS/src/test/java/com/dotcms/ai/service/OpenAIImageServiceImplTest.java
+++ b/dotCMS/src/test/java/com/dotcms/ai/api/OpenAIImageAPIImplTest.java
@@ -1,7 +1,5 @@
-package com.dotcms.ai.service;
+package com.dotcms.ai.api;
-import com.dotcms.ai.api.ImageAPI;
-import com.dotcms.ai.api.OpenAIImageAPIImpl;
import com.dotcms.ai.app.AIModel;
import com.dotcms.ai.app.AIModelType;
import com.dotcms.ai.app.AppConfig;
@@ -27,7 +25,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
-public class OpenAIImageServiceImplTest {
+public class OpenAIImageAPIImplTest {
private static final String RESPONSE_JSON =
"{\"data\":[{\"url\":\"http://localhost:8080\",\"value\":\"this is a response\"}]}";
@@ -220,7 +218,7 @@ public AIImageRequestDTO.Builder getDtoBuilder() {
}
private JSONObject prepareJsonObject(final String prompt, final boolean tempFileError) throws Exception {
- when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withNames("some-image-model").build());
+ when(config.getImageModel()).thenReturn(AIModel.builder().withType(AIModelType.IMAGE).withModelNames("some-image-model").build());
when(config.getImageSize()).thenReturn("some-image-size");
final File file = mock(File.class);
when(file.getName()).thenReturn(UUIDGenerator.shorty());
diff --git a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java
index c4d5c93b7627..8c1cd1e79c4e 100644
--- a/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java
+++ b/dotCMS/src/test/java/com/dotcms/ai/app/AIAppUtilTest.java
@@ -1,10 +1,12 @@
package com.dotcms.ai.app;
+import com.dotcms.ai.domain.Model;
import com.dotcms.security.apps.Secret;
import org.junit.Before;
import org.junit.Test;
import java.util.Map;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -127,7 +129,7 @@ public void testCreateTextModel() {
AIModel model = aiAppUtil.createTextModel(secrets);
assertNotNull(model);
assertEquals(AIModelType.TEXT, model.getType());
- assertTrue(model.getNames().contains("textmodel"));
+ assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("textmodel"));
}
/**
@@ -143,7 +145,7 @@ public void testCreateImageModel() {
AIModel model = aiAppUtil.createImageModel(secrets);
assertNotNull(model);
assertEquals(AIModelType.IMAGE, model.getType());
- assertTrue(model.getNames().contains("imagemodel"));
+ assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList()).contains("imagemodel"));
}
/**
@@ -159,7 +161,8 @@ public void testCreateEmbeddingsModel() {
AIModel model = aiAppUtil.createEmbeddingsModel(secrets);
assertNotNull(model);
assertEquals(AIModelType.EMBEDDINGS, model.getType());
- assertTrue(model.getNames().contains("embeddingsmodel"));
+ assertTrue(model.getModels().stream().map(Model::getName).collect(Collectors.toList())
+ .contains("embeddingsmodel"));
}
@Test
diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java
new file mode 100644
index 000000000000..9109d502f60d
--- /dev/null
+++ b/dotCMS/src/test/java/com/dotcms/ai/client/AIProxyClientTest.java
@@ -0,0 +1,65 @@
+package com.dotcms.ai.client;
+
+import com.dotcms.ai.domain.AIRequest;
+import com.dotcms.ai.domain.AIResponse;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Unit tests for the AIProxyClient class.
+ */
+public class AIProxyClientTest {
+
+ private AIProxyClient proxyClient;
+ private AIProxiedClient mockProxiedClient;
+
+ @Before
+ public void setUp() {
+ mockProxiedClient = mock(AIProxiedClient.class);
+ proxyClient = new AIProxyClient(mockProxiedClient);
+ }
+
+ /**
+ * Scenario: Sending a valid AI request with an output stream
+ * Given a valid AI request and an output stream
+ * When the request is sent to the AI service
+ * Then the response should be written to the output stream
+ */
+ @Test
+ public void testSendRequest_withValidRequestAndOutput() {
+ AIRequest request = mock(AIRequest.class);
+ OutputStream output = new ByteArrayOutputStream();
+
+ AIResponse response = proxyClient.sendRequest(request, output);
+
+ verify(mockProxiedClient).callToAI(request, output);
+ assertEquals(AIResponse.EMPTY, response);
+ }
+
+ /**
+ * Scenario: Sending a valid AI request with a null output stream
+ * Given a valid AI request and a null output stream
+ * When the request is sent to the AI service
+ * Then the response should be returned as a string
+ */
+ @Test
+ public void testSendRequest_withValidRequestAndNullOutput() {
+ AIRequest request = mock(AIRequest.class);
+ OutputStream output = null;
+
+ AIResponse response = proxyClient.sendRequest(request, output);
+
+ verify(mockProxiedClient).callToAI(request, output);
+ assertNotNull(response);
+ }
+
+}
\ No newline at end of file
diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java
new file mode 100644
index 000000000000..e2f890cfa463
--- /dev/null
+++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/AIProxiedClientTest.java
@@ -0,0 +1,102 @@
+package com.dotcms.ai.client.openai;
+
+import com.dotcms.ai.client.AIClient;
+import com.dotcms.ai.client.AIClientStrategy;
+import com.dotcms.ai.client.AIProxiedClient;
+import com.dotcms.ai.client.AIProxyStrategy;
+import com.dotcms.ai.client.AIResponseEvaluator;
+import com.dotcms.ai.domain.AIRequest;
+import com.dotcms.ai.domain.AIResponse;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Unit tests for the AIProxiedClient class.
+ *
+ * @author vico
+ */
+public class AIProxiedClientTest {
+
+ private AIClient mockClient;
+ private AIProxyStrategy mockProxyStrategy;
+ private AIClientStrategy mockClientStrategy;
+ private AIResponseEvaluator mockResponseEvaluator;
+ private AIProxiedClient proxiedClient;
+
+ @Before
+ public void setUp() {
+ mockClient = mock(AIClient.class);
+ mockProxyStrategy = mock(AIProxyStrategy.class);
+ mockClientStrategy = mock(AIClientStrategy.class);
+ when(mockProxyStrategy.getStrategy()).thenReturn(mockClientStrategy);
+ mockResponseEvaluator = mock(AIResponseEvaluator.class);
+ proxiedClient = AIProxiedClient.of(mockClient, mockProxyStrategy, mockResponseEvaluator);
+ }
+
+ /**
+ * Scenario: Sending a valid AI request
+ * Given a valid AI request
+ * When the request is sent to the AI service
+ * Then the strategy should be applied
+ * And the response should be written to the output stream
+ */
+ @Test
+ public void testCallToAI_withValidRequest() {
+ AIRequest request = mock(AIRequest.class);
+ OutputStream output = mock(OutputStream.class);
+
+ AIResponse response = proxiedClient.callToAI(request, output);
+
+ verify(mockClientStrategy).applyStrategy(mockClient, mockResponseEvaluator, request, output);
+ assertEquals(AIResponse.EMPTY, response);
+ }
+
+ /**
+ * Scenario: Sending an AI request with null output stream
+ * Given a valid AI request and a null output stream
+ * When the request is sent to the AI service
+ * Then the strategy should be applied
+ * And the response should be returned as a string
+ */
+ @Test
+ public void testCallToAI_withNullOutput() {
+ AIRequest request = mock(AIRequest.class);
+ AIResponse response = proxiedClient.callToAI(request, null);
+
+ verify(mockClientStrategy).applyStrategy(
+ eq(mockClient),
+ eq(mockResponseEvaluator),
+ eq(request),
+ any(OutputStream.class));
+ assertEquals("", response.getResponse());
+ }
+
+ /**
+ * Scenario: Sending an AI request with NOOP client
+ * Given a valid AI request and a NOOP client
+ * When the request is sent to the AI service
+ * Then no operations should be performed
+ * And the response should be empty
+ */
+ @Test
+ public void testCallToAI_withNoopClient() {
+ proxiedClient = AIProxiedClient.NOOP;
+ AIRequest request = AIRequest.builder().build();
+ OutputStream output = new ByteArrayOutputStream();
+
+ AIResponse response = proxiedClient.callToAI(request, output);
+
+ assertEquals(AIResponse.EMPTY, response);
+ }
+}
\ No newline at end of file
diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java
new file mode 100644
index 000000000000..9ce5f40b9257
--- /dev/null
+++ b/dotCMS/src/test/java/com/dotcms/ai/client/openai/OpenAIResponseEvaluatorTest.java
@@ -0,0 +1,104 @@
+package com.dotcms.ai.client.openai;
+
+import com.dotcms.ai.domain.AIResponseData;
+import com.dotcms.ai.domain.ModelStatus;
+import com.dotcms.ai.exception.DotAIModelNotFoundException;
+import com.dotmarketing.exception.DotRuntimeException;
+import org.json.JSONObject;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+
+/**
+ * Tests for the OpenAIResponseEvaluator class.
+ *
+ * @author vico
+ */
+public class OpenAIResponseEvaluatorTest {
+
+ private OpenAIResponseEvaluator evaluator;
+
+ @Before
+ public void setUp() {
+ evaluator = OpenAIResponseEvaluator.get();
+ }
+
+ /**
+ * Scenario: Processing a response with an error
+ * Given a response with an error message "Model has been deprecated"
+ * When the response is processed
+ * Then the metadata should contain the error message "Model has been deprecated"
+ * And the status should be set to DECOMMISSIONED
+ */
+ @Test
+ public void testFromResponse_withError() {
+ String response = new JSONObject().put("error", "Model has been deprecated").toString();
+ AIResponseData metadata = new AIResponseData();
+
+ evaluator.fromResponse(response, metadata);
+
+ assertEquals("Model has been deprecated", metadata.getError());
+ assertEquals(ModelStatus.DECOMMISSIONED, metadata.getStatus());
+ }
+
+ /**
+ * Scenario: Processing a response without an error
+ * Given a response without an error message
+ * When the response is processed
+ * Then the metadata should not contain any error message
+ * And the status should be null
+ */
+ @Test
+ public void testFromResponse_withoutError() {
+ String response = new JSONObject().put("data", "some data").toString();
+ AIResponseData metadata = new AIResponseData();
+
+ evaluator.fromResponse(response, metadata);
+
+ assertNull(metadata.getError());
+ assertNull(metadata.getStatus());
+ }
+
+ /**
+ * Scenario: Processing an exception of type DotRuntimeException
+ * Given an exception of type DotAIModelNotFoundException with message "Model not found"
+ * When the exception is processed
+ * Then the metadata should contain the error message "Model not found"
+ * And the status should be set to INVALID
+ * And the exception should be set to the given DotRuntimeException
+ */
+ @Test
+ public void testFromException_withDotRuntimeException() {
+ DotRuntimeException exception = new DotAIModelNotFoundException("Model not found");
+ AIResponseData metadata = new AIResponseData();
+
+ evaluator.fromException(exception, metadata);
+
+ assertEquals("Model not found", metadata.getError());
+ assertEquals(ModelStatus.INVALID, metadata.getStatus());
+ assertEquals(exception, metadata.getException());
+ }
+
+ /**
+ * Scenario: Processing a general exception
+ * Given a general exception with message "General error"
+ * When the exception is processed
+ * Then the metadata should contain the error message "General error"
+ * And the status should be set to UNKNOWN
+ * And the exception should be wrapped in a DotRuntimeException
+ */
+ @Test
+ public void testFromException_withOtherException() {
+ Exception exception = new Exception("General error");
+ AIResponseData metadata = new AIResponseData();
+
+ evaluator.fromException(exception, metadata);
+
+ assertEquals("General error", metadata.getError());
+ assertEquals(ModelStatus.UNKNOWN, metadata.getStatus());
+ assertEquals(DotRuntimeException.class, metadata.getException().getClass());
+ }
+}
diff --git a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java
index 81b74a231e3f..0f98dc89849b 100644
--- a/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java
+++ b/dotcms-integration/src/test/java/com/dotcms/MainSuite2b.java
@@ -1,6 +1,7 @@
package com.dotcms;
import com.dotcms.ai.app.AIModelsTest;
+import com.dotcms.ai.app.ConfigServiceTest;
import com.dotcms.ai.listener.EmbeddingContentListenerTest;
import com.dotcms.ai.viewtool.AIViewToolTest;
import com.dotcms.ai.viewtool.CompletionsToolTest;
@@ -302,6 +303,7 @@
EmbeddingsToolTest.class,
CompletionsToolTest.class,
AIModelsTest.class,
+ ConfigServiceTest.class,
TimeMachineAPITest.class,
Task240513UpdateContentTypesSystemFieldTest.class,
PruneTimeMachineBackupJobTest.class,
diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java
index 855f61ad4572..02f2f31e6172 100644
--- a/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java
+++ b/dotcms-integration/src/test/java/com/dotcms/ai/AiTest.java
@@ -11,6 +11,7 @@
import com.github.tomakehurst.wiremock.WireMockServer;
import java.util.Map;
+import java.util.Objects;
public interface AiTest {
@@ -31,55 +32,55 @@ static WireMockServer prepareWireMock() {
return wireMockServer;
}
- static Map aiAppSecrets(final WireMockServer wireMockServer,
- final Host host,
+ static Map aiAppSecrets(final Host host,
final String apiKey,
final String textModels,
final String imageModels,
final String embeddingsModel)
throws DotDataException, DotSecurityException {
- final AppSecrets appSecrets = new AppSecrets.Builder()
+ final AppSecrets.Builder builder = new AppSecrets.Builder()
.withKey(AppKeys.APP_KEY)
- .withSecret(AppKeys.API_URL.key, String.format(API_URL, wireMockServer.port()))
- .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, wireMockServer.port()))
- .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, wireMockServer.port()))
+ .withSecret(AppKeys.API_URL.key, String.format(API_URL, PORT))
+ .withSecret(AppKeys.API_IMAGE_URL.key, String.format(API_IMAGE_URL, PORT))
+ .withSecret(AppKeys.API_EMBEDDINGS_URL.key, String.format(API_EMBEDDINGS_URL, PORT))
.withHiddenSecret(AppKeys.API_KEY.key, apiKey)
- .withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels)
- .withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels)
- .withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel)
.withSecret(AppKeys.IMAGE_SIZE.key, IMAGE_SIZE)
.withSecret(AppKeys.LISTENER_INDEXER.key, "{\"default\":\"blog\"}")
.withSecret(AppKeys.COMPLETION_ROLE_PROMPT.key, AppKeys.COMPLETION_ROLE_PROMPT.defaultValue)
- .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue)
- .build();
+ .withSecret(AppKeys.COMPLETION_TEXT_PROMPT.key, AppKeys.COMPLETION_TEXT_PROMPT.defaultValue);
+
+ if (Objects.nonNull(textModels)) {
+ builder.withSecret(AppKeys.TEXT_MODEL_NAMES.key, textModels);
+ }
+ if (Objects.nonNull(imageModels)) {
+ builder.withSecret(AppKeys.IMAGE_MODEL_NAMES.key, imageModels);
+ }
+ if (Objects.nonNull(embeddingsModel)) {
+ builder.withSecret(AppKeys.EMBEDDINGS_MODEL_NAMES.key, embeddingsModel);
+ }
+
+ final AppSecrets appSecrets = builder.build();
APILocator.getAppsAPI().saveSecrets(appSecrets, host, APILocator.systemUser());
return appSecrets.getSecrets();
}
- static Map aiAppSecrets(final WireMockServer wireMockServer,
- final Host host,
- final String apiKey)
+ static Map aiAppSecrets(final Host host, final String apiKey)
throws DotDataException, DotSecurityException {
- return aiAppSecrets(wireMockServer, host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL);
+ return aiAppSecrets(host, apiKey, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL);
}
- static Map aiAppSecrets(final WireMockServer wireMockServer,
- final Host host,
+ static Map aiAppSecrets(final Host host,
final String textModels,
final String imageModels,
final String embeddingsModel)
throws DotDataException, DotSecurityException {
- return aiAppSecrets(wireMockServer, host, API_KEY, textModels, imageModels, embeddingsModel);
+ return aiAppSecrets(host, API_KEY, textModels, imageModels, embeddingsModel);
}
- static Map aiAppSecrets(final WireMockServer wireMockServer, final Host host)
+ static Map aiAppSecrets(final Host host)
throws DotDataException, DotSecurityException {
- return aiAppSecrets(wireMockServer, host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL);
- }
-
- static void removeSecrets(final Host host) throws DotDataException, DotSecurityException {
- APILocator.getAppsAPI().removeSecretsForSite(host, APILocator.systemUser());
+ return aiAppSecrets(host, MODEL, IMAGE_MODEL, EMBEDDINGS_MODEL);
}
}
diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java
index e08965e20843..3da3e9a57586 100644
--- a/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java
+++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/AIModelsTest.java
@@ -1,16 +1,18 @@
package com.dotcms.ai.app;
import com.dotcms.ai.AiTest;
+import com.dotcms.ai.domain.Model;
+import com.dotcms.ai.domain.ModelStatus;
+import com.dotcms.ai.exception.DotAIModelNotFoundException;
+import com.dotcms.ai.model.SimpleModel;
import com.dotcms.datagen.SiteDataGen;
import com.dotcms.util.IntegrationTestInitService;
import com.dotcms.util.network.IPUtils;
import com.dotmarketing.beans.Host;
import com.dotmarketing.business.APILocator;
-import com.dotmarketing.exception.DotDataException;
import com.dotmarketing.exception.DotRuntimeException;
-import com.dotmarketing.exception.DotSecurityException;
-import com.dotmarketing.util.DateUtil;
import com.github.tomakehurst.wiremock.WireMockServer;
+import io.vavr.Tuple2;
import io.vavr.control.Try;
import org.junit.After;
import org.junit.AfterClass;
@@ -23,9 +25,11 @@
import java.util.Set;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
/**
@@ -59,7 +63,7 @@ public void before() {
IPUtils.disabledIpPrivateSubnet(true);
host = new SiteDataGen().nextPersisted();
otherHost = new SiteDataGen().nextPersisted();
- List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(wireMockServer, host)).get());
+ List.of(host, otherHost).forEach(h -> Try.of(() -> AiTest.aiAppSecrets(host)).get());
}
@After
@@ -73,31 +77,32 @@ public void after() {
* Then the correct models should be found and returned.
*/
@Test
- public void test_loadModels_andFindThem() throws DotDataException, DotSecurityException {
- AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost());
- saveSecrets(
+ public void test_loadModels_andFindThem() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost());
+ AiTest.aiAppSecrets(
host,
"text-model-1,text-model-2",
"image-model-3,image-model-4",
"embeddings-model-5,embeddings-model-6");
- saveSecrets(otherHost, "text-model-1", null, null);
+ AiTest.aiAppSecrets(otherHost, "text-model-1", null, null);
final String hostId = host.getHostname();
+ final AppConfig appConfig = ConfigService.INSTANCE.config(host);
- final Optional notFound = aiModels.findModel(hostId, "some-invalid-model-name");
+ final Optional notFound = aiModels.findModel(appConfig, "some-invalid-model-name", AIModelType.TEXT);
assertTrue(notFound.isEmpty());
- final Optional text1 = aiModels.findModel(hostId, "text-model-1");
- final Optional text2 = aiModels.findModel(hostId, "text-model-2");
- assertModels(text1, text2, AIModelType.TEXT);
+ final Optional text1 = aiModels.findModel(appConfig, "text-model-1", AIModelType.TEXT);
+ final Optional text2 = aiModels.findModel(appConfig, "text-model-2", AIModelType.TEXT);
+ assertModels(text1, text2, AIModelType.TEXT, true);
- final Optional image1 = aiModels.findModel(hostId, "image-model-3");
- final Optional image2 = aiModels.findModel(hostId, "image-model-4");
- assertModels(image1, image2, AIModelType.IMAGE);
+ final Optional image1 = aiModels.findModel(appConfig, "image-model-3", AIModelType.IMAGE);
+ final Optional image2 = aiModels.findModel(appConfig, "image-model-4", AIModelType.IMAGE);
+ assertModels(image1, image2, AIModelType.IMAGE, true);
- final Optional embeddings1 = aiModels.findModel(hostId, "embeddings-model-5");
- final Optional embeddings2 = aiModels.findModel(hostId, "embeddings-model-6");
- assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS);
+ final Optional embeddings1 = aiModels.findModel(appConfig, "embeddings-model-5", AIModelType.EMBEDDINGS);
+ final Optional embeddings2 = aiModels.findModel(appConfig, "embeddings-model-6", AIModelType.EMBEDDINGS);
+ assertModels(embeddings1, embeddings2, AIModelType.EMBEDDINGS, true);
assertNotSame(text1.get(), image1.get());
assertNotSame(text1.get(), embeddings1.get());
@@ -112,27 +117,135 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx
final Optional embeddings3 = aiModels.findModel(hostId, AIModelType.EMBEDDINGS);
assertSameModels(embeddings3, embeddings1, embeddings2);
- final Optional text4 = aiModels.findModel(otherHost.getHostname(), "text-model-1");
+ final AppConfig otherAppConfig = ConfigService.INSTANCE.config(otherHost);
+ final Optional text4 = aiModels.findModel(otherAppConfig, "text-model-1", AIModelType.TEXT);
assertTrue(text3.isPresent());
assertNotSame(text1.get(), text4.get());
- saveSecrets(
+ AiTest.aiAppSecrets(
host,
"text-model-7,text-model-8",
"image-model-9,image-model-10",
"embeddings-model-11, embeddings-model-12");
- final Optional text7 = aiModels.findModel(hostId, "text-model-7");
- final Optional text8 = aiModels.findModel(hostId, "text-model-8");
+ final Optional text7 = aiModels.findModel(otherAppConfig, "text-model-7", AIModelType.TEXT);
+ final Optional text8 = aiModels.findModel(otherAppConfig, "text-model-8", AIModelType.TEXT);
assertNotPresentModels(text7, text8);
- final Optional image9 = aiModels.findModel(hostId, "image-model-9");
- final Optional image10 = aiModels.findModel(hostId, "image-model-10");
+ final Optional image9 = aiModels.findModel(otherAppConfig, "image-model-9", AIModelType.IMAGE);
+ final Optional image10 = aiModels.findModel(otherAppConfig, "image-model-10", AIModelType.IMAGE);
assertNotPresentModels(image9, image10);
- final Optional embeddings11 = aiModels.findModel(hostId, "embeddings-model-11");
- final Optional embeddings12 = aiModels.findModel(hostId, "embeddings-model-12");
+ final Optional embeddings11 = aiModels.findModel(otherAppConfig, "embeddings-model-11", AIModelType.EMBEDDINGS);
+ final Optional embeddings12 = aiModels.findModel(otherAppConfig, "embeddings-model-12", AIModelType.EMBEDDINGS);
assertNotPresentModels(embeddings11, embeddings12);
+
+ final List available = aiModels.getAvailableModels();
+ final List availableNames = List.of(
+ "gpt-3.5-turbo-16k", "dall-e-3", "text-embedding-ada-002",
+ "text-model-1", "text-model-7", "text-model-8",
+ "image-model-9", "image-model-10",
+ "embeddings-model-11", "embeddings-model-12");
+ assertTrue(available.stream().anyMatch(model -> availableNames.contains(model.getName())));
+ }
+
+ /**
+ * Given a set of models loaded into the AIModels instance
+ * When the resolveModel method is called with various model names and types
+ * Then the correct models should be resolved and their operational status verified.
+ */
+ @Test
+ public void test_resolveModel() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost());
+ AiTest.aiAppSecrets(host, "text-model-20", "image-model-21", "embeddings-model-22");
+ ConfigService.INSTANCE.config(host);
+ AiTest.aiAppSecrets(otherHost, "text-model-23", null, null);
+ ConfigService.INSTANCE.config(otherHost);
+
+ assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.TEXT).isOperational());
+ assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.IMAGE).isOperational());
+ assertTrue(aiModels.resolveModel(host.getHostname(), AIModelType.EMBEDDINGS).isOperational());
+ assertTrue(aiModels.resolveModel(otherHost.getHostname(), AIModelType.TEXT).isOperational());
+ assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.IMAGE).isOperational());
+ assertFalse(aiModels.resolveModel(otherHost.getHostname(), AIModelType.EMBEDDINGS).isOperational());
+ }
+
+ /**
+ * Given a set of models loaded into the AIModels instance
+ * When the resolveAIModelOrThrow method is called with various model names and types
+ * Then the correct models should be resolved and their operational status verified.
+ */
+ @Test
+ public void test_resolveAIModelOrThrow() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost());
+ AiTest.aiAppSecrets(host, "text-model-30", "image-model-31", "embeddings-model-32");
+
+ final AppConfig appConfig = ConfigService.INSTANCE.config(host);
+ final AIModel aiModel30 = aiModels.resolveAIModelOrThrow(appConfig, "text-model-30", AIModelType.TEXT);
+ final AIModel aiModel31 = aiModels.resolveAIModelOrThrow(appConfig, "image-model-31", AIModelType.IMAGE);
+ final AIModel aiModel32 = aiModels.resolveAIModelOrThrow(
+ appConfig,
+ "embeddings-model-32",
+ AIModelType.EMBEDDINGS);
+
+ assertNotNull(aiModel30);
+ assertNotNull(aiModel31);
+ assertNotNull(aiModel32);
+ assertEquals("text-model-30", aiModel30.getModel("text-model-30").getName());
+ assertEquals("image-model-31", aiModel31.getModel("image-model-31").getName());
+ assertEquals("embeddings-model-32", aiModel32.getModel("embeddings-model-32").getName());
+
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-33", AIModelType.TEXT));
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-34", AIModelType.IMAGE));
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-35", AIModelType.EMBEDDINGS));
+ }
+
+ /**
+ * Given a set of models loaded into the AIModels instance
+ * When the resolveModelOrThrow method is called with various model names and types
+ * Then the correct models should be resolved and their operational status verified.
+ */
+ @Test
+ public void test_resolveModelOrThrow() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost());
+ AiTest.aiAppSecrets(host, "text-model-40", "image-model-41", "embeddings-model-42");
+
+ final AppConfig appConfig = ConfigService.INSTANCE.config(host);
+ final Tuple2 modelTuple40 = aiModels.resolveModelOrThrow(
+ appConfig,
+ "text-model-40",
+ AIModelType.TEXT);
+ final Tuple2 modelTuple41 = aiModels.resolveModelOrThrow(
+ appConfig,
+ "image-model-41",
+ AIModelType.IMAGE);
+ final Tuple2 modelTuple42 = aiModels.resolveModelOrThrow(
+ appConfig,
+ "embeddings-model-42",
+ AIModelType.EMBEDDINGS);
+
+ assertNotNull(modelTuple40);
+ assertNotNull(modelTuple41);
+ assertNotNull(modelTuple42);
+ assertEquals("text-model-40", modelTuple40._1.getModel("text-model-40").getName());
+ assertEquals("image-model-41", modelTuple41._1.getModel("image-model-41").getName());
+ assertEquals("embeddings-model-42", modelTuple42._1.getModel("embeddings-model-42").getName());
+
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "text-model-43", AIModelType.TEXT));
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "image-model-44", AIModelType.IMAGE));
+ assertThrows(
+ DotAIModelNotFoundException.class,
+ () -> aiModels.resolveAIModelOrThrow(appConfig, "embeddings-model-45", AIModelType.EMBEDDINGS));
}
/**
@@ -141,69 +254,46 @@ public void test_loadModels_andFindThem() throws DotDataException, DotSecurityEx
* Then a list of supported models should be returned.
*/
@Test
- public void test_getOrPullSupportedModules() throws DotDataException, DotSecurityException {
- AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost());
+ public void test_getOrPullSupportedModels() throws Exception {
+ final Host systemHost = APILocator.systemHost();
+ AiTest.aiAppSecrets(systemHost);
AIModels.get().cleanSupportedModelsCache();
- Set supported = aiModels.getOrPullSupportedModels();
+ Set supported = aiModels.getOrPullSupportedModels(AiTest.API_KEY);
assertNotNull(supported);
assertEquals(38, supported.size());
-
- AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config);
}
/**
* Given an invalid URL for supported models
* When the getOrPullSupportedModules method is called
- * Then an empty list of supported models should be returned.
+ * Then an exception should be thrown
*/
@Test(expected = DotRuntimeException.class)
- public void test_getOrPullSupportedModules_withNetworkError() {
+ public void test_getOrPullSupportedModuels_withNetworkError() {
AIModels.get().cleanSupportedModelsCache();
IPUtils.disabledIpPrivateSubnet(false);
- final Set supported = aiModels.getOrPullSupportedModels();
- assertSupported(supported);
-
+ aiModels.getOrPullSupportedModels(AiTest.API_KEY);
IPUtils.disabledIpPrivateSubnet(true);
- AIModels.get().setAppConfigSupplier(ConfigService.INSTANCE::config);
}
/**
* Given no API key
* When the getOrPullSupportedModules method is called
- * Then an empty list of supported models should be returned.
+ * Then an exception should be thrown.
*/
@Test(expected = DotRuntimeException.class)
- public void test_getOrPullSupportedModules_noApiKey() throws DotDataException, DotSecurityException {
- AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), null);
+ public void test_getOrPullSupportedModels_noApiKey() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost(), null);
AIModels.get().cleanSupportedModelsCache();
- aiModels.getOrPullSupportedModels();
+ aiModels.getOrPullSupportedModels(null);
}
- /**
- * Given no API key
- * When the getOrPullSupportedModules method is called
- * Then an empty list of supported models should be returned.
- */
- @Test(expected = DotRuntimeException.class)
- public void test_getOrPullSupportedModules_noSystemHost() throws DotDataException, DotSecurityException {
- AiTest.removeSecrets(APILocator.systemHost());
-
- AIModels.get().cleanSupportedModelsCache();
- aiModels.getOrPullSupportedModels();
- }
-
- private void saveSecrets(final Host host,
- final String textModels,
- final String imageModels,
- final String embeddingsModels) throws DotDataException, DotSecurityException {
- AiTest.aiAppSecrets(wireMockServer, host, textModels, imageModels, embeddingsModels);
- DateUtil.sleep(1000);
- }
-
- private static void assertSameModels(Optional text3, Optional text1, Optional text2) {
+ private static void assertSameModels(final Optional text3,
+ final Optional text1,
+ final Optional text2) {
assertTrue(text3.isPresent());
assertSame(text1.get(), text3.get());
assertSame(text2.get(), text3.get());
@@ -211,12 +301,17 @@ private static void assertSameModels(Optional text3, Optional
private static void assertModels(final Optional model1,
final Optional model2,
- final AIModelType type) {
+ final AIModelType type,
+ final boolean assertModelNames) {
assertTrue(model1.isPresent());
assertTrue(model2.isPresent());
assertSame(model1.get(), model2.get());
assertSame(type, model1.get().getType());
assertSame(type, model2.get().getType());
+ if (assertModelNames) {
+ assertTrue(model1.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE));
+ assertTrue(model2.get().getModels().stream().allMatch(model -> model.getStatus() == ModelStatus.ACTIVE));
+ }
}
private static void assertNotPresentModels(final Optional model1, final Optional model2) {
@@ -224,9 +319,4 @@ private static void assertNotPresentModels(final Optional model1, final
assertTrue(model2.isEmpty());
}
- private static void assertSupported(Set supported) {
- assertNotNull(supported);
- assertTrue(supported.isEmpty());
- }
-
}
diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java
new file mode 100644
index 000000000000..2e6143037095
--- /dev/null
+++ b/dotcms-integration/src/test/java/com/dotcms/ai/app/ConfigServiceTest.java
@@ -0,0 +1,101 @@
+package com.dotcms.ai.app;
+
+import com.dotcms.ai.AiTest;
+import com.dotcms.datagen.SiteDataGen;
+import com.dotcms.util.IntegrationTestInitService;
+import com.dotcms.util.LicenseValiditySupplier;
+import com.dotmarketing.beans.Host;
+import com.dotmarketing.business.APILocator;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Unit tests for the ConfigService class.
+ *
+ *
+ * This class contains tests to verify the behavior of the ConfigService,
+ * including scenarios with valid and invalid licenses, and configurations
+ * with and without secrets.
+ *
+ *
+ *
+ * The tests ensure that the ConfigService correctly initializes and
+ * configures the AppConfig based on the provided Host and license validity.
+ *
+ *
+ * @author vico
+ */
+public class ConfigServiceTest {
+
+ private Host host;
+ private ConfigService configService;
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ IntegrationTestInitService.getInstance().init();
+ }
+
+ @Before
+ public void before() {
+ host = new SiteDataGen().nextPersisted();
+ configService = ConfigService.INSTANCE;
+ }
+
+ /**
+ * Given a ConfigService with an invalid license
+ * When the config method is called with a host
+ * Then the models should not be operational.
+ */
+ @Test
+ public void test_invalidLicense() {
+ configService = new ConfigService(new LicenseValiditySupplier() {
+ @Override
+ public boolean hasValidLicense() {
+ return false;
+ }
+ });
+ final AppConfig appConfig = configService.config(host);
+
+ assertFalse(appConfig.getModel().isOperational());
+ assertFalse(appConfig.getImageModel().isOperational());
+ assertFalse(appConfig.getEmbeddingsModel().isOperational());
+ }
+
+ /**
+ * Given a host with secrets and a ConfigService
+ * When the config method is called with the host
+ * Then the models should be operational and the host should be correctly set in the AppConfig.
+ */
+ @Test
+ public void test_config_hostWithSecrets() throws Exception {
+ AiTest.aiAppSecrets(host, "text-model-0", "image-model-1", "embeddings-model-2");
+ final AppConfig appConfig = configService.config(host);
+
+ assertTrue(appConfig.getModel().isOperational());
+ assertTrue(appConfig.getImageModel().isOperational());
+ assertTrue(appConfig.getEmbeddingsModel().isOperational());
+ assertEquals(host.getHostname(), appConfig.getHost());
+ }
+
+ /**
+ * Given a host without secrets and a ConfigService
+ * When the config method is called with the host
+ * Then the models should be operational and the host should be set to "System Host" in the AppConfig.
+ */
+ @Test
+ public void test_config_hostWithoutSecrets() throws Exception {
+ AiTest.aiAppSecrets(APILocator.systemHost(), "text-model-10", "image-model-11", "embeddings-model-12");
+ final AppConfig appConfig = configService.config(host);
+
+ assertTrue(appConfig.getModel().isOperational());
+ assertTrue(appConfig.getImageModel().isOperational());
+ assertTrue(appConfig.getEmbeddingsModel().isOperational());
+ assertEquals("System Host", appConfig.getHost());
+ }
+
+}
diff --git a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java
index 3c61cd335f55..a41bcc9b1398 100644
--- a/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java
+++ b/dotcms-integration/src/test/java/com/dotcms/ai/listener/EmbeddingContentListenerTest.java
@@ -191,8 +191,8 @@ private static boolean waitForEmbeddings(final Contentlet blogContent, final Str
}
private static void addDotAISecrets() throws DotDataException, DotSecurityException {
- AiTest.aiAppSecrets(wireMockServer, host, AiTest.API_KEY);
- AiTest.aiAppSecrets(wireMockServer, APILocator.systemHost(), AiTest.API_KEY);
+ AiTest.aiAppSecrets(host, AiTest.API_KEY);
+ AiTest.aiAppSecrets(APILocator.systemHost(), AiTest.API_KEY);
}
private static void removeDotAISecrets() throws DotDataException, DotSecurityException {
@@ -232,4 +232,5 @@ public static java.util.function.Predicate distinctByKey(
Set