From 2baf8e7a4505ae9d9d71d4dde90380aef1250b9a Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Fri, 1 Dec 2023 13:56:56 -0800 Subject: [PATCH] fixing UT Signed-off-by: Amit Galitzky --- .../workflow/CreateConnectorStepTests.java | 14 ++++++++++++++ .../workflow/CreateIndexStepTests.java | 18 +++++++++++++++++- .../CreateIngestPipelineStepTests.java | 15 ++++++++++++++- .../workflow/ModelGroupStepTests.java | 11 +++++++++++ .../workflow/RegisterLocalModelStepTests.java | 11 +++++++++++ .../workflow/RegisterRemoteModelStepTests.java | 11 +++++++++++ 6 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 1135a0ca6..d58e1734b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,7 +8,9 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -33,6 +35,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.ArgumentMatchers.anyString; + + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; + public class CreateConnectorStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; @@ -87,6 +95,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), inputData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 6aec92fa8..bc5bd463b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -10,6 +10,8 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -21,10 +23,12 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -36,12 +40,16 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @SuppressWarnings("deprecation") public class CreateIndexStepTests extends OpenSearchTestCase { @@ -95,7 +103,15 @@ public void setUp() throws Exception { CreateIndexStep.indexMappingUpdated = indexMappingUpdated; } - public void testCreateIndexStep() throws ExecutionException, InterruptedException { + public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { + + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIndexStep.execute( diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 39cdadc34..d6bd0fcf2 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -10,13 +10,16 @@ import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -25,10 +28,14 @@ import org.mockito.ArgumentCaptor; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @SuppressWarnings("deprecation") public class CreateIngestPipelineStepTests extends OpenSearchTestCase { @@ -69,10 +76,16 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); } - public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException { + public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException, IOException { CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIngestPipelineStep.execute( diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index b00f2e5bf..41c6aad4d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -9,7 +9,9 @@ package org.opensearch.flowframework.workflow; import com.google.common.collect.ImmutableList; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -31,9 +33,12 @@ import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; public class ModelGroupStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; @@ -79,6 +84,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = modelGroupStep.execute( inputData.getNodeId(), inputData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index a2d3939ee..a6fdb60fd 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -10,11 +10,13 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -35,8 +37,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.mockito.ArgumentMatchers.anyString; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -134,6 +139,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).getTask(any(), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 7b37301a1..ea9c2bd73 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -10,7 +10,9 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -27,6 +29,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.mockito.ArgumentMatchers.anyString; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.mockito.ArgumentMatchers.any; @@ -34,6 +38,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterRemoteModelStepTests extends OpenSearchTestCase { @@ -76,6 +81,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), workflowData,