diff --git a/marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingAdder.java b/marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingAdder.java index 2cbe43fb..69fc4202 100644 --- a/marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingAdder.java +++ b/marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingAdder.java @@ -26,20 +26,36 @@ public class EmbeddingAdder implements Function pendingSourceDocuments = new ArrayList<>(); - public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator, DocumentTextSplitter documentTextSplitter) { - this.chunkSelector = chunkSelector; - this.embeddingGenerator = embeddingGenerator; + /** + * Use this when a user has configured a splitter, as the splitter will return {@code DocumentAndChunks} instances + * that avoid the need for using a {@code ChunkSelector} to find chunks. + * + * @param documentTextSplitter + * @param embeddingGenerator + */ + public EmbeddingAdder(DocumentTextSplitter documentTextSplitter, EmbeddingGenerator embeddingGenerator) { this.documentTextSplitter = documentTextSplitter; + this.embeddingGenerator = embeddingGenerator; + this.chunkSelector = null; } /** - * I think we can hold onto documents here? addEmbeddings could return true/false if it actually sends anything. + * Use this constructor when the user has not configured a splitter, as the {@code ChunkSelector} is needed to find + * chunks in each document. * - * @param sourceDocument the function argument - * @return + * @param chunkSelector + * @param embeddingGenerator */ + public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) { + this.chunkSelector = chunkSelector; + this.embeddingGenerator = embeddingGenerator; + this.documentTextSplitter = null; + } + @Override public Iterator apply(DocumentWriteOperation sourceDocument) { + // If the user configured a splitter, then follow a path where the source document is split, which will produce + // DocumentAndChunks instances. Which means the ChunkSelector isn't needed. if (documentTextSplitter != null) { return splitAndAddEmbeddings(sourceDocument); } diff --git a/marklogic-spark-api/src/main/java/com/marklogic/spark/Context.java b/marklogic-spark-api/src/main/java/com/marklogic/spark/Context.java index 00f19158..53921c77 100644 --- a/marklogic-spark-api/src/main/java/com/marklogic/spark/Context.java +++ b/marklogic-spark-api/src/main/java/com/marklogic/spark/Context.java @@ -9,7 +9,7 @@ public abstract class Context implements Serializable { - protected final Map properties; + private final Map properties; protected Context(Map properties) { this.properties = properties; @@ -49,7 +49,7 @@ public final long getNumericOption(String optionName, long defaultValue, long mi public final boolean getBooleanOption(String option, boolean defaultValue) { return hasOption(option) ? Boolean.parseBoolean(getStringOption(option)) : defaultValue; } - + public final String getOptionNameForMessage(String option) { return Util.getOptionNameForErrorMessage(option); } diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/ContextSupport.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/ContextSupport.java index e12a914e..ce9cdff7 100644 --- a/marklogic-spark-connector/src/main/java/com/marklogic/spark/ContextSupport.java +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/ContextSupport.java @@ -79,18 +79,18 @@ protected final Map buildConnectionProperties() { Map connectionProps = new HashMap<>(); connectionProps.put("spark.marklogic.client.authType", "digest"); connectionProps.put("spark.marklogic.client.connectionType", "gateway"); - connectionProps.putAll(this.properties); + connectionProps.putAll(getProperties()); if (optionExists(Options.CLIENT_URI)) { - parseConnectionString(properties.get(Options.CLIENT_URI), connectionProps); + parseConnectionString(getProperties().get(Options.CLIENT_URI), connectionProps); } - if ("true".equalsIgnoreCase(properties.get(Options.CLIENT_SSL_ENABLED))) { + if ("true".equalsIgnoreCase(getProperties().get(Options.CLIENT_SSL_ENABLED))) { connectionProps.put("spark.marklogic.client.sslProtocol", "default"); } return connectionProps; } public final boolean optionExists(String option) { - String value = properties.get(option); + String value = getProperties().get(option); return value != null && value.trim().length() > 0; } diff --git a/marklogic-spark-connector/src/test/java/com/marklogic/langchain4j/embedding/EmbedderTest.java b/marklogic-spark-connector/src/test/java/com/marklogic/langchain4j/embedding/EmbedderTest.java index 71229616..29dc0cba 100644 --- a/marklogic-spark-connector/src/test/java/com/marklogic/langchain4j/embedding/EmbedderTest.java +++ b/marklogic-spark-connector/src/test/java/com/marklogic/langchain4j/embedding/EmbedderTest.java @@ -34,9 +34,7 @@ class EmbedderTest extends AbstractIntegrationTest { @Test void defaultPaths() { DocumentTextSplitter splitter = newJsonSplitter(500, 2, "/text"); - EmbeddingAdder embedder = new EmbeddingAdder( - new JsonChunkSelector.Builder().build(), new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter - ); + EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())); Iterator docs = embedder.apply(readJsonDocument()); @@ -66,8 +64,7 @@ void customizedPaths() { .withTextPointer("/wrapper/custom-text") .withEmbeddingArrayName("custom-embedding") .build(), - new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), - null + new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()) ); DocumentWriteOperation output = embedder.apply(new DocumentWriteOperationImpl("a.json", null, new JacksonHandle(doc))).next(); @@ -82,10 +79,7 @@ void customizedPaths() { @Test void xml() { DocumentTextSplitter splitter = newXmlSplitter(500, 2, "/node()/text"); - EmbeddingAdder embedder = new EmbeddingAdder( - new DOMChunkSelector(null, new XmlChunkConfig()), - new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter - ); + EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())); Iterator docs = embedder.apply(readXmlDocument()); diff --git a/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToXmlTest.java b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToXmlTest.java index 8fa10dbb..511c70a0 100644 --- a/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToXmlTest.java +++ b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToXmlTest.java @@ -85,15 +85,36 @@ void sidecarWithNamespace() { .mode(SaveMode.Append) .save(); - XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example")); - doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4); - for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) { - chunk.assertElementExists("/ex:chunk/ex:text"); - chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " + - "this soon to be a MarkLogic-specific namespace to better distinguish it from the users " + - "content.", "/ex:chunk/embedding"); - } + verifyChunksInNamespacedSidecar(); + verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks"); + } + + /** + * This test verifies that when the source document does not have a namespace but the sidecar document does, + * the chunks still get embeddings because the connector doesn't need to use a ChunkSelector. That is due to the + * connector knowing that the splitter will return instances of DocumentAndChunks, which means the embedder can + * access the chunks without having to find them. + */ + @ExtendWith(RequiresMarkLogic12.class) + @Test + void sidecarWithCustomNamespace() { + readDocument("/marklogic-docs/java-client-intro.xml") + .write().format(CONNECTOR_IDENTIFIER) + .option(Options.CLIENT_URI, makeClientUri()) + .option(Options.XPATH_NAMESPACE_PREFIX + "ex", "org:example") + .option(Options.WRITE_SPLITTER_XPATH, "/node()/text/text()") + .option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS) + .option(Options.WRITE_URI_TEMPLATE, "/split-test.xml") + .option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500) + .option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 4) + .option(Options.WRITE_SPLITTER_SIDECAR_ROOT_NAME, "sidecar") + .option(Options.WRITE_SPLITTER_SIDECAR_XML_NAMESPACE, "org:example") + .option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "namespaced-xml-vector-chunks") + .option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, TEST_EMBEDDING_FUNCTION_CLASS) + .mode(SaveMode.Append) + .save(); + verifyChunksInNamespacedSidecar(); verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks"); } @@ -247,4 +268,15 @@ private void verifyEachChunkIsReturnedByAVectorQuery(String viewName) { assertEquals(4, counter, "Each test is expected to produce 4 chunks based on the max chunk size of 500."); } + + private void verifyChunksInNamespacedSidecar() { + XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example")); + doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4); + for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) { + chunk.assertElementExists("/ex:chunk/ex:text"); + chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " + + "this soon to be a MarkLogic-specific namespace to better distinguish it from the users " + + "content.", "/ex:chunk/embedding"); + } + } } diff --git a/marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java b/marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java index 68d579c4..f50f03f4 100644 --- a/marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java +++ b/marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java @@ -21,9 +21,12 @@ public abstract class EmbeddingAdderFactory { public static Optional makeEmbedder(Context context, DocumentTextSplitter splitter) { Optional embeddingModel = makeEmbeddingModel(context); if (embeddingModel.isPresent()) { - ChunkSelector chunkSelector = makeChunkSelector(context); EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context); - return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator, splitter)); + if (splitter != null) { + return Optional.of(new EmbeddingAdder(splitter, embeddingGenerator)); + } + ChunkSelector chunkSelector = makeChunkSelector(context); + return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator)); } return Optional.empty(); }