Skip to content

Commit

Permalink
More meaningful name for new targetKey value
Browse files Browse the repository at this point in the history
Signed-off-by: Sanjana679 <[email protected]>
  • Loading branch information
Sanjana679 committed Nov 12, 2023
1 parent 2c7f491 commit a5b1c4a
Showing 1 changed file with 97 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import com.google.common.collect.ImmutableMap;

/**
* The abstract class for text processing use cases. Users provide a field name map and a model id.
* During ingestion, the processor will use the corresponding model to inference the input texts,
* The abstract class for text processing use cases. Users provide a field name
* map and a model id.
* During ingestion, the processor will use the corresponding model to inference
* the input texts,
* and set the target fields according to the field name map.
*/
@Log4j2
Expand All @@ -39,7 +41,8 @@ public abstract class InferenceProcessor extends AbstractProcessor {

private final String type;

// This field is used for nested knn_vector/rank_features field. The value of the field will be used as the
// This field is used for nested knn_vector/rank_features field. The value of
// the field will be used as the
// default key for the nested object.
private final String listTypeNestedMapKey;

Expand All @@ -52,18 +55,18 @@ public abstract class InferenceProcessor extends AbstractProcessor {
private final Environment environment;

public InferenceProcessor(
String tag,
String description,
String type,
String listTypeNestedMapKey,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
) {
String tag,
String description,
String type,
String listTypeNestedMapKey,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment) {
super(tag, description);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
if (StringUtils.isBlank(modelId))
throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);

this.listTypeNestedMapKey = listTypeNestedMapKey;
Expand All @@ -75,33 +78,36 @@ public InferenceProcessor(

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
if (fieldMap == null
|| fieldMap.size() == 0
|| fieldMap.entrySet()
.stream()
.anyMatch(
x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString())
)) {
|| fieldMap.size() == 0
|| fieldMap.entrySet()
.stream()
.anyMatch(
x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue())
|| StringUtils.isBlank(x.getValue().toString()))) {
throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value");
}
}

public abstract void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
);
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler);

@Override
public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
return ingestDocument;
}

/**
* This method will be invoked by PipelineService to make async inference and then delegate the handler to
* This method will be invoked by PipelineService to make async inference and
* then delegate the handler to
* process the inference response or failure.
* @param ingestDocument {@link IngestDocument} which is the document passed to processor.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
*
* @param ingestDocument {@link IngestDocument} which is the document passed to
* processor.
* @param handler {@link BiConsumer} which is the handler which can be
* used after the inference task is done.
*/
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
Expand Down Expand Up @@ -142,7 +148,8 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
} else if (sourceValue instanceof List) {
texts.addAll(((List<String>) sourceValue));
} else {
if (sourceValue == null) return;
if (sourceValue == null)
return;
texts.add(sourceValue.toString());
}
}
Expand All @@ -157,16 +164,17 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge

int nestedDotIndex = originalKey.indexOf('.');
if (nestedDotIndex != -1) {
Map<String, Object> temp = new LinkedHashMap<>();
temp.put(originalKey.substring(nestedDotIndex + 1), targetKey);
targetKey = temp;
Map<String, Object> newTargetKey = new LinkedHashMap<>();
newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey);
targetKey = newTargetKey;

originalKey = originalKey.substring(0, nestedDotIndex);
}

if (targetKey instanceof Map) {
Map<String, Object> treeRes = new LinkedHashMap<>();
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes);
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap,
treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
Expand All @@ -176,21 +184,20 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
if (processorKey == null || sourceAndMetadataMap == null) return;
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes) {
if (processorKey == null || sourceAndMetadataMap == null)
return;
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next);
}
treeRes.put(parentKey, next);
} else {
Expand All @@ -209,9 +216,11 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
throw new IllegalArgumentException(
"field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
throw new IllegalArgumentException(
"field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
Expand All @@ -221,18 +230,21 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
throw new IllegalArgumentException(
"map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
throw new IllegalArgumentException(
"map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
throw new IllegalArgumentException(
"map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

Expand All @@ -242,14 +254,17 @@ private void validateListTypeValue(String sourceKey, Object sourceValue) {
if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
throw new IllegalArgumentException(
"list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
throw new IllegalArgumentException(
"list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap,
List<?> results) {
Objects.requireNonNull(results, "embedding failed, inference returns null result!");
log.debug("Model inference result fetched, starting build vector output!");
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
Expand All @@ -258,7 +273,8 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results,
Map<String, Object> sourceAndMetadataMap) {
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
Expand All @@ -277,34 +293,36 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res

@SuppressWarnings({ "unchecked" })
private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
String processorKey,
Object sourceValue,
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null)
return;
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
putNLPResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
);
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey));
}
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
sourceAndMetadataMap.put(processorKey,
buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
}
}

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results,
IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
.forEachOrdered(
x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
}

Expand All @@ -314,10 +332,14 @@ public String getType() {
}

/**
* Since we need to build a {@link List<String>} as the input for text embedding, and the result type is {@link List<Float>} of {@link List},
* we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order
* traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the
* order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase
* Since we need to build a {@link List<String>} as the input for text
* embedding, and the result type is {@link List<Float>} of {@link List},
* we need to map the result back to the input one by one with exactly order.
* For nested map type input, we're performing a pre-order
* traversal to extract the input strings, so when mapping back to the nested
* map, we still need a pre-order traversal to ensure the
* order. And we also need to ensure the index pointer goes forward in the
* recursive, so here the IndexWrapper is to store and increase
* the index pointer during the recursive.
* index: the index pointer of the text embedding result.
*/
Expand Down

0 comments on commit a5b1c4a

Please sign in to comment.