Skip to content

Commit

Permalink
[FEATURE] Support batch ingestion in TextEmbeddingProcessor & SparseE…
Browse files Browse the repository at this point in the history
…ncodingProcessor (#744) (#762)

* Support batch ingestion in TextEmbeddingProcess & SparseEncodingProcessor

Signed-off-by: Liyun Xiu <[email protected]>

* Update Changelog

Signed-off-by: Liyun Xiu <[email protected]>

* Add UT and IT

Signed-off-by: Liyun Xiu <[email protected]>

* Add comments

Signed-off-by: Liyun Xiu <[email protected]>

* Sort texts by length before sending for inference

Signed-off-by: Liyun Xiu <[email protected]>

* Make consistent check for inferenceList

Signed-off-by: Liyun Xiu <[email protected]>

---------

Signed-off-by: Liyun Xiu <[email protected]>
(cherry picked from commit afd1215)
  • Loading branch information
chishui authored May 27, 2024
1 parent bfa0766 commit 15b4a0f
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 72 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x)
### Features
- Support batchExecute in TextEmbeddingProcessor and SparseEncodingProcessor ([#743](https://github.com/opensearch-project/neural-search/issues/743))
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,30 @@
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -119,6 +129,121 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
}
}

/**
* This is the function which does actual inference work for batchExecute interface.
* @param inferenceList a list of String for inference.
* @param handler a callback handler to handle inference results which is a list of objects.
* @param onException an exception callback to handle exception.
*/
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

@Override
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
handler.accept(Collections.emptyList());
return;
}

List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
List<String> inferenceList = constructInferenceTexts(dataForInferences);
if (inferenceList.isEmpty()) {
handler.accept(ingestDocumentWrappers);
return;
}
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
inferenceList = sortedResult.v1();
Map<Integer, Integer> originalOrder = sortedResult.v2();
doBatchExecute(inferenceList, results -> {
int startIndex = 0;
results = restoreToOriginalOrder(results, originalOrder);
for (DataForInference dataForInference : dataForInferences) {
if (dataForInference.getIngestDocumentWrapper().getException() != null
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
continue;
}
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
startIndex += dataForInference.getInferenceList().size();
setVectorFieldsToDocument(
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
dataForInference.getProcessMap(),
inferenceResults
);
}
handler.accept(ingestDocumentWrappers);
}, exception -> {
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
// set exception to IngestDocumentWrapper which doesn't have exception before.
if (ingestDocumentWrapper.getException() == null) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
}
}
handler.accept(ingestDocumentWrappers);
});
}

private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
List<Tuple<Integer, String>> docsWithIndex = new ArrayList<>();
for (int i = 0; i < inferenceList.size(); ++i) {
docsWithIndex.add(Tuple.tuple(i, inferenceList.get(i)));
}
docsWithIndex.sort(Comparator.comparingInt(t -> t.v2().length()));
List<String> sortedInferenceList = docsWithIndex.stream().map(Tuple::v2).collect(Collectors.toList());
Map<Integer, Integer> originalOrderMap = new HashMap<>();
for (int i = 0; i < docsWithIndex.size(); ++i) {
originalOrderMap.put(i, docsWithIndex.get(i).v1());
}
return Tuple.tuple(sortedInferenceList, originalOrderMap);
}

private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> originalOrder) {
List<Object> sortedResults = Arrays.asList(results.toArray());
for (int i = 0; i < results.size(); ++i) {
if (!originalOrder.containsKey(i)) continue;
int oldIndex = originalOrder.get(i);
sortedResults.set(oldIndex, results.get(i));
}
return sortedResults;
}

private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
List<String> inferenceTexts = new ArrayList<>();
for (DataForInference dataForInference : dataForInferences) {
if (dataForInference.getIngestDocumentWrapper().getException() != null
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
continue;
}
inferenceTexts.addAll(dataForInference.getInferenceList());
}
return inferenceTexts;
}

private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
List<DataForInference> dataForInferences = new ArrayList<>();
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Map<String, Object> processMap = null;
List<String> inferenceList = null;
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
} finally {
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
}
}
return dataForInferences;
}

@Getter
@AllArgsConstructor
private static class DataForInference {
private final IngestDocumentWrapper ingestDocumentWrapper;
private final Map<String, Object> processMap;
private final List<String> inferenceList;
}

@SuppressWarnings({ "unchecked" })
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
List<String> texts = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -49,4 +50,13 @@ public void doExecute(
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -48,4 +49,9 @@ public void doExecute(
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import com.google.common.collect.ImmutableList;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.test.OpenSearchTestCase;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class InferenceProcessorTestCase extends OpenSearchTestCase {

protected List<IngestDocumentWrapper> createIngestDocumentWrappers(int count) {
List<IngestDocumentWrapper> wrapperList = new ArrayList<>();
for (int i = 0; i < count; ++i) {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
wrapperList.add(new IngestDocumentWrapper(i, new IngestDocument(sourceAndMetadata, new HashMap<>()), null));
}
return wrapperList;
}

protected List<List<Float>> createMockVectorWithLength(int size) {
float suffix = .234f;
List<List<Float>> result = new ArrayList<>();
for (int i = 0; i < size * 2;) {
List<Float> number = new ArrayList<>();
number.add(i++ + suffix);
number.add(i++ + suffix);
result.add(number);
}
return result;
}

protected List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
List<Float> number2 = ImmutableList.of(3.234f, 4.354f);
List<Float> number3 = ImmutableList.of(5.234f, 6.354f);
List<Float> number4 = ImmutableList.of(7.234f, 8.354f);
List<Float> number5 = ImmutableList.of(9.234f, 10.354f);
List<Float> number6 = ImmutableList.of(11.234f, 12.354f);
List<Float> number7 = ImmutableList.of(13.234f, 14.354f);
modelTensorList.add(number1);
modelTensorList.add(number2);
modelTensorList.add(number3);
modelTensorList.add(number4);
modelTensorList.add(number5);
modelTensorList.add(number6);
modelTensorList.add(number7);
return modelTensorList;
}
}
Loading

0 comments on commit 15b4a0f

Please sign in to comment.