From 5bec6162878fdd2fe15489c94a215f295ce00099 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 28 Nov 2023 15:36:20 -0800 Subject: [PATCH] apply spotless after rebase Signed-off-by: HenryL27 --- .../TextSimilarityCrossEncoderModel.java | 4 +- .../TextSimilarityTranslator.java | 16 +-- .../TextSimilarityCrossEncoderModelTest.java | 110 ++++++++++-------- 3 files changed, 73 insertions(+), 57 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java index 84b22986df..d3049c851a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java @@ -17,8 +17,6 @@ */ package org.opensearch.ml.engine.algorithms.text_similarity; -import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE; - import java.util.ArrayList; import java.util.List; @@ -40,7 +38,7 @@ @Function(FunctionName.TEXT_SIMILARITY) public class TextSimilarityCrossEncoderModel extends DLModel { - + @Override public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { MLInputDataset inputDataSet = mlInput.getInputDataset(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java index 996a592464..4967d2035b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java @@ -38,6 +38,7 @@ public class TextSimilarityTranslator extends SentenceTransformerTranslator { public final String SIMILARITY_NAME = "similarity"; + @Override public NDList processInput(TranslatorContext ctx, Input input) { String sentence = input.getAsString(0); @@ -78,13 +79,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) { DataType dataType = ndArray.getDataType(); MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); ByteBuffer buffer = ndArray.toByteBuffer(); - ModelTensor tensor = ModelTensor.builder() - .name(name) - .data(data) - .shape(shape) - .dataType(mlResultDataType) - .byteBuffer(buffer) - .build(); + ModelTensor tensor = ModelTensor + .builder() + .name(name) + .data(data) + .shape(shape) + .dataType(mlResultDataType) + .byteBuffer(buffer) + .build(); outputs.add(tensor); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java index 300d2e77f9..73a25a4ee7 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java @@ -17,6 +17,15 @@ */ package org.opensearch.ml.engine.algorithms.text_similarity; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.DLModel.*; + import java.io.File; import java.io.IOException; import java.net.URISyntaxException; @@ -30,7 +39,6 @@ import java.util.Map; import java.util.UUID; -import org.apache.commons.lang3.tuple.Pair; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -61,15 +69,6 @@ import ai.djl.translate.TranslatorContext; import lombok.extern.log4j.Log4j2; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.engine.algorithms.DLModel.*; - @Log4j2 public class TextSimilarityCrossEncoderModelTest { @@ -88,14 +87,15 @@ public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); - model = MLModel.builder() - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .name("test_model_name") - .modelId("test_model_id") - .algorithm(FunctionName.TEXT_SIMILARITY) - .version("1.0.0") - .modelState(MLModelState.TRAINED) - .build(); + model = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.TEXT_SIMILARITY) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); modelHelper = new ModelHelper(mlEngine); params = new HashMap<>(); modelZipFile = new File(getClass().getResource("TinyBERT-CE.zip").toURI()); @@ -104,10 +104,9 @@ public void setUp() throws URISyntaxException { params.put(ML_ENGINE, mlEngine); textSimilarityCrossEncoderModel = new TextSimilarityCrossEncoderModel(); - inputDataSet = TextSimilarityInputDataSet.builder() - .textDocs(Arrays.asList( - "That is a happy dog", - "it's summer")) + inputDataSet = TextSimilarityInputDataSet + .builder() + .textDocs(Arrays.asList("That is a happy dog", "it's summer")) .queryText("it's summer") .build(); } @@ -128,7 +127,7 @@ public void test_TextSimilarity_Translator_ProcessInput() throws URISyntaxExcept when(input.getAsString(0)).thenReturn(testSentence); when(input.getAsString(1)).thenReturn(testSentence); NDArray indiceNdArray = mock(NDArray.class); - when(indiceNdArray.toLongArray()).thenReturn(new long[]{102l, 101l}); + when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l }); when(manager.create((long[]) any())).thenReturn(indiceNdArray); doNothing().when(indiceNdArray).setName(any()); NDList outputList = textSimilarityTranslator.processInput(translatorContext, input); @@ -136,7 +135,7 @@ public void test_TextSimilarity_Translator_ProcessInput() throws URISyntaxExcept Iterator iterator = outputList.iterator(); while (iterator.hasNext()) { NDArray ndArray = iterator.next(); - long [] output = ndArray.toLongArray(); + long[] output = ndArray.toLongArray(); assertEquals(2, output.length); } } @@ -155,10 +154,10 @@ public void test_TextSimilarity_Translator_ProcessOutput() throws URISyntaxExcep when(ndArray.nonzero()).thenReturn(ndArray); when(ndArray.squeeze()).thenReturn(ndArray); when(ndArray.getFloat(any())).thenReturn(1.0f); - when(ndArray.toArray()).thenReturn(new Number[]{1.245f}); + when(ndArray.toArray()).thenReturn(new Number[] { 1.245f }); when(ndArray.getName()).thenReturn("output"); when(ndArray.getShape()).thenReturn(shape); - when(shape.getShape()).thenReturn(new long[]{1}); + when(shape.getShape()).thenReturn(new long[] { 1 }); when(ndArray.getDataType()).thenReturn(DataType.FLOAT32); List ndArrayList = Collections.singletonList(ndArray); NDList ndList = new NDList(ndArrayList); @@ -181,7 +180,7 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput); List mlModelOutputs = output.getMlModelOutputs(); assertEquals(2, mlModelOutputs.size()); - for (int i=0;i mlModelTensors = tensors.getMlModelTensors(); assertEquals(1, mlModelTensors.size()); @@ -190,13 +189,14 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti textSimilarityCrossEncoderModel.close(); } - @Test public void initModel_NullModelHelper() throws URISyntaxException { Map params = new HashMap<>(); params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE.zip").toURI())); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); assert (e.getMessage().equals("model helper is null")); } @@ -205,16 +205,20 @@ public void initModel_NullMLEngine() throws URISyntaxException { Map params = new HashMap<>(); params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE.zip").toURI())); params.put(MODEL_HELPER, modelHelper); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); assert (e.getMessage().equals("ML engine is null")); } @Test public void initModel_NullModelId() { model.setModelId(null); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); assert (e.getMessage().equals("model id is null")); } @@ -224,8 +228,7 @@ public void initModel_WrongModelFile() throws URISyntaxException { params.put(MODEL_HELPER, modelHelper); params.put(MODEL_ZIP_FILE, new File(getClass().getResource("../text_embedding/wrong_zip_with_2_pt_file.zip").toURI())); params.put(ML_ENGINE, mlEngine); - MLException e = assertThrows(MLException.class, - () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + MLException e = assertThrows(MLException.class, () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); Throwable rootCause = e.getCause(); assert (rootCause instanceof IllegalArgumentException); assert (rootCause.getMessage().equals("found multiple models")); @@ -234,26 +237,36 @@ public void initModel_WrongModelFile() throws URISyntaxException { @Test public void initModel_WrongFunctionName() { MLModel mlModel = model.toBuilder().algorithm(FunctionName.KMEANS).build(); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.initModel(mlModel, params, encryptor)); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(mlModel, params, encryptor) + ); assert (e.getMessage().equals("wrong function name")); } @Test public void predict_NullModelHelper() { - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); assert (e.getMessage().equals("model not deployed")); } @Test public void predict_NullModelId() { model.setModelId(null); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor) + ); assert (e.getMessage().equals("model id is null")); - IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class, - () -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())); + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); assert (e2.getMessage().equals("model not deployed")); } @@ -261,12 +274,15 @@ public void predict_NullModelId() { public void predict_AfterModelClosed() { textSimilarityCrossEncoderModel.initModel(model, params, encryptor); textSimilarityCrossEncoderModel.close(); - MLException e = assertThrows(MLException.class, - () -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())); + MLException e = assertThrows( + MLException.class, + () -> textSimilarityCrossEncoderModel + .predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()) + ); log.info(e.getMessage()); assert (e.getMessage().startsWith("Failed to inference TEXT_SIMILARITY")); } - + @After public void tearDown() { FileUtils.deleteFileQuietly(mlCachePath);