diff --git a/src/main/java/com/dotcms/ai/db/EmbeddingsDB.java b/src/main/java/com/dotcms/ai/db/EmbeddingsDB.java index bce28dd..434c201 100644 --- a/src/main/java/com/dotcms/ai/db/EmbeddingsDB.java +++ b/src/main/java/com/dotcms/ai/db/EmbeddingsDB.java @@ -27,6 +27,8 @@ import java.util.Map; import java.util.TreeMap; +import static com.dotcms.ai.db.EmbeddingsDTO.ALL_INDICES; + public class EmbeddingsDB { @@ -285,7 +287,7 @@ List appendParams(StringBuilder sql, EmbeddingsDTO dto) { sql.append(" and host=? "); params.add(dto.host); } - if (UtilMethods.isSet(dto.indexName)) { + if (UtilMethods.isSet(dto.indexName) && !ALL_INDICES.equals(dto.indexName)) { sql.append(" and lower(index_name)=lower(?) "); params.add(dto.indexName); } diff --git a/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java b/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java index 4a25256..22cf133 100644 --- a/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java +++ b/src/main/java/com/dotcms/ai/db/EmbeddingsDTO.java @@ -46,6 +46,8 @@ public class EmbeddingsDTO implements Serializable { public final float temperature; private final String[] operators = {"<->", "<=>", "<#>"}; + public static final String ALL_INDICES = "all"; + private EmbeddingsDTO(Builder builder) { this.embeddings = (builder.embeddings == null) ? new Float[0] : builder.embeddings.toArray(new Float[0]); this.identifier = builder.identifier; diff --git a/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java b/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java index ce287e5..2487a50 100644 --- a/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java +++ b/src/main/java/com/dotcms/ai/listener/EmbeddingContentListener.java @@ -24,6 +24,8 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static com.dotcms.ai.db.EmbeddingsDTO.ALL_INDICES; + public class EmbeddingContentListener implements ContentletListener { @@ -137,6 +139,7 @@ void deleteFromIndexes(Contentlet contentlet) { EmbeddingsDTO dto = new EmbeddingsDTO.Builder() .withIdentifier(contentlet.getIdentifier()) .withLanguage(contentlet.getLanguageId()) + .withIndexName(ALL_INDICES) .build(); EmbeddingsAPI.impl().deleteEmbedding(dto); diff --git a/src/main/java/com/dotcms/ai/service/OpenAIChatService.java b/src/main/java/com/dotcms/ai/service/OpenAIChatService.java index 6e59319..f1c8cfe 100644 --- a/src/main/java/com/dotcms/ai/service/OpenAIChatService.java +++ b/src/main/java/com/dotcms/ai/service/OpenAIChatService.java @@ -18,7 +18,7 @@ public interface OpenAIChatService { * * @return a JSONObject including the generated text and metadata *

diff --git a/src/main/java/com/dotcms/ai/viewtool/SearchTool.java b/src/main/java/com/dotcms/ai/viewtool/SearchTool.java index 02b39e3..fee0ce8 100644 --- a/src/main/java/com/dotcms/ai/viewtool/SearchTool.java +++ b/src/main/java/com/dotcms/ai/viewtool/SearchTool.java @@ -23,7 +23,6 @@ import java.util.Optional; public class SearchTool implements ViewTool { - final private HttpServletRequest request; final private Host host; final private AppConfig app; @@ -39,18 +38,15 @@ public class SearchTool implements ViewTool { this.app = ConfigService.INSTANCE.config(this.host); } - @Override public void init(Object initData) { /* unneeded because of constructor */ } - public Object query(Map mapIn) { User user = PortalUtil.getUser(request); EmbeddingsDTO searcher = EmbeddingsDTO.from(mapIn).withUser(user).build(); - try { return EmbeddingsAPI.impl(host).searchForContent(searcher); } catch (Exception e) { @@ -59,29 +55,22 @@ public Object query(Map mapIn) { } public Object query(String query) { - return query(query, "default"); } public Object query(String query, String indexName) { User user = PortalUtil.getUser(request); - EmbeddingsDTO searcher = new EmbeddingsDTO.Builder().withQuery(query).withIndexName(indexName).withUser(user).withLimit(50).withThreshold(.25f).build(); - try { return EmbeddingsAPI.impl(host).searchForContent(searcher); } catch (Exception e) { return Map.of("error", e.getMessage(), "stackTrace", Arrays.asList(e.getStackTrace())); } - - } public Object related(ContentMap contentMap, String indexName) { - return related(contentMap.getContentObject(), indexName); - } public Object related(Contentlet contentlet, String indexName) { @@ -89,7 +78,6 @@ public Object related(Contentlet contentlet, String indexName) { User user = PortalUtil.getUser(request); List fields = ContentToStringUtil.impl.get().guessWhatFieldsToIndex(contentlet); - Optional contentToRelate = ContentToStringUtil.impl.get().parseFields(contentlet, fields); if (contentToRelate.isEmpty()) { return new JSONObject(); @@ -99,8 +87,6 @@ public Object related(Contentlet contentlet, String indexName) { } catch (Exception e) { return Map.of("error", e.getMessage(), "stackTrace", Arrays.asList(e.getStackTrace())); } - - } }