Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored EmbeddingAdder #371

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,36 @@ public class EmbeddingAdder implements Function<DocumentWriteOperation, Iterator

private List<DocumentWriteOperation> pendingSourceDocuments = new ArrayList<>();

public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator, DocumentTextSplitter documentTextSplitter) {
this.chunkSelector = chunkSelector;
this.embeddingGenerator = embeddingGenerator;
/**
* Use this when a user has configured a splitter, as the splitter will return {@code DocumentAndChunks} instances
* that avoid the need for using a {@code ChunkSelector} to find chunks.
*
* @param documentTextSplitter
* @param embeddingGenerator
*/
public EmbeddingAdder(DocumentTextSplitter documentTextSplitter, EmbeddingGenerator embeddingGenerator) {
this.documentTextSplitter = documentTextSplitter;
this.embeddingGenerator = embeddingGenerator;
this.chunkSelector = null;
}

/**
* I think we can hold onto documents here? addEmbeddings could return true/false if it actually sends anything.
* Use this constructor when the user has not configured a splitter, as the {@code ChunkSelector} is needed to find
* chunks in each document.
*
* @param sourceDocument the function argument
* @return
* @param chunkSelector
* @param embeddingGenerator
*/
public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) {
this.chunkSelector = chunkSelector;
this.embeddingGenerator = embeddingGenerator;
this.documentTextSplitter = null;
}

@Override
public Iterator<DocumentWriteOperation> apply(DocumentWriteOperation sourceDocument) {
// If the user configured a splitter, then follow a path where the source document is split, which will produce
// DocumentAndChunks instances. Which means the ChunkSelector isn't needed.
if (documentTextSplitter != null) {
return splitAndAddEmbeddings(sourceDocument);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

public abstract class Context implements Serializable {

protected final Map<String, String> properties;
private final Map<String, String> properties;

protected Context(Map<String, String> properties) {
this.properties = properties;
Expand Down Expand Up @@ -49,7 +49,7 @@ public final long getNumericOption(String optionName, long defaultValue, long mi
public final boolean getBooleanOption(String option, boolean defaultValue) {
return hasOption(option) ? Boolean.parseBoolean(getStringOption(option)) : defaultValue;
}

public final String getOptionNameForMessage(String option) {
return Util.getOptionNameForErrorMessage(option);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ protected final Map<String, String> buildConnectionProperties() {
Map<String, String> connectionProps = new HashMap<>();
connectionProps.put("spark.marklogic.client.authType", "digest");
connectionProps.put("spark.marklogic.client.connectionType", "gateway");
connectionProps.putAll(this.properties);
connectionProps.putAll(getProperties());
if (optionExists(Options.CLIENT_URI)) {
parseConnectionString(properties.get(Options.CLIENT_URI), connectionProps);
parseConnectionString(getProperties().get(Options.CLIENT_URI), connectionProps);
}
if ("true".equalsIgnoreCase(properties.get(Options.CLIENT_SSL_ENABLED))) {
if ("true".equalsIgnoreCase(getProperties().get(Options.CLIENT_SSL_ENABLED))) {
connectionProps.put("spark.marklogic.client.sslProtocol", "default");
}
return connectionProps;
}

public final boolean optionExists(String option) {
String value = properties.get(option);
String value = getProperties().get(option);
return value != null && value.trim().length() > 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class EmbedderTest extends AbstractIntegrationTest {
@Test
void defaultPaths() {
DocumentTextSplitter splitter = newJsonSplitter(500, 2, "/text");
EmbeddingAdder embedder = new EmbeddingAdder(
new JsonChunkSelector.Builder().build(), new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter
);
EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()));

Iterator<DocumentWriteOperation> docs = embedder.apply(readJsonDocument());

Expand Down Expand Up @@ -66,8 +64,7 @@ void customizedPaths() {
.withTextPointer("/wrapper/custom-text")
.withEmbeddingArrayName("custom-embedding")
.build(),
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()),
null
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())
);

DocumentWriteOperation output = embedder.apply(new DocumentWriteOperationImpl("a.json", null, new JacksonHandle(doc))).next();
Expand All @@ -82,10 +79,7 @@ void customizedPaths() {
@Test
void xml() {
DocumentTextSplitter splitter = newXmlSplitter(500, 2, "/node()/text");
EmbeddingAdder embedder = new EmbeddingAdder(
new DOMChunkSelector(null, new XmlChunkConfig()),
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter
);
EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()));

Iterator<DocumentWriteOperation> docs = embedder.apply(readXmlDocument());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,36 @@ void sidecarWithNamespace() {
.mode(SaveMode.Append)
.save();

XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example"));
doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4);
for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) {
chunk.assertElementExists("/ex:chunk/ex:text");
chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " +
"this soon to be a MarkLogic-specific namespace to better distinguish it from the users " +
"content.", "/ex:chunk/embedding");
}
verifyChunksInNamespacedSidecar();
verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks");
}

/**
* This test verifies that when the source document does not have a namespace but the sidecar document does,
* the chunks still get embeddings because the connector doesn't need to use a ChunkSelector. That is due to the
* connector knowing that the splitter will return instances of DocumentAndChunks, which means the embedder can
* access the chunks without having to find them.
*/
@ExtendWith(RequiresMarkLogic12.class)
@Test
void sidecarWithCustomNamespace() {
readDocument("/marklogic-docs/java-client-intro.xml")
.write().format(CONNECTOR_IDENTIFIER)
.option(Options.CLIENT_URI, makeClientUri())
.option(Options.XPATH_NAMESPACE_PREFIX + "ex", "org:example")
.option(Options.WRITE_SPLITTER_XPATH, "/node()/text/text()")
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
.option(Options.WRITE_URI_TEMPLATE, "/split-test.xml")
.option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500)
.option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 4)
.option(Options.WRITE_SPLITTER_SIDECAR_ROOT_NAME, "sidecar")
.option(Options.WRITE_SPLITTER_SIDECAR_XML_NAMESPACE, "org:example")
.option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "namespaced-xml-vector-chunks")
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, TEST_EMBEDDING_FUNCTION_CLASS)
.mode(SaveMode.Append)
.save();

verifyChunksInNamespacedSidecar();
verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks");
}

Expand Down Expand Up @@ -247,4 +268,15 @@ private void verifyEachChunkIsReturnedByAVectorQuery(String viewName) {

assertEquals(4, counter, "Each test is expected to produce 4 chunks based on the max chunk size of 500.");
}

private void verifyChunksInNamespacedSidecar() {
XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example"));
doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4);
for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) {
chunk.assertElementExists("/ex:chunk/ex:text");
chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " +
"this soon to be a MarkLogic-specific namespace to better distinguish it from the users " +
"content.", "/ex:chunk/embedding");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ public abstract class EmbeddingAdderFactory {
public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTextSplitter splitter) {
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
if (embeddingModel.isPresent()) {
ChunkSelector chunkSelector = makeChunkSelector(context);
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator, splitter));
if (splitter != null) {
return Optional.of(new EmbeddingAdder(splitter, embeddingGenerator));
}
ChunkSelector chunkSelector = makeChunkSelector(context);
return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator));
}
return Optional.empty();
}
Expand Down