Skip to content

Commit

Permalink
Revert InferenceProcessor.java
Browse files Browse the repository at this point in the history
Signed-off-by: Yuye Zhu <[email protected]>
  • Loading branch information
yuye-aws authored Mar 6, 2024
1 parent 36799bd commit 0140fc0
Showing 1 changed file with 81 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand Down Expand Up @@ -43,11 +45,11 @@ public abstract class InferenceProcessor extends AbstractProcessor {

protected final String modelId;

protected final Map<String, Object> fieldMap;
private final Map<String, Object> fieldMap;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

protected final Environment environment;
private final Environment environment;

public InferenceProcessor(
String tag,
Expand All @@ -61,8 +63,9 @@ public InferenceProcessor(
) {
super(tag, description);
this.type = type;
validateEmbeddingConfiguration(fieldMap);
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);

this.listTypeNestedMapKey = listTypeNestedMapKey;
this.modelId = modelId;
this.fieldMap = fieldMap;
Expand Down Expand Up @@ -103,32 +106,26 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, false);
Map<String, Object> processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.isEmpty()) {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(ProcessMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
} else {
doExecute(ingestDocument, processMap, inferenceList, handler);
doExecute(ingestDocument, ProcessMap, inferenceList, handler);
}
} catch (Exception e) {
handler.accept(null, e);
}
}

@SuppressWarnings({ "unchecked" })
protected List<String> createInferenceList(Map<String, Object> knnKeyMap) {
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
List<String> texts = new ArrayList<>();
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
Object sourceValue = knnMapEntry.getValue();
if (sourceValue instanceof List) {
for (Object nestedValue : (List<Object>) sourceValue) {
if (nestedValue instanceof String) {
texts.add((String) nestedValue);
} else {
texts.addAll((List<String>) nestedValue);
}
}
texts.addAll(((List<String>) sourceValue));
} else if (sourceValue instanceof Map) {
createInferenceListForMapTypeInput(sourceValue, texts);
} else {
Expand Down Expand Up @@ -207,16 +204,68 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(
}
}

protected void setTargetFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (sourceValue != null) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1);
} else if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
Objects.requireNonNull(results, "embedding failed, inference returns null result!");
log.debug("Model inference result fetched, starting build vector output!");
Map<String, Object> result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata());
result.forEach(ingestDocument::setFieldValue);
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
nlpResult.forEach(ingestDocument::setFieldValue);
}

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
Expand All @@ -225,16 +274,16 @@ Map<String, Object> buildResult(Map<String, Object> processorMap, List<?> result
if (sourceValue instanceof String) {
result.put(knnKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
result.put(knnKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
result.put(knnKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
} else if (sourceValue instanceof Map) {
putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
}
}
return result;
}

@SuppressWarnings({ "unchecked" })
private void putResultToSourceMapForMapType(
private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
Expand All @@ -245,12 +294,12 @@ private void putResultToSourceMapForMapType(
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
if (sourceAndMetadataMap.get(processorKey) instanceof List) {
// build output for list of nested objects
// build nlp output for list of nested objects
for (Map<String, Object> nestedElement : (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey)) {
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
} else {
putResultToSourceMapForMapType(
putNLPResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
Expand All @@ -262,27 +311,15 @@ private void putResultToSourceMapForMapType(
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
}
}

protected List<?> buildResultForListType(List<Object> sourceValue, List<?> results, IndexWrapper indexWrapper) {
Object peek = sourceValue.get(0);
if (peek instanceof String) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
} else {
List<List<Map<String, Object>>> keyToResult = new ArrayList<>();
for (Object nestedList : sourceValue) {
List<Map<String, Object>> nestedResult = new ArrayList<>();
IntStream.range(0, ((List) nestedList).size())
.forEachOrdered(x -> nestedResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
keyToResult.add(nestedResult);
}
return keyToResult;
}
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
}

@Override
Expand All @@ -299,7 +336,7 @@ public String getType() {
* index: the index pointer of the text embedding result.
*/
static class IndexWrapper {
protected int index;
private int index;

protected IndexWrapper(int index) {
this.index = index;
Expand Down

0 comments on commit 0140fc0

Please sign in to comment.