-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Treat . as a nested field in field_map of text embedding processor #488
Changes from all commits
7006c97
2c7f491
a5b1c4a
0e38fea
2177ba2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
package org.opensearch.neuralsearch.processor; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.LinkedHashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
@@ -27,8 +28,10 @@ | |
import com.google.common.collect.ImmutableMap; | ||
|
||
/** | ||
* The abstract class for text processing use cases. Users provide a field name map and a model id. | ||
* During ingestion, the processor will use the corresponding model to inference the input texts, | ||
* The abstract class for text processing use cases. Users provide a field name | ||
* map and a model id. | ||
* During ingestion, the processor will use the corresponding model to inference | ||
* the input texts, | ||
* and set the target fields according to the field name map. | ||
*/ | ||
@Log4j2 | ||
|
@@ -39,7 +42,8 @@ public abstract class InferenceProcessor extends AbstractProcessor { | |
|
||
private final String type; | ||
|
||
// This field is used for nested knn_vector/rank_features field. The value of the field will be used as the | ||
// This field is used for nested knn_vector/rank_features field. The value of | ||
// the field will be used as the | ||
// default key for the nested object. | ||
private final String listTypeNestedMapKey; | ||
|
||
|
@@ -52,18 +56,18 @@ public abstract class InferenceProcessor extends AbstractProcessor { | |
private final Environment environment; | ||
|
||
public InferenceProcessor( | ||
String tag, | ||
String description, | ||
String type, | ||
String listTypeNestedMapKey, | ||
String modelId, | ||
Map<String, Object> fieldMap, | ||
MLCommonsClientAccessor clientAccessor, | ||
Environment environment | ||
) { | ||
String tag, | ||
String description, | ||
String type, | ||
String listTypeNestedMapKey, | ||
String modelId, | ||
Map<String, Object> fieldMap, | ||
MLCommonsClientAccessor clientAccessor, | ||
Environment environment) { | ||
super(tag, description); | ||
this.type = type; | ||
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); | ||
if (StringUtils.isBlank(modelId)) | ||
throw new IllegalArgumentException("model_id is null or empty, cannot process it"); | ||
validateEmbeddingConfiguration(fieldMap); | ||
|
||
this.listTypeNestedMapKey = listTypeNestedMapKey; | ||
|
@@ -75,33 +79,38 @@ public InferenceProcessor( | |
|
||
private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) { | ||
if (fieldMap == null | ||
|| fieldMap.size() == 0 | ||
|| fieldMap.entrySet() | ||
.stream() | ||
.anyMatch( | ||
x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) | ||
)) { | ||
|| fieldMap.size() == 0 | ||
|| fieldMap.entrySet() | ||
.stream() | ||
.anyMatch( | ||
x -> StringUtils.startsWith(x.getKey(), ".") || StringUtils.endsWith(x.getKey(), ".") | ||
|| Arrays.stream(x.getKey().split("\\.")).anyMatch(y -> StringUtils.isBlank(y)) | ||
|| Objects.isNull(x.getValue()) | ||
|| StringUtils.isBlank(x.getValue().toString()))) { | ||
throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); | ||
} | ||
} | ||
|
||
public abstract void doExecute( | ||
IngestDocument ingestDocument, | ||
Map<String, Object> ProcessMap, | ||
List<String> inferenceList, | ||
BiConsumer<IngestDocument, Exception> handler | ||
); | ||
IngestDocument ingestDocument, | ||
Map<String, Object> ProcessMap, | ||
List<String> inferenceList, | ||
BiConsumer<IngestDocument, Exception> handler); | ||
|
||
@Override | ||
public IngestDocument execute(IngestDocument ingestDocument) throws Exception { | ||
return ingestDocument; | ||
} | ||
|
||
/** | ||
* This method will be invoked by PipelineService to make async inference and then delegate the handler to | ||
* This method will be invoked by PipelineService to make async inference and | ||
* then delegate the handler to | ||
* process the inference response or failure. | ||
* @param ingestDocument {@link IngestDocument} which is the document passed to processor. | ||
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. | ||
* | ||
* @param ingestDocument {@link IngestDocument} which is the document passed to | ||
* processor. | ||
* @param handler {@link BiConsumer} which is the handler which can be | ||
* used after the inference task is done. | ||
*/ | ||
@Override | ||
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) { | ||
|
@@ -142,7 +151,8 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String> | |
} else if (sourceValue instanceof List) { | ||
texts.addAll(((List<String>) sourceValue)); | ||
} else { | ||
if (sourceValue == null) return; | ||
if (sourceValue == null) | ||
return; | ||
texts.add(sourceValue.toString()); | ||
} | ||
} | ||
|
@@ -154,9 +164,20 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge | |
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) { | ||
String originalKey = fieldMapEntry.getKey(); | ||
Object targetKey = fieldMapEntry.getValue(); | ||
|
||
int nestedDotIndex = originalKey.indexOf('.'); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a user-provided info, can we add basic validation if it's not already done as part of the processor/pipeline definition. if multiple levels of nested fields are needed this code may need a rework There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is some basic validation done in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use a static constant and avoid magic characters in code.
|
||
if (nestedDotIndex != -1) { | ||
Map<String, Object> newTargetKey = new LinkedHashMap<>(); | ||
newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey); | ||
targetKey = newTargetKey; | ||
|
||
originalKey = originalKey.substring(0, nestedDotIndex); | ||
} | ||
Comment on lines
+169
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Sanjana679 can you please provide details how we are handling multiple level of nesting here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment, I'm not currently handling multiple levels of nesting, as I initially thought it was only for one level. However, I will work on handling multiple levels of nesting. |
||
|
||
if (targetKey instanceof Map) { | ||
Map<String, Object> treeRes = new LinkedHashMap<>(); | ||
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); | ||
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, | ||
treeRes); | ||
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); | ||
} else { | ||
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); | ||
|
@@ -166,21 +187,20 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge | |
} | ||
|
||
private void buildMapWithProcessorKeyAndOriginalValueForMapType( | ||
String parentKey, | ||
Object processorKey, | ||
Map<String, Object> sourceAndMetadataMap, | ||
Map<String, Object> treeRes | ||
) { | ||
if (processorKey == null || sourceAndMetadataMap == null) return; | ||
String parentKey, | ||
Object processorKey, | ||
Map<String, Object> sourceAndMetadataMap, | ||
Map<String, Object> treeRes) { | ||
Comment on lines
+190
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove these indents. |
||
if (processorKey == null || sourceAndMetadataMap == null) | ||
return; | ||
if (processorKey instanceof Map) { | ||
Map<String, Object> next = new LinkedHashMap<>(); | ||
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) { | ||
buildMapWithProcessorKeyAndOriginalValueForMapType( | ||
nestedFieldMapEntry.getKey(), | ||
nestedFieldMapEntry.getValue(), | ||
(Map<String, Object>) sourceAndMetadataMap.get(parentKey), | ||
next | ||
); | ||
nestedFieldMapEntry.getKey(), | ||
nestedFieldMapEntry.getValue(), | ||
(Map<String, Object>) sourceAndMetadataMap.get(parentKey), | ||
next); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep |
||
} | ||
treeRes.put(parentKey, next); | ||
} else { | ||
|
@@ -199,9 +219,11 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { | |
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { | ||
validateNestedTypeValue(sourceKey, sourceValue, () -> 1); | ||
} else if (!String.class.isAssignableFrom(sourceValueClass)) { | ||
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"field [" + sourceKey + "] is neither string nor nested type, cannot process it"); | ||
} else if (StringUtils.isBlank(sourceValue.toString())) { | ||
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"field [" + sourceKey + "] has empty string value, cannot process it"); | ||
} | ||
} | ||
} | ||
|
@@ -211,18 +233,21 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { | |
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we replace |
||
int maxDepth = maxDepthSupplier.get(); | ||
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { | ||
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"map type field [" + sourceKey + "] reached max depth limit, cannot process it"); | ||
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { | ||
validateListTypeValue(sourceKey, sourceValue); | ||
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) { | ||
((Map) sourceValue).values() | ||
.stream() | ||
.filter(Objects::nonNull) | ||
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); | ||
.stream() | ||
.filter(Objects::nonNull) | ||
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); | ||
Comment on lines
+242
to
+244
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove these indents |
||
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) { | ||
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"map type field [" + sourceKey + "] has non-string type, cannot process it"); | ||
} else if (StringUtils.isBlank(sourceValue.toString())) { | ||
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"map type field [" + sourceKey + "] has empty string, cannot process it"); | ||
} | ||
} | ||
|
||
|
@@ -232,14 +257,17 @@ private void validateListTypeValue(String sourceKey, Object sourceValue) { | |
if (value == null) { | ||
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); | ||
} else if (!(value instanceof String)) { | ||
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"list type field [" + sourceKey + "] has non string value, cannot process it"); | ||
} else if (StringUtils.isBlank(value.toString())) { | ||
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); | ||
throw new IllegalArgumentException( | ||
"list type field [" + sourceKey + "] has empty string, cannot process it"); | ||
} | ||
} | ||
} | ||
|
||
protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) { | ||
protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, | ||
List<?> results) { | ||
Objects.requireNonNull(results, "embedding failed, inference returns null result!"); | ||
log.debug("Model inference result fetched, starting build vector output!"); | ||
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); | ||
|
@@ -248,7 +276,8 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri | |
|
||
@SuppressWarnings({ "unchecked" }) | ||
@VisibleForTesting | ||
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) { | ||
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, | ||
Map<String, Object> sourceAndMetadataMap) { | ||
IndexWrapper indexWrapper = new IndexWrapper(0); | ||
Map<String, Object> result = new LinkedHashMap<>(); | ||
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) { | ||
|
@@ -267,34 +296,36 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res | |
|
||
@SuppressWarnings({ "unchecked" }) | ||
private void putNLPResultToSourceMapForMapType( | ||
String processorKey, | ||
Object sourceValue, | ||
List<?> results, | ||
IndexWrapper indexWrapper, | ||
Map<String, Object> sourceAndMetadataMap | ||
) { | ||
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; | ||
String processorKey, | ||
Object sourceValue, | ||
List<?> results, | ||
IndexWrapper indexWrapper, | ||
Map<String, Object> sourceAndMetadataMap) { | ||
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) | ||
return; | ||
if (sourceValue instanceof Map) { | ||
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) { | ||
putNLPResultToSourceMapForMapType( | ||
inputNestedMapEntry.getKey(), | ||
inputNestedMapEntry.getValue(), | ||
results, | ||
indexWrapper, | ||
(Map<String, Object>) sourceAndMetadataMap.get(processorKey) | ||
); | ||
inputNestedMapEntry.getKey(), | ||
inputNestedMapEntry.getValue(), | ||
results, | ||
indexWrapper, | ||
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)); | ||
} | ||
} else if (sourceValue instanceof String) { | ||
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); | ||
} else if (sourceValue instanceof List) { | ||
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper)); | ||
sourceAndMetadataMap.put(processorKey, | ||
buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper)); | ||
} | ||
} | ||
|
||
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) { | ||
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, | ||
IndexWrapper indexWrapper) { | ||
List<Map<String, Object>> keyToResult = new ArrayList<>(); | ||
IntStream.range(0, sourceValue.size()) | ||
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); | ||
.forEachOrdered( | ||
x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); | ||
return keyToResult; | ||
} | ||
|
||
|
@@ -304,10 +335,14 @@ public String getType() { | |
} | ||
|
||
/** | ||
* Since we need to build a {@link List<String>} as the input for text embedding, and the result type is {@link List<Float>} of {@link List}, | ||
* we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order | ||
* traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the | ||
* order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase | ||
* Since we need to build a {@link List<String>} as the input for text | ||
* embedding, and the result type is {@link List<Float>} of {@link List}, | ||
* we need to map the result back to the input one by one with exactly order. | ||
* For nested map type input, we're performing a pre-order | ||
* traversal to extract the input strings, so when mapping back to the nested | ||
* map, we still need a pre-order traversal to ensure the | ||
* order. And we also need to ensure the index pointer goes forward in the | ||
* recursive, so here the IndexWrapper is to store and increase | ||
* the index pointer during the recursive. | ||
* index: the index pointer of the text embedding result. | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did the formatting change? did you run
./gradlew :spotlessApply
prior?