diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index d72e1a1ed..601bf9003 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -99,10 +99,12 @@ public List> getQueries() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - Map allProcessors = new HashMap<>(); - allProcessors.put(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); - allProcessors.put(SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env)); - return allProcessors; + return Map.of( + TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + SparseEncodingProcessor.TYPE, + new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 7e81c3922..09edc301f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -26,6 +26,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +// The abstract class for text processing use cases. Users provide a field name map +// and a model id. During ingestion, the processor will use the corresponding model +// to inference the input texts, and set the target fields according to the field name map. @Log4j2 public abstract class NLPProcessor extends AbstractProcessor { @@ -58,7 +61,7 @@ public NLPProcessor( ) { super(tag, description); this.type = type; - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); validateEmbeddingConfiguration(fieldMap); this.listTypeNestedMapKey = listTypeNestedMapKey; @@ -81,14 +84,14 @@ private void validateEmbeddingConfiguration(Map fieldMap) { } @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { + private void validateListTypeValue(String sourceKey, Object sourceValue) { for (Object value : (List) sourceValue) { if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); } } } @@ -97,7 +100,7 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue) private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { @@ -106,9 +109,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl .filter(Objects::nonNull) .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); } } @@ -122,9 +125,9 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { validateNestedTypeValue(sourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 3da98a186..eb369d1a7 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -260,7 +260,7 @@ private static void validateForRewrite(String queryText, String modelId) { + QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() - + " can not be null." + + " cannot be null." ); } } @@ -275,7 +275,7 @@ private static void validateFieldType(MappedFieldType fieldType) { private static void validateQueryTokens(Map queryTokens) { if (null == queryTokens) { - throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null."); + throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field cannot be null."); } for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a51c62977..000b9598b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -162,8 +162,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } - public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { - // final List> map = List.of(Map.of("key", "value")); + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { @@ -179,7 +178,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { + public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { final ActionListener>> resultListener = mock(ActionListener.class); final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); Mockito.doAnswer(invocation -> { @@ -200,7 +199,7 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenE Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { + public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -224,7 +223,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { + public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -246,7 +245,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() { + public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Times() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -264,7 +263,7 @@ public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Ti Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); } - public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { + public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2);