From aee2a0b97f688f037c3dd060edded26388faec2b Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 15:56:55 +0800 Subject: [PATCH] Address comments Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 4 ++-- .../ConnectorPreProcessFunction.java | 10 +++++++++- ...ultiModalConnectorPreProcessFunction.java} | 18 ++++++++++++++--- ...ModalConnectorPreProcessFunctionTest.java} | 20 +++++++++---------- 4 files changed, 35 insertions(+), 17 deletions(-) rename common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/{MultiModalEmbeddingPreProcessFunction.java => MultiModalConnectorPreProcessFunction.java} (54%) rename common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/{MultiModalEmbeddingPreProcessFunctionTest.java => MultiModalConnectorPreProcessFunctionTest.java} (85%) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index c7aabb3aff..758e1e3ce0 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -8,7 +8,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; -import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -36,7 +36,7 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); - MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction(); + MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 906f719157..d049dc8956 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -16,11 +16,18 @@ import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Function; import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; +/** + * This abstract class represents a pre-processing function for a connector. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true. + */ @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { @@ -45,10 +52,11 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { + log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName())); throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); } List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); - if (docs.size() == 1 && docs.get(0) == null) { + if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { throw new IllegalArgumentException("No input text or image provided"); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java similarity index 54% rename from common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java rename to common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 19ce98f1b7..05cc317856 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -15,9 +15,16 @@ import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; -public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { +/** + * This class provides a pre-processing function for multi-modal input data. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input. + * The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + */ +public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction { - public MultiModalEmbeddingPreProcessFunction() { + public MultiModalConnectorPreProcessFunction() { this.returnDirectlyForRemoteInferenceInput = true; } @@ -26,7 +33,12 @@ public void validate(MLInput mlInput) { validateTextDocsInput(mlInput); } - // The input will must have inputText even it's null, input image is optional. + /** + * @param mlInput The input data to be processed. + * This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + * If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. + * The inputText will always show up in the first document, even it's null. + */ @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java similarity index 85% rename from common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java rename to common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java index d5319fd6ce..4bc4c4cd8f 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java @@ -14,9 +14,7 @@ import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.w3c.dom.Text; -import java.rmi.Remote; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -24,11 +22,11 @@ import static org.junit.Assert.assertEquals; -public class MultiModalEmbeddingPreProcessFunctionTest { +public class MultiModalConnectorPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - MultiModalEmbeddingPreProcessFunction function; + MultiModalConnectorPreProcessFunction function; TextSimilarityInputDataSet textSimilarityInputDataSet; TextDocsInputDataSet textDocsInputDataSet; @@ -40,7 +38,7 @@ public class MultiModalEmbeddingPreProcessFunctionTest { @Before public void setUp() { - function = new MultiModalEmbeddingPreProcessFunction(); + function = new MultiModalConnectorPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); @@ -51,21 +49,21 @@ public void setUp() { } @Test - public void process_NullInput() { + public void testProcess_whenNullInput_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Preprocess function input can't be null"); function.apply(null); } @Test - public void process_WrongInput() { + public void testProcess_whenWrongInput_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); function.apply(textSimilarityInput); } @Test - public void process_input_text_image() { + public void testProcess_whenCorrectInput_expectCorrectOutput() { MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); RemoteInferenceInputDataSet dataSet = function.apply(mlInput); assertEquals(2, dataSet.getParameters().size()); @@ -74,7 +72,7 @@ public void process_input_text_image() { } @Test - public void process_input_text_only() { + public void testProcess_whenInputTextOnly_expectInputTextShowUp() { TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); RemoteInferenceInputDataSet dataSet = function.apply(mlInput); @@ -83,7 +81,7 @@ public void process_input_text_only() { } @Test - public void process_input_text_null() { + public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("No input text or image provided"); List docs = new ArrayList<>(); @@ -94,7 +92,7 @@ public void process_input_text_null() { } @Test - public void process_RemoteInferenceInput() { + public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); assertEquals(remoteInferenceInputDataSet, dataSet); }