Skip to content

Commit

Permalink
Support list<list<string>> type in embedding and extract validation l…
Browse files Browse the repository at this point in the history
…ogic to common class

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored and yuye-aws committed Mar 1, 2024
1 parent 3a41649 commit cd323ab
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ProcessorInputValidator;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.DocumentChunkingProcessor;
Expand Down Expand Up @@ -109,11 +110,12 @@ public List<QuerySpec<?>> getQueries() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
ProcessorInputValidator processorInputValidator = new ProcessorInputValidator();
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, processorInputValidator),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, processorInputValidator),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
DocumentChunkingProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand Down Expand Up @@ -51,6 +49,8 @@ public abstract class InferenceProcessor extends AbstractProcessor {

private final Environment environment;

private final ProcessorInputValidator processorInputValidator;

public InferenceProcessor(
String tag,
String description,
Expand All @@ -59,7 +59,8 @@ public InferenceProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ProcessorInputValidator processorInputValidator
) {
super(tag, description);
this.type = type;
Expand All @@ -71,6 +72,7 @@ public InferenceProcessor(
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.processorInputValidator = processorInputValidator;
}

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
Expand Down Expand Up @@ -106,13 +108,13 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(ProcessMap);
processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, false);
Map<String, Object> processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
} else {
doExecute(ingestDocument, ProcessMap, inferenceList, handler);
doExecute(ingestDocument, processMap, inferenceList, handler);
}
} catch (Exception e) {
handler.accept(null, e);
Expand All @@ -125,7 +127,13 @@ private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
Object sourceValue = knnMapEntry.getValue();
if (sourceValue instanceof List) {
texts.addAll(((List<String>) sourceValue));
for (Object nestedValue : (List<Object>) sourceValue) {
if (nestedValue instanceof String) {
texts.add((String) nestedValue);
} else {
texts.addAll((List<String>) nestedValue);
}
}
} else if (sourceValue instanceof Map) {
createInferenceListForMapTypeInput(sourceValue, texts);
} else {
Expand Down Expand Up @@ -204,68 +212,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(
}
}

private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (sourceValue != null) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1);
} else if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
protected void setTargetFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
Objects.requireNonNull(results, "embedding failed, inference returns null result!");
log.debug("Model inference result fetched, starting build vector output!");
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
nlpResult.forEach(ingestDocument::setFieldValue);
Map<String, Object> result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata());
result.forEach(ingestDocument::setFieldValue);
}

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
Map<String, Object> buildResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
Expand All @@ -274,16 +230,16 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res
if (sourceValue instanceof String) {
result.put(knnKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
result.put(knnKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
result.put(knnKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
} else if (sourceValue instanceof Map) {
putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
}
}
return result;
}

@SuppressWarnings({ "unchecked" })
private void putNLPResultToSourceMapForMapType(
private void putResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
Expand All @@ -294,12 +250,12 @@ private void putNLPResultToSourceMapForMapType(
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
if (sourceAndMetadataMap.get(processorKey) instanceof List) {
// build nlp output for list of nested objects
// build output for list of nested objects
for (Map<String, Object> nestedElement : (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey)) {
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
} else {
putNLPResultToSourceMapForMapType(
putResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
Expand All @@ -311,15 +267,27 @@ private void putNLPResultToSourceMapForMapType(
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
sourceAndMetadataMap.put(processorKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
}
}

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
protected List<?> buildResultForListType(List<Object> sourceValue, List<?> results, IndexWrapper indexWrapper) {
Object peek = sourceValue.get(0);
if (peek instanceof String) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
} else {
List<List<Map<String, Object>>> keyToResult = new ArrayList<>();
for (Object nestedList : sourceValue) {
List<Map<String, Object>> nestedResult = new ArrayList<>();
IntStream.range(0, ((List) nestedList).size())
.forEachOrdered(x -> nestedResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
keyToResult.add(nestedResult);
}
return keyToResult;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.IngestDocument;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

public class ProcessorInputValidator {

public void validateFieldsValue(
Map<String, Object> fieldMap,
Environment environment,
IngestDocument ingestDocument,
boolean allowEmpty
) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (sourceValue != null) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, environment, allowEmpty, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (!allowEmpty && StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(
String sourceKey,
Object sourceValue,
Environment environment,
boolean allowEmpty,
Supplier<Integer> maxDepthSupplier
) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue, environment, allowEmpty, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, environment, allowEmpty, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (!allowEmpty && StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(
String sourceKey,
Object sourceValue,
Environment environment,
boolean allowEmpty,
Supplier<Integer> maxDepthSupplier
) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, environment, allowEmpty, () -> maxDepthSupplier.get() + 1);
} else if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (value instanceof List) {
for (Object nestedValue : (List) sourceValue) {
if (!(nestedValue instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
}
}
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (!allowEmpty && StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ public SparseEncodingProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ProcessorInputValidator processorInputValidator
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, processorInputValidator);
}

@Override
Expand All @@ -45,7 +46,7 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
setTargetFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}
Expand Down
Loading

0 comments on commit cd323ab

Please sign in to comment.