Skip to content

Commit

Permalink
Merge branch '2.x' into backport/backport-392-to-2.x
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 authored Oct 6, 2023
2 parents 2090536 + 47d1aee commit b6cd151
Show file tree
Hide file tree
Showing 25 changed files with 1,539 additions and 54 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@

package org.opensearch.neuralsearch.ml;

import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -37,6 +43,8 @@
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;
private static final String EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED = "failed while calling model, check error log for details";
private static final String EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED = "encountered following error while calling a model";

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
Expand Down Expand Up @@ -113,13 +121,30 @@ public void inferenceSentencesWithMapResult(
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a list of floats in the order of inputText.
*
* @param modelId {@link String}
* @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
MLInput mlInput = createMLTextInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
Expand All @@ -140,7 +165,7 @@ private void retryableInferenceSentencesWithVectorResult(
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
Expand All @@ -154,7 +179,7 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
Expand All @@ -167,6 +192,20 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
if (Objects.isNull(tensor.getData())) {
if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) {
String errorFromModel = (String) tensor.getDataAsMap().get("message");
throw new IllegalStateException(
String.format(Locale.ROOT, "%s: %s", EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED, errorFromModel)
);
} else {
log.error(
"Received following output tensor from a model, there is no detailed error message: {}",
tensor.toString()
);
throw new IllegalStateException(EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED);
}
}
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
Expand All @@ -191,4 +230,41 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

private void retryableInferenceSentencesWithSingleVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final Map<String, String> inputObjects,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLMultimodalInput(final List<String> targetResponseFilters, final Map<String, String> input) {
List<String> inputText = new ArrayList<>();
inputText.add(input.get(INPUT_TEXT));
if (input.containsKey(INPUT_IMAGE)) {
inputText.add(input.get(INPUT_IMAGE));
}
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand Down Expand Up @@ -106,7 +108,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0);
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
Expand All @@ -270,7 +270,7 @@ private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
InferenceProcessor.IndexWrapper indexWrapper,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
Expand All @@ -291,11 +291,7 @@ private void putNLPResultToSourceMapForMapType(
}
}

private List<Map<String, Object>> buildNLPResultForListType(
List<String> sourceValue,
List<?> results,
InferenceProcessor.IndexWrapper indexWrapper
) {
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -52,6 +52,9 @@ public void execute(
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
Expand All @@ -67,7 +70,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
}

/**
Expand Down Expand Up @@ -123,7 +126,8 @@ private void updateOriginalQueryResults(final List<QuerySearchResult> querySearc
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand All @@ -135,14 +139,17 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (int i = 0; i < searchHitArray.length; i++) {
int originalDocId = docIds.get(i);
docIdToSearchHit.put(originalDocId, searchHitArray[i]);
}

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
Expand All @@ -161,4 +168,23 @@ private void updateOriginalFetchResults(
);
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
}
return searchHitArray;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
List<Integer> docIds = querySearchResults.isEmpty()
? List.of()
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
return docIds;
}
}
Loading

0 comments on commit b6cd151

Please sign in to comment.