Skip to content

Commit

Permalink
apply spotless after rebase
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Nov 28, 2023
1 parent b413994 commit 5bec616
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package org.opensearch.ml.engine.algorithms.text_similarity;

import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE;

import java.util.ArrayList;
import java.util.List;

Expand All @@ -40,7 +38,7 @@

@Function(FunctionName.TEXT_SIMILARITY)
public class TextSimilarityCrossEncoderModel extends DLModel {

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

public class TextSimilarityTranslator extends SentenceTransformerTranslator {
public final String SIMILARITY_NAME = "similarity";

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
String sentence = input.getAsString(0);
Expand Down Expand Up @@ -78,13 +79,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
DataType dataType = ndArray.getDataType();
MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name());
ByteBuffer buffer = ndArray.toByteBuffer();
ModelTensor tensor = ModelTensor.builder()
.name(name)
.data(data)
.shape(shape)
.dataType(mlResultDataType)
.byteBuffer(buffer)
.build();
ModelTensor tensor = ModelTensor
.builder()
.name(name)
.data(data)
.shape(shape)
.dataType(mlResultDataType)
.byteBuffer(buffer)
.build();
outputs.add(tensor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
*/
package org.opensearch.ml.engine.algorithms.text_similarity;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.DLModel.*;

import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
Expand All @@ -30,7 +39,6 @@
import java.util.Map;
import java.util.UUID;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -61,15 +69,6 @@
import ai.djl.translate.TranslatorContext;
import lombok.extern.log4j.Log4j2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.DLModel.*;

@Log4j2
public class TextSimilarityCrossEncoderModelTest {

Expand All @@ -88,14 +87,15 @@ public void setUp() throws URISyntaxException {
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(mlCachePath, encryptor);
model = MLModel.builder()
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TEXT_SIMILARITY)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
model = MLModel
.builder()
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TEXT_SIMILARITY)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
modelHelper = new ModelHelper(mlEngine);
params = new HashMap<>();
modelZipFile = new File(getClass().getResource("TinyBERT-CE.zip").toURI());
Expand All @@ -104,10 +104,9 @@ public void setUp() throws URISyntaxException {
params.put(ML_ENGINE, mlEngine);
textSimilarityCrossEncoderModel = new TextSimilarityCrossEncoderModel();

inputDataSet = TextSimilarityInputDataSet.builder()
.textDocs(Arrays.asList(
"That is a happy dog",
"it's summer"))
inputDataSet = TextSimilarityInputDataSet
.builder()
.textDocs(Arrays.asList("That is a happy dog", "it's summer"))
.queryText("it's summer")
.build();
}
Expand All @@ -128,15 +127,15 @@ public void test_TextSimilarity_Translator_ProcessInput() throws URISyntaxExcept
when(input.getAsString(0)).thenReturn(testSentence);
when(input.getAsString(1)).thenReturn(testSentence);
NDArray indiceNdArray = mock(NDArray.class);
when(indiceNdArray.toLongArray()).thenReturn(new long[]{102l, 101l});
when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l });
when(manager.create((long[]) any())).thenReturn(indiceNdArray);
doNothing().when(indiceNdArray).setName(any());
NDList outputList = textSimilarityTranslator.processInput(translatorContext, input);
assertEquals(3, outputList.size());
Iterator<NDArray> iterator = outputList.iterator();
while (iterator.hasNext()) {
NDArray ndArray = iterator.next();
long [] output = ndArray.toLongArray();
long[] output = ndArray.toLongArray();
assertEquals(2, output.length);
}
}
Expand All @@ -155,10 +154,10 @@ public void test_TextSimilarity_Translator_ProcessOutput() throws URISyntaxExcep
when(ndArray.nonzero()).thenReturn(ndArray);
when(ndArray.squeeze()).thenReturn(ndArray);
when(ndArray.getFloat(any())).thenReturn(1.0f);
when(ndArray.toArray()).thenReturn(new Number[]{1.245f});
when(ndArray.toArray()).thenReturn(new Number[] { 1.245f });
when(ndArray.getName()).thenReturn("output");
when(ndArray.getShape()).thenReturn(shape);
when(shape.getShape()).thenReturn(new long[]{1});
when(shape.getShape()).thenReturn(new long[] { 1 });
when(ndArray.getDataType()).thenReturn(DataType.FLOAT32);
List<NDArray> ndArrayList = Collections.singletonList(ndArray);
NDList ndList = new NDList(ndArrayList);
Expand All @@ -181,7 +180,7 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti
ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput);
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
assertEquals(2, mlModelOutputs.size());
for (int i=0;i<mlModelOutputs.size();i++) {
for (int i = 0; i < mlModelOutputs.size(); i++) {
ModelTensors tensors = mlModelOutputs.get(i);
List<ModelTensor> mlModelTensors = tensors.getMlModelTensors();
assertEquals(1, mlModelTensors.size());
Expand All @@ -190,13 +189,14 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti
textSimilarityCrossEncoderModel.close();
}


@Test
public void initModel_NullModelHelper() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE.zip").toURI()));
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)
);
assert (e.getMessage().equals("model helper is null"));
}

Expand All @@ -205,16 +205,20 @@ public void initModel_NullMLEngine() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("TinyBERT-CE.zip").toURI()));
params.put(MODEL_HELPER, modelHelper);
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)
);
assert (e.getMessage().equals("ML engine is null"));
}

@Test
public void initModel_NullModelId() {
model.setModelId(null);
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)
);
assert (e.getMessage().equals("model id is null"));
}

Expand All @@ -224,8 +228,7 @@ public void initModel_WrongModelFile() throws URISyntaxException {
params.put(MODEL_HELPER, modelHelper);
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("../text_embedding/wrong_zip_with_2_pt_file.zip").toURI()));
params.put(ML_ENGINE, mlEngine);
MLException e = assertThrows(MLException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
MLException e = assertThrows(MLException.class, () -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
Throwable rootCause = e.getCause();
assert (rootCause instanceof IllegalArgumentException);
assert (rootCause.getMessage().equals("found multiple models"));
Expand All @@ -234,39 +237,52 @@ public void initModel_WrongModelFile() throws URISyntaxException {
@Test
public void initModel_WrongFunctionName() {
MLModel mlModel = model.toBuilder().algorithm(FunctionName.KMEANS).build();
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(mlModel, params, encryptor));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(mlModel, params, encryptor)
);
assert (e.getMessage().equals("wrong function name"));
}

@Test
public void predict_NullModelHelper() {
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel
.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())
);
assert (e.getMessage().equals("model not deployed"));
}

@Test
public void predict_NullModelId() {
model.setModelId(null);
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor));
IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)
);
assert (e.getMessage().equals("model id is null"));
IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()));
IllegalArgumentException e2 = assertThrows(
IllegalArgumentException.class,
() -> textSimilarityCrossEncoderModel
.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())
);
assert (e2.getMessage().equals("model not deployed"));
}

@Test
public void predict_AfterModelClosed() {
textSimilarityCrossEncoderModel.initModel(model, params, encryptor);
textSimilarityCrossEncoderModel.close();
MLException e = assertThrows(MLException.class,
() -> textSimilarityCrossEncoderModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build()));
MLException e = assertThrows(
MLException.class,
() -> textSimilarityCrossEncoderModel
.predict(MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build())
);
log.info(e.getMessage());
assert (e.getMessage().startsWith("Failed to inference TEXT_SIMILARITY"));
}

@After
public void tearDown() {
FileUtils.deleteFileQuietly(mlCachePath);
Expand Down

0 comments on commit 5bec616

Please sign in to comment.