Skip to content

Commit

Permalink
fixing UT
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Dec 1, 2023
1 parent 72122e1 commit 2baf8e7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand All @@ -33,6 +35,12 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import static org.mockito.ArgumentMatchers.anyString;


import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;

public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

Expand Down Expand Up @@ -87,6 +95,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(
inputData.getNodeId(),
inputData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.IndicesAdminClient;
Expand All @@ -21,10 +23,12 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -36,12 +40,16 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;

@SuppressWarnings("deprecation")
public class CreateIndexStepTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -95,7 +103,15 @@ public void setUp() throws Exception {
CreateIndexStep.indexMappingUpdated = indexMappingUpdated;
}

public void testCreateIndexStep() throws ExecutionException, InterruptedException {
public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException {


doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

@SuppressWarnings({ "unchecked" })
ArgumentCaptor<ActionListener<CreateIndexResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = createIndexStep.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

import org.opensearch.action.ingest.PutPipelineRequest;
import org.opensearch.action.support.master.AcknowledgedResponse;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand All @@ -25,10 +28,14 @@
import org.mockito.ArgumentCaptor;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;

@SuppressWarnings("deprecation")
public class CreateIngestPipelineStepTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -69,10 +76,16 @@ public void setUp() throws Exception {
when(adminClient.cluster()).thenReturn(clusterAdminClient);
}

public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException {
public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException, IOException {

CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler);

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<AcknowledgedResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = createIngestPipelineStep.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
package org.opensearch.flowframework.workflow;

import com.google.common.collect.ImmutableList;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -31,9 +33,12 @@
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;

public class ModelGroupStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;
Expand Down Expand Up @@ -79,6 +84,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep
return null;
}).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = modelGroupStep.execute(
inputData.getNodeId(),
inputData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -35,8 +37,11 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.anyString;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -134,6 +139,12 @@ public void testRegisterLocalModelSuccess() throws Exception {
return null;
}).when(machineLearningNodeClient).getTask(any(), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = registerLocalModelStep.execute(
workflowData.getNodeId(),
workflowData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -27,13 +29,16 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.anyString;
import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class RegisterRemoteModelStepTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -76,6 +81,12 @@ public void testRegisterRemoteModelSuccess() throws Exception {
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
workflowData.getNodeId(),
workflowData,
Expand Down

0 comments on commit 2baf8e7

Please sign in to comment.