Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 4, 2023
1 parent 923fd0f commit ee6ddc3
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.List;
Expand Down Expand Up @@ -43,12 +42,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) {
logger.info("Model deployed successfully");
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus())))
);
}
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus())))
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,6 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {

/*ActionListener<MLDeployModelResponse> deployActionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) {
logger.info("Model deployment successful");
registerModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("modelId", mlRegisterModelResponse.getModelId())))
);
}
}
@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
registerModelFuture.completeExceptionally(new IOException("Model deployment failed"));
}
};
machineLearningNodeClient.deploy(mlRegisterModelResponse.getModelId(), deployActionListener);*/
// scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient,
// mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC);

/*DeployModel deployModel = new DeployModel();
deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());*/
logger.info("Model registration successful");
registerModelFuture.complete(
new WorkflowData(
Expand Down Expand Up @@ -116,7 +92,6 @@ public void onFailure(Exception e) {
for (WorkflowData workflowData : data) {
if (workflowData != null) {
Map<String, Object> content = workflowData.getContent();
logger.info("Previous step sent content: {}", content);

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.client.NoOpNodeClient;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.Mockito.*;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class DeployModelTests extends OpenSearchTestCase {

private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private NodeClient nodeClient;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")));

MockitoAnnotations.openMocks(this);

nodeClient = new NoOpNodeClient("xyz");

}

public void testDeployModel() {

String taskId = "taskId";
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;

DeployModel deployModel = new DeployModel(nodeClient);

ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = deployModel.execute(List.of(inputData));

// TODO: Find a way to verify the below
// verify(machineLearningNodeClient).deploy(eq(MLRegisterModelInput.class), actionListenerCaptor.capture());

assertTrue(future.isCompletedExceptionally());

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.*;
import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.Mockito.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class RegisterModelStepTests extends OpenSearchTestCase {
Expand All @@ -38,6 +42,9 @@ public class RegisterModelStepTests extends OpenSearchTestCase {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private NodeClient nodeClient;

@Mock
ActionListener<MLRegisterModelResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -68,6 +75,9 @@ public void setUp() throws Exception {

public void testRegisterModel() throws ExecutionException, InterruptedException {

String taskId = "abcd";
String modelId = "efgh";
String status = MLTaskState.CREATED.name();
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(FunctionName.from("REMOTE"))
.modelName("testModelName")
Expand All @@ -78,25 +88,20 @@ public void testRegisterModel() throws ExecutionException, InterruptedException
RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData));

/*try (MockedStatic<MLClient> mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) {
mlClientMockedStatic
.when(() -> MLClient.createMLClient(any(NodeClient.class)))
.thenReturn(machineLearningNodeClient);
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(2);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());

}*/
// when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient);
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());
actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc"));

assertTrue(future.isDone() && !future.isCompletedExceptionally());

Map<String, Object> outputData = Map.of("index-name", "demo");
CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData));

assertTrue(future.isDone() && future.isCompletedExceptionally());
// TODO: Find a way to verify the below
// verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());

assertEquals(outputData, future.get().getContent());
assertTrue(future.isCompletedExceptionally());

}

Expand Down
49 changes: 43 additions & 6 deletions src/test/resources/template/demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,27 @@
"nodes": [
{
"id": "fetch_model",
"type": "demo_delay_3",
"inputs": {
"ingest_key": "ingest_value"
}
"type": "demo_delay_3"
},
{
"id": "create_index",
"type": "demo_delay_3"
},
{
"id": "create_ingest_pipeline",
"type": "demo_delay_3"
},
{
"id": "create_search_pipeline",
"type": "demo_delay_5"
},
{
"id": "create_neural_search_index",
"type": "demo_delay_3"
},
{
"id": "register_model",
"type": "register_model",
"type": "demo_delay_3",
"inputs": {
"name": "openAI-gpt-3.5-turbo",
"function_name": "remote",
Expand All @@ -26,7 +39,7 @@
},
{
"id": "deploy_model",
"type": "deploy_model",
"type": "demo_delay_3",
"inputs": {
"model_id": "abc"
}
Expand All @@ -35,6 +48,30 @@
"edges": [
{
"source": "fetch_model",
"dest": "create_index"
},
{
"source": "create_index",
"dest": "create_ingest_pipeline"
},
{
"source": "fetch_model",
"dest": "create_search_pipeline"
},
{
"source": "create_ingest_pipeline",
"dest": "create_neural_search_index"
},
{
"source": "create_search_pipeline",
"dest": "create_neural_search_index"
},
{
"source": "create_neural_search_index",
"dest": "register_model"
},
{
"source": "register_model",
"dest": "deploy_model"
}
]
Expand Down
Loading

0 comments on commit ee6ddc3

Please sign in to comment.