Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Sep 26, 2023
1 parent c6c631e commit ec70c34
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ public List<QuerySpec<?>> getQueries() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
Map<String, Processor.Factory> allProcessors = new HashMap<>();
allProcessors.put(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
allProcessors.put(SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env));
return allProcessors;
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

// The abstract class for text processing use cases. Users provide a field name map
// and a model id. During ingestion, the processor will use the corresponding model
// to inference the input texts, and set the target fields according to the field name map.
@Log4j2
public abstract class NLPProcessor extends AbstractProcessor {

Expand Down Expand Up @@ -58,7 +61,7 @@ public NLPProcessor(
) {
super(tag, description);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it");
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);

this.listTypeNestedMapKey = listTypeNestedMapKey;
Expand All @@ -81,14 +84,14 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
}

@SuppressWarnings({ "rawtypes" })
private static void validateListTypeValue(String sourceKey, Object sourceValue) {
private void validateListTypeValue(String sourceKey, Object sourceValue) {
for (Object value : (List) sourceValue) {
if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it");
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it");
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it");
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}
Expand All @@ -97,7 +100,7 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue)
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it");
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
Expand All @@ -106,9 +109,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it");
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it");
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

Expand All @@ -122,9 +125,9 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it");
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it");
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ private static void validateForRewrite(String queryText, String modelId) {
+ QUERY_TEXT_FIELD.getPreferredName()
+ " and "
+ MODEL_ID_FIELD.getPreferredName()
+ " can not be null."
+ " cannot be null."
);
}
}
Expand All @@ -275,7 +275,7 @@ private static void validateFieldType(MappedFieldType fieldType) {

private static void validateQueryTokens(Map<String, Float> queryTokens) {
if (null == queryTokens) {
throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null.");
throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field cannot be null.");
}
for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
if (entry.getValue() <= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
Mockito.verify(resultListener).onFailure(illegalStateException);
}

public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
// final List<Map<String, String>> map = List.of(Map.of("key", "value"));
public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
final Map<String, String> map = Map.of("key", "value");
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
Mockito.doAnswer(invocation -> {
Expand All @@ -179,7 +178,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() {
public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() {
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList());
Mockito.doAnswer(invocation -> {
Expand All @@ -200,7 +199,7 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenE
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() {
public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() {
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand All @@ -224,7 +223,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() {
public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() {
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand All @@ -246,7 +245,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() {
public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Times() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
Expand All @@ -264,7 +263,7 @@ public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Ti
Mockito.verify(resultListener).onFailure(nodeNodeConnectedException);
}

public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() {
public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFail() {
final IllegalStateException illegalStateException = new IllegalStateException("Illegal state");
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand Down

0 comments on commit ec70c34

Please sign in to comment.