From dc20feb9d25b7c60bd76c8f5cc66a5a1eed860e2 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 14 Aug 2024 09:24:27 -0700 Subject: [PATCH] Refactor workflow step resource updates to eliminate duplication (#796) * Refactor workflow step resource updates to eliminate duplication Signed-off-by: Daniel Widdis * Add coverage and changelog Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- CHANGELOG.md | 1 + .../indices/FlowFrameworkIndicesHandler.java | 45 +++++- .../workflow/AbstractCreatePipelineStep.java | 46 ++---- .../AbstractRegisterLocalModelStep.java | 55 +++---- .../AbstractRetryableWorkflowStep.java | 45 ++---- .../workflow/CreateConnectorStep.java | 43 ++---- .../workflow/CreateIndexStep.java | 40 ++---- .../workflow/CreateIngestPipelineStep.java | 4 - .../workflow/CreateSearchPipelineStep.java | 4 - .../workflow/DeployModelStep.java | 21 +-- .../flowframework/workflow/NoOpStep.java | 1 + .../workflow/RegisterAgentStep.java | 45 ++---- .../workflow/RegisterModelGroupStep.java | 54 ++----- .../workflow/RegisterRemoteModelStep.java | 99 ++++++------- .../flowframework/workflow/ReindexStep.java | 3 - .../flowframework/workflow/WorkflowData.java | 6 +- .../FlowFrameworkIndicesHandlerTests.java | 51 +++++++ .../workflow/CreateConnectorStepTests.java | 10 +- .../workflow/CreateIndexStepTests.java | 10 +- .../CreateIngestPipelineStepTests.java | 10 +- .../CreateSearchPipelineStepTests.java | 16 +-- .../workflow/DeployModelStepTests.java | 10 +- .../flowframework/workflow/NoOpStepTests.java | 46 ++++++ .../workflow/RegisterAgentTests.java | 16 +-- .../RegisterLocalCustomModelStepTests.java | 93 +++++++++++- ...RegisterLocalPretrainedModelStepTests.java | 10 +- ...sterLocalSparseEncodingModelStepTests.java | 10 +- .../workflow/RegisterModelGroupStepTests.java | 13 +- .../RegisterRemoteModelStepTests.java | 134 ++++++++++++++++-- .../workflow/ReindexStepTests.java | 12 -- 30 files changed, 515 insertions(+), 438 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdaed0ea0..f5d147ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,3 +24,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Documentation ### Maintenance ### Refactoring +- Refactor workflow step resource updates to eliminate duplication ([#796](https://github.com/opensearch-project/flow-framework/pull/796)) diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 63ac6f7d4..02ef8a825 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -44,6 +44,7 @@ import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -666,13 +667,13 @@ public void updateFlowFrameworkSystemIndexDocWithScript( /** * Creates a new ResourceCreated object and a script to update the state index * @param workflowId workflowId for the relevant step - * @param nodeId WorkflowData object with relevent step information + * @param nodeId current process node (workflow step) id * @param workflowStepName the workflowstep name that created the resource * @param resourceId the id of the newly created resource * @param listener the ActionListener for this step to handle completing the future after update * @throws IOException if parsing fails on new resource */ - public void updateResourceInStateIndex( + private void updateResourceInStateIndex( String workflowId, String nodeId, String workflowStepName, @@ -697,6 +698,44 @@ public void updateResourceInStateIndex( updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> { logger.info("updated resources created of {}", workflowId); listener.onResponse(updateResponse); - }, exception -> { listener.onFailure(exception); })); + }, listener::onFailure)); + } + + /** + * Adds a resource to the state index, including common exception handling + * @param currentNodeInputs Inputs to the current node + * @param nodeId current process node (workflow step) id + * @param workflowStepName the workflow step name that created the resource + * @param resourceId the id of the newly created resource + * @param listener the ActionListener for this step to handle completing the future after update + */ + public void addResourceToStateIndex( + WorkflowData currentNodeInputs, + String nodeId, + String workflowStepName, + String resourceId, + ActionListener listener + ) { + String resourceName = getResourceByWorkflowStep(workflowStepName); + try { + updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + nodeId, + workflowStepName, + resourceId, + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + listener.onResponse(new WorkflowData(Map.of(resourceName, resourceId), currentNodeInputs.getWorkflowId(), nodeId)); + }, exception -> { + String errorMessage = "Failed to update new created " + nodeId + " resource " + workflowStepName + " id " + resourceId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); + } catch (Exception e) { + String errorMessage = "Failed to parse and update new created resource"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java index e23d88b63..bbaa77204 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java @@ -33,7 +33,6 @@ import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -98,43 +97,14 @@ public PlainActionFuture execute( @Override public void onResponse(AcknowledgedResponse acknowledgedResponse) { - String resourceName = getResourceByWorkflowStep(getName()); - try { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - pipelineId, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - // PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead - // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here - createPipelineFuture.onResponse( - new WorkflowData( - Map.of(resourceName, pipelineId), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + pipelineId; - logger.error(errorMessage, exception); - createPipelineFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - createPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + // PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + pipelineId, + createPipelineFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 10c6a884b..3f6d19219 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -49,7 +49,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; @@ -189,58 +188,38 @@ public PlainActionFuture execute( // Attempt to retrieve the model ID retryableGetMlTask( - currentNodeInputs.getWorkflowId(), + currentNodeInputs, currentNodeId, registerLocalModelFuture, taskId, "Local model registration", - ActionListener.wrap(mlTask -> { - + ActionListener.wrap(mlTaskWorkflowData -> { // Registered Model Resource has been updated String resourceName = getResourceByWorkflowStep(getName()); - String id = getResourceId(mlTask); - if (Boolean.TRUE.equals(deploy)) { - - // Simulate Model deployment step and update resources created - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - DeployModelStep.NAME, - id, - ActionListener.wrap(deployUpdateResponse -> { - logger.info( - "successfully updated resources created in state index: {}", - deployUpdateResponse.getIndex() - ); - registerLocalModelFuture.onResponse( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, id), - Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, deployUpdateException -> { + String id = (String) mlTaskWorkflowData.getContent().get(resourceName); + ActionListener deployUpdateListener = ActionListener.wrap( + deployUpdateResponse -> registerLocalModelFuture.onResponse(mlTaskWorkflowData), + deployUpdateException -> { String errorMessage = "Failed to update simulated deploy step resource " + id; logger.error(errorMessage, deployUpdateException); registerLocalModelFuture.onFailure( new FlowFrameworkException(errorMessage, ExceptionsHelper.status(deployUpdateException)) ); - }) + } ); - } else { - registerLocalModelFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) + // Simulate Model deployment step and update resources created + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + DeployModelStep.NAME, + id, + deployUpdateListener ); + } else { + registerLocalModelFuture.onResponse(mlTaskWorkflowData); } - }, exception -> { registerLocalModelFuture.onFailure(exception); }) + }, registerLocalModelFuture::onFailure) ); }, exception -> { Exception e = getSafeException(exception); diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index c1933b0c4..3e6dc9a85 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.FutureUtils; @@ -24,8 +23,11 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.threadpool.ThreadPool; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -60,7 +62,7 @@ protected AbstractRetryableWorkflowStep( /** * Retryable get ml task - * @param workflowId the workflow id + * @param currentNodeInputs the current Node Inputs * @param nodeId the workflow node id * @param future the workflow step future * @param taskId the ml task id @@ -68,12 +70,12 @@ protected AbstractRetryableWorkflowStep( * @param mlTaskListener the ML Task Listener */ protected void retryableGetMlTask( - String workflowId, + WorkflowData currentNodeInputs, String nodeId, PlainActionFuture future, String taskId, String workflowStep, - ActionListener mlTaskListener + ActionListener mlTaskListener ) { CompletableFuture.runAsync(() -> { do { @@ -82,34 +84,13 @@ protected void retryableGetMlTask( String id = getResourceId(response); switch (response.getState()) { case COMPLETED: - try { - logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - id, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - mlTaskListener.onResponse(response); - }, exception -> { - String errorMessage = "Failed to update new created " - + nodeId - + " resource " - + getName() - + " id " - + id; - logger.error(errorMessage, exception); - mlTaskListener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource " + resourceName + " id " + id; - logger.error(errorMessage, e); - mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("{} successful for {} and {} {}", workflowStep, currentNodeInputs, resourceName, id); + ActionListener resourceListener = ActionListener.wrap(r -> { + Map content = new HashMap<>(r.getContent()); + content.put(REGISTER_MODEL_STATUS, response.getState().toString()); + mlTaskListener.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId())); + }, mlTaskListener::onFailure); + flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, nodeId, getName(), id, resourceListener); break; case FAILED: case COMPLETED_WITH_ERROR: diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 484807ce3..cd9d2e2ac 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -43,7 +43,6 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; @@ -85,40 +84,14 @@ public PlainActionFuture execute( @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - String resourceName = getResourceByWorkflowStep(getName()); - try { - logger.info("Created connector successfully"); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlCreateConnectorResponse.getConnectorId(), - ActionListener.wrap(response -> { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); - createConnectorFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlCreateConnectorResponse.getConnectorId(); - logger.error(errorMessage, exception); - createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Created connector successfully"); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlCreateConnectorResponse.getConnectorId(), + createConnectorFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 32ca9e9f6..3ac82cb7a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -21,13 +21,11 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.index.mapper.MapperService; -import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; @@ -37,7 +35,6 @@ import static java.util.Collections.singletonMap; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -108,35 +105,14 @@ public PlainActionFuture execute( } client.admin().indices().create(createIndexRequest, ActionListener.wrap(acknowledgedResponse -> { - String resourceName = getResourceByWorkflowStep(getName()); logger.info("Created index: {}", indexName); - try { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - indexName, - ActionListener.wrap(response -> { - logger.info("successfully updated resource created in state index: {}", response.getIndex()); - createIndexFuture.onResponse( - new WorkflowData(Map.of(resourceName, indexName), currentNodeInputs.getWorkflowId(), currentNodeId) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + indexName; - logger.error(errorMessage, exception); - createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - } catch (IOException ex) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, ex); - createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(ex))); - } + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + indexName, + createIndexFuture + ); }, ex -> { Exception e = getSafeException(ex); String errorMessage = (e == null ? "Failed to create the index " + indexName : e.getMessage()); @@ -155,7 +131,7 @@ public PlainActionFuture execute( // to encounter the same behavior and not suddenly have to add `_doc` while using our create_index step // related bug: https://github.com/opensearch-project/OpenSearch/issues/12775 private static Map prepareMappings(Map source) { - if (source.containsKey("mappings") == false || (source.get("mappings") instanceof Map) == false) { + if (!source.containsKey("mappings") || !(source.get("mappings") instanceof Map)) { return source; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index c89b25d16..13f1b6fe5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -17,8 +15,6 @@ * Step to create an ingest pipeline */ public class CreateIngestPipelineStep extends AbstractCreatePipelineStep { - private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "create_ingest_pipeline"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java index 1dc8fc745..67bc5fc0d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -17,8 +15,6 @@ * Step to create a search pipeline */ public class CreateSearchPipelineStep extends AbstractCreatePipelineStep { - private static final Logger logger = LogManager.getLogger(CreateSearchPipelineStep.class); - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "create_search_pipeline"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 56c2a6181..e51100f04 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -28,7 +28,6 @@ import java.util.Set; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -93,24 +92,16 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { // Attempt to retrieve the model ID retryableGetMlTask( - currentNodeInputs.getWorkflowId(), + currentNodeInputs, currentNodeId, deployModelFuture, taskId, "Deploy model", - ActionListener.wrap(mlTask -> { - // Deployed Model Resource has been updated - String resourceName = getResourceByWorkflowStep(getName()); - String id = getResourceId(mlTask); - deployModelFuture.onResponse( - new WorkflowData(Map.of(resourceName, id), currentNodeInputs.getWorkflowId(), currentNodeId) - ); - }, - e -> { - deployModelFuture.onFailure( - new FlowFrameworkException("Failed to deploy model", ExceptionsHelper.status(e)) - ); - } + ActionListener.wrap( + deployModelFuture::onResponse, + e -> deployModelFuture.onFailure( + new FlowFrameworkException("Failed to deploy model", ExceptionsHelper.status(e)) + ) ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index af2aea1e3..cf03de4bb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -62,6 +62,7 @@ public PlainActionFuture execute( throw new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST); } catch (InterruptedException e) { FutureUtils.cancel(future); + Thread.currentThread().interrupt(); } future.onResponse(WorkflowData.EMPTY); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index d485f6f5c..5eed7a864 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -47,7 +47,6 @@ import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; @@ -95,42 +94,14 @@ public PlainActionFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { - try { - String resourceName = getResourceByWorkflowStep(getName()); - logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - currentNodeId, - getName(), - mlRegisterAgentResponse.getAgentId(), - ActionListener.wrap(response -> { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); - registerAgentModelFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), - workflowId, - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterAgentResponse.getAgentId(); - logger.error(errorMessage, exception); - registerAgentModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerAgentModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterAgentResponse.getAgentId(), + registerAgentModelFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index 8fe1271df..6cc990429 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -37,7 +38,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -77,45 +77,19 @@ public PlainActionFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { - try { - logger.info("Model group registration successful"); - String resourceName = getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlRegisterModelGroupResponse.getModelGroupId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - registerModelGroupFuture.onResponse( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, mlRegisterModelGroupResponse.getModelGroupId()), - Map.entry(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterModelGroupResponse.getModelGroupId(); - logger.error(errorMessage, exception); - registerModelGroupFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerModelGroupFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Model group registration successful"); + ActionListener resourceListener = ActionListener.wrap(r -> { + Map content = new HashMap<>(r.getContent()); + content.put(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()); + registerModelGroupFuture.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId())); + }, registerModelGroupFuture::onFailure); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterModelGroupResponse.getModelGroupId(), + resourceListener + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 25c0de272..72b56ab79 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; @@ -145,60 +144,51 @@ public PlainActionFuture execute( mlClient.register(mlInput, new ActionListener() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("Remote Model registration successful"); + String resourceName = getResourceByWorkflowStep(getName()); + ActionListener registerUpdateListener = ActionListener.wrap(registerUpdateResponse -> { + if (Boolean.TRUE.equals(deploy)) { + updateDeployResource(resourceName, mlRegisterModelResponse); + } else { + completeRegisterFuture(resourceName, mlRegisterModelResponse); + } + }, registerUpdateException -> { + String errorMessage = "Failed to update new created " + + currentNodeId + + " resource " + + getName() + + " id " + + mlRegisterModelResponse.getModelId(); + completeRegisterFutureExceptionally(errorMessage, registerUpdateException); + }); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterModelResponse.getModelId(), + registerUpdateListener + ); + } - try { - logger.info("Remote Model registration successful"); - String resourceName = getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlRegisterModelResponse.getModelId(), - ActionListener.wrap(response -> { - // If we deployed, simulate the deploy step has been called - if (Boolean.TRUE.equals(deploy)) { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - DeployModelStep.NAME, - mlRegisterModelResponse.getModelId(), - ActionListener.wrap(deployUpdateResponse -> { - completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse); - }, deployUpdateException -> { - String errorMessage = "Failed to update simulated deploy step resource " - + mlRegisterModelResponse.getModelId(); - logger.error(errorMessage, deployUpdateException); - registerRemoteModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(deployUpdateException)) - ); - }) - ); - } else { - completeRegisterFuture(response, resourceName, mlRegisterModelResponse); - } - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterModelResponse.getModelId(); - logger.error(errorMessage, exception); - registerRemoteModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerRemoteModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + private void updateDeployResource(String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { + ActionListener deployUpdateListener = ActionListener.wrap( + deployUpdateResponse -> completeRegisterFuture(resourceName, mlRegisterModelResponse), + deployUpdateException -> { + String errorMessage = "Failed to update simulated deploy step resource " + mlRegisterModelResponse.getModelId(); + completeRegisterFutureExceptionally(errorMessage, deployUpdateException); + } + ); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + DeployModelStep.NAME, + mlRegisterModelResponse.getModelId(), + deployUpdateListener + ); } - void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); + private void completeRegisterFuture(String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("successfully updated resources created in state index"); registerRemoteModelFuture.onResponse( new WorkflowData( Map.ofEntries( @@ -211,6 +201,11 @@ void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegi ); } + private void completeRegisterFutureExceptionally(String errorMessage, Exception exception) { + logger.error(errorMessage, exception); + registerRemoteModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + @Override public void onFailure(Exception ex) { Exception e = getSafeException(ex); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java b/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java index b46ddecab..18ce60f14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java @@ -38,8 +38,6 @@ public class ReindexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ReindexStep.class); private final Client client; - private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "reindex"; /** The refresh field for reindex */ @@ -61,7 +59,6 @@ public class ReindexStep implements WorkflowStep { */ public ReindexStep(Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.client = client; - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index ba19823a7..dad2d65c9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -74,7 +74,7 @@ public Map getContent() { */ public Map getParams() { return this.params; - }; + } /** * Returns the workflowId associated with this data. @@ -83,7 +83,7 @@ public Map getParams() { @Nullable public String getWorkflowId() { return this.workflowId; - }; + } /** * Returns the nodeId associated with this data. @@ -92,5 +92,5 @@ public String getWorkflowId() { @Nullable public String getNodeId() { return this.nodeId; - }; + } } diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index f55cb3bc1..ecaec46b5 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -37,13 +37,16 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.CreateIndexStep; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.index.get.GetResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -488,4 +491,52 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { exceptionCaptor.getValue().getMessage() ); } + + public void testAddResourceToStateIndex() throws IOException { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED)); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to update state")); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update new created node_id resource create_connector id this_id", exceptionCaptor.getValue().getMessage()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 86dc4af47..dd9eb369d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -30,8 +28,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -94,10 +90,10 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(CONNECTOR_ID, connectorId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = createConnectorStep.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 538747b94..b9a9dfe88 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -12,7 +12,6 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -21,7 +20,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -36,10 +34,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -96,10 +92,10 @@ public void setUp() throws Exception { public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(INDEX_NAME, "demo"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index c7912d494..efd9275b5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -11,12 +11,10 @@ import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; @@ -27,9 +25,7 @@ import org.mockito.ArgumentCaptor; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.mockito.ArgumentMatchers.any; @@ -78,10 +74,10 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(PIPELINE_ID, "pipelineId"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java index 9a54f2af6..ac1fda8d8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java @@ -11,12 +11,10 @@ import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; @@ -27,9 +25,7 @@ import org.mockito.ArgumentCaptor; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.mockito.ArgumentMatchers.any; @@ -43,7 +39,7 @@ public class CreateSearchPipelineStepTests extends OpenSearchTestCase { private WorkflowData inputData; - private WorkflowData outpuData; + private WorkflowData outputData; private Client client; private AdminClient adminClient; private ClusterAdminClient clusterAdminClient; @@ -63,7 +59,7 @@ public void setUp() throws Exception { ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); + outputData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -78,10 +74,10 @@ public void testCreateSearchPipelineStep() throws InterruptedException, Executio CreateSearchPipelineStep createSearchPipelineStep = new CreateSearchPipelineStep(client, flowFrameworkIndicesHandler); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(PIPELINE_ID, "pipelineId"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -100,7 +96,7 @@ public void testCreateSearchPipelineStep() throws InterruptedException, Executio actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); assertTrue(future.isDone()); - assertEquals(outpuData.getContent(), future.get().getContent()); + assertEquals(outputData.getContent(), future.get().getContent()); } public void testCreateSearchPipelineStepFailure() throws InterruptedException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 60a6b7f48..b05e43ed4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,10 +40,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -152,10 +148,10 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = deployModel.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 4aabee215..b53574cfe 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -15,6 +15,9 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.flowframework.common.CommonValue.DELAY_FIELD; @@ -50,6 +53,49 @@ public void testNoOpStepDelay() throws IOException, InterruptedException { assertTrue(System.nanoTime() - start > 900_000_000L); } + public void testNoOpStepInterrupt() throws IOException, InterruptedException { + NoOpStep noopStep = new NoOpStep(); + WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "5s"), null, null); + + CountDownLatch latch = new CountDownLatch(1); + // Fetch errors from the separate thread + AtomicReference assertionError = new AtomicReference<>(); + + Thread testThread = new Thread(() -> { + try { + PlainActionFuture future = noopStep.execute( + "nodeId", + delayData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + try { + future.actionGet(); + } catch (Exception e) { + // Ignore the IllegalStateExcption/InterruptedExcpetion + } + assertTrue(future.isDone()); + assertTrue(future.isCancelled()); + assertTrue(Thread.currentThread().isInterrupted()); + } catch (AssertionError e) { + assertionError.set(e); + } finally { + latch.countDown(); + } + }); + + testThread.start(); + Thread.sleep(100); + testThread.interrupt(); + + latch.await(1, TimeUnit.SECONDS); + + if (assertionError.get() != null) { + throw assertionError.get(); + } + } + public void testNoOpStepParse() throws IOException { NoOpStep noopStep = new NoOpStep(); WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "foo"), null, null); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 9005052ae..626dfdfa1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -33,8 +31,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -101,10 +97,10 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(AGENT_ID, agentId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), @@ -134,10 +130,10 @@ public void testRegisterAgentFailure() throws IOException { }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(AGENT_ID, agentId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index f8f2bce8f..061d9f8c8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -38,16 +36,15 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -168,10 +165,10 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), @@ -225,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } + // This method tests code in the abstract parent + public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage()); + } + public void testRegisterLocalCustomModelFailure() { doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index dcd2098d5..afe97bacb 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,12 +40,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -162,10 +158,10 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalPretrainedModelStep.execute( workflowData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index df6fd94e2..9c35af33c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,12 +40,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -165,10 +161,10 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( workflowData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 921c7ac15..7f7adf44b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -33,8 +31,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -111,10 +108,10 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_GROUP_ID, modelGroupId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = modelGroupStep.execute( inputData.getNodeId(), @@ -140,7 +137,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep assertTrue(future.isDone()); assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); - assertEquals(status, future.get().getContent().get("model_group_status")); + assertEquals(status, future.get().getContent().get(MODEL_GROUP_STATUS)); } public void testRegisterModelGroupFailure() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 603bfde57..0e2ab91e9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -12,9 +12,7 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -30,15 +28,14 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -94,10 +91,10 @@ public void testRegisterRemoteModelSuccess() throws Exception { }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), @@ -109,7 +106,13 @@ public void testRegisterRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); // only updates register resource - verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(1)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -130,10 +133,10 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); WorkflowData deployWorkflowData = new WorkflowData( Map.ofEntries( @@ -156,7 +159,13 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); // updates both register and deploy resources - verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(2)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -182,7 +191,13 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); // updates both register and deploy resources - verify(flowFrameworkIndicesHandler, times(4)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(4)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -210,6 +225,99 @@ public void testRegisterRemoteModelFailure() { } + public void testRegisterRemoteModelUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onFailure(new RuntimeException("Failed to update register resource")); + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update new created test-node-id resource register_remote_model id efgh", ex.getCause().getMessage()); + } + + public void testRegisterRemoteModelDeployUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource efgh", ex.getCause().getMessage()); + } + public void testReisterRemoteModelInterfaceFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java index 97eff365a..a195025ea 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java @@ -11,12 +11,10 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.OpenSearchException; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.Randomness; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.BulkByScrollTask; @@ -36,16 +34,12 @@ import static java.lang.Math.abs; import static java.util.stream.Collectors.toList; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX; import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.workflow.ReindexStep.NAME; import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -96,12 +90,6 @@ public void testReindexStep() throws ExecutionException, InterruptedException, I @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - PlainActionFuture future = reIndexStep.execute( inputData.getNodeId(), inputData,