From 1b055e8ee0c061302a9081f2799896fa600d5400 Mon Sep 17 00:00:00 2001 From: Junwei Dai <59641585+junweid62@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:03:07 -0700 Subject: [PATCH] Fix and Refactor Test with Builder Pattern for Issue #2825 (#859) * Fix: Use @Builder for constructor in test Signed-off-by: Junwei Dai * Fix: change parameter for different test case Signed-off-by: Junwei Dai --------- Signed-off-by: Junwei Dai Co-authored-by: Junwei Dai --- .../workflow/DeployModelStepTests.java | 34 +---------- .../RegisterLocalCustomModelStepTests.java | 57 +++---------------- ...RegisterLocalPretrainedModelStepTests.java | 40 +++---------- ...sterLocalSparseEncodingModelStepTests.java | 40 +++---------- 4 files changed, 27 insertions(+), 144 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index b05e43ed4..9b945a05c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -127,22 +127,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); @@ -215,22 +200,7 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.FAILED, - null, - null, - null, - null, - null, - null, - "error", - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.FAILED).async(false).error("error").build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 061d9f8c8..6a6809d07 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -144,22 +144,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); @@ -239,22 +224,7 @@ public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); @@ -342,22 +312,13 @@ public void testRegisterLocalCustomModelTaskFailure() { // Stub get ml task for failure case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.FAILED, - null, - null, - null, - null, - null, - null, - testErrorMessage, - null, - false - ); + MLTask output = MLTask.builder() + .taskId(taskId) + .modelId(modelId) + .state(MLTaskState.FAILED) + .error(testErrorMessage) + .async(false) + .build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index afe97bacb..162b97dba 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -137,22 +137,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); @@ -247,22 +232,13 @@ public void testRegisterLocalPretrainedModelTaskFailure() { // Stub get ml task for failure case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.FAILED, - null, - null, - null, - null, - null, - null, - testErrorMessage, - null, - false - ); + MLTask output = MLTask.builder() + .taskId(taskId) + .modelId(modelId) + .state(MLTaskState.FAILED) + .error(testErrorMessage) + .async(false) + .build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 9c35af33c..79d7bb883 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -140,22 +140,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { // Stub getTask for success case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); + MLTask output = MLTask.builder().taskId(taskId).modelId(modelId).state(MLTaskState.COMPLETED).async(false).build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any()); @@ -252,22 +237,13 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() { // Stub get ml task for failure case doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.FAILED, - null, - null, - null, - null, - null, - null, - testErrorMessage, - null, - false - ); + MLTask output = MLTask.builder() + .taskId(taskId) + .modelId(modelId) + .state(MLTaskState.FAILED) + .error(testErrorMessage) + .async(false) + .build(); actionListener.onResponse(output); return null; }).when(machineLearningNodeClient).getTask(any(), any());