Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Testcases for MLEngine.java #1536

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.UUID;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
Expand All @@ -54,7 +56,7 @@ public class MLEngineTest {
@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
MLEngine mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
divitr marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
Expand All @@ -68,6 +70,22 @@ public void testPrebuiltModelPath() {
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath);
}

@Test
public void testDeployModelZipPath() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "testGetDeployModelZipPath"?

String modelId = "test_id";
String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b";
String modelZipPath = mlEngine.getDeployModelZipPath(modelId, modelName);
assertEquals(mlEngine.getMlCachePath() + "/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip", modelZipPath);
}

@Test
public void testGetDeployModelChunkPath() {
String modelId = "test_id";
Integer chunkNum = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os is the number of chunks hard coded to be always 10? Wondering if we should have a loop here to test all of them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os is the number of chunks hard coded to be always 10? Wondering if we should have a loop here to test all of them.

@austintlee Thanks for the feedback! I added a loop here to check for chunks between 1 and 10.

Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum);
assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/1"), chunkPath);
}

@Test
public void predictKMeans() {
MLModel model = trainKMeansModel();
Expand Down Expand Up @@ -142,6 +160,33 @@ public void train_NullInput() {
}
}

@Test
public void train_NullTrainable() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testTrainNullTraininable? (no underscore).

exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.train(mlInput);
}

@Test
public void predict_NullTrainAndPredictable() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we improve the name of the test here?

exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
MLModel mlModel = Mockito.mock(MLModel.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.predict(mlInput, mlModel);
}

@Test
public void trainAndPredict_NullTrainable() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on the testcase name.

exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.trainAndPredict(mlInput);
}
//TODO: fix mockito error
@Ignore
@Test
Comment on lines 192 to 194
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os we should fix this while we're at it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. @divitr if you think it's easy enough to fix in this PR, please go forward. Otherwise we can create another issue to take out the @ignore and make this test working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os I don't think this issue is that easy to fix in this PR, I think we should open a new issue. I can open up a new issue if you'd like.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll create an issue later.

Expand Down
Loading