Skip to content

Commit

Permalink
Refactor workflow step resource updates to eliminate duplication (#796)
Browse files Browse the repository at this point in the history
* Refactor workflow step resource updates to eliminate duplication

Signed-off-by: Daniel Widdis <[email protected]>

* Add coverage and changelog

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Aug 14, 2024
1 parent 268f746 commit dc20feb
Show file tree
Hide file tree
Showing 30 changed files with 515 additions and 438 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Documentation
### Maintenance
### Refactoring
- Refactor workflow step resource updates to eliminate duplication ([#796](https://github.com/opensearch-project/flow-framework/pull/796))
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;

Expand Down Expand Up @@ -666,13 +667,13 @@ public void updateFlowFrameworkSystemIndexDocWithScript(
/**
* Creates a new ResourceCreated object and a script to update the state index
* @param workflowId workflowId for the relevant step
* @param nodeId WorkflowData object with relevent step information
* @param nodeId current process node (workflow step) id
* @param workflowStepName the workflowstep name that created the resource
* @param resourceId the id of the newly created resource
* @param listener the ActionListener for this step to handle completing the future after update
* @throws IOException if parsing fails on new resource
*/
public void updateResourceInStateIndex(
private void updateResourceInStateIndex(
String workflowId,
String nodeId,
String workflowStepName,
Expand All @@ -697,6 +698,44 @@ public void updateResourceInStateIndex(
updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> {
logger.info("updated resources created of {}", workflowId);
listener.onResponse(updateResponse);
}, exception -> { listener.onFailure(exception); }));
}, listener::onFailure));
}

/**
* Adds a resource to the state index, including common exception handling
* @param currentNodeInputs Inputs to the current node
* @param nodeId current process node (workflow step) id
* @param workflowStepName the workflow step name that created the resource
* @param resourceId the id of the newly created resource
* @param listener the ActionListener for this step to handle completing the future after update
*/
public void addResourceToStateIndex(
WorkflowData currentNodeInputs,
String nodeId,
String workflowStepName,
String resourceId,
ActionListener<WorkflowData> listener
) {
String resourceName = getResourceByWorkflowStep(workflowStepName);
try {
updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
nodeId,
workflowStepName,
resourceId,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
listener.onResponse(new WorkflowData(Map.of(resourceName, resourceId), currentNodeInputs.getWorkflowId(), nodeId));
}, exception -> {
String errorMessage = "Failed to update new created " + nodeId + " resource " + workflowStepName + " id " + resourceId;
logger.error(errorMessage, exception);
listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
})
);
} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource";
logger.error(errorMessage, e);
listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;
import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException;

/**
Expand Down Expand Up @@ -98,43 +97,14 @@ public PlainActionFuture<WorkflowData> execute(

@Override
public void onResponse(AcknowledgedResponse acknowledgedResponse) {
String resourceName = getResourceByWorkflowStep(getName());
try {
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
pipelineId,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
// PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead
// TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here
createPipelineFuture.onResponse(
new WorkflowData(
Map.of(resourceName, pipelineId),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}, exception -> {
String errorMessage = "Failed to update new created "
+ currentNodeId
+ " resource "
+ getName()
+ " id "
+ pipelineId;
logger.error(errorMessage, exception);
createPipelineFuture.onFailure(
new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource";
logger.error(errorMessage, e);
createPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
// PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead
flowFrameworkIndicesHandler.addResourceToStateIndex(
currentNodeInputs,
currentNodeId,
getName(),
pipelineId,
createPipelineFuture
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID;
Expand Down Expand Up @@ -189,58 +188,38 @@ public PlainActionFuture<WorkflowData> execute(

// Attempt to retrieve the model ID
retryableGetMlTask(
currentNodeInputs.getWorkflowId(),
currentNodeInputs,
currentNodeId,
registerLocalModelFuture,
taskId,
"Local model registration",
ActionListener.wrap(mlTask -> {

ActionListener.wrap(mlTaskWorkflowData -> {
// Registered Model Resource has been updated
String resourceName = getResourceByWorkflowStep(getName());
String id = getResourceId(mlTask);

if (Boolean.TRUE.equals(deploy)) {

// Simulate Model deployment step and update resources created
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
DeployModelStep.NAME,
id,
ActionListener.wrap(deployUpdateResponse -> {
logger.info(
"successfully updated resources created in state index: {}",
deployUpdateResponse.getIndex()
);
registerLocalModelFuture.onResponse(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, id),
Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())
),
currentNodeInputs.getWorkflowId(),
currentNodeId
)
);
}, deployUpdateException -> {
String id = (String) mlTaskWorkflowData.getContent().get(resourceName);
ActionListener<WorkflowData> deployUpdateListener = ActionListener.wrap(
deployUpdateResponse -> registerLocalModelFuture.onResponse(mlTaskWorkflowData),
deployUpdateException -> {
String errorMessage = "Failed to update simulated deploy step resource " + id;
logger.error(errorMessage, deployUpdateException);
registerLocalModelFuture.onFailure(
new FlowFrameworkException(errorMessage, ExceptionsHelper.status(deployUpdateException))
);
})
}
);
} else {
registerLocalModelFuture.onResponse(
new WorkflowData(
Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())),
currentNodeInputs.getWorkflowId(),
currentNodeId
)
// Simulate Model deployment step and update resources created
flowFrameworkIndicesHandler.addResourceToStateIndex(
currentNodeInputs,
currentNodeId,
DeployModelStep.NAME,
id,
deployUpdateListener
);
} else {
registerLocalModelFuture.onResponse(mlTaskWorkflowData);
}
}, exception -> { registerLocalModelFuture.onFailure(exception); })
}, registerLocalModelFuture::onFailure)
);
}, exception -> {
Exception e = getSafeException(exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.FutureUtils;
Expand All @@ -24,8 +23,11 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.threadpool.ThreadPool;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

Expand Down Expand Up @@ -60,20 +62,20 @@ protected AbstractRetryableWorkflowStep(

/**
* Retryable get ml task
* @param workflowId the workflow id
* @param currentNodeInputs the current Node Inputs
* @param nodeId the workflow node id
* @param future the workflow step future
* @param taskId the ml task id
* @param workflowStep the workflow step which requires a retry get ml task functionality
* @param mlTaskListener the ML Task Listener
*/
protected void retryableGetMlTask(
String workflowId,
WorkflowData currentNodeInputs,
String nodeId,
PlainActionFuture<WorkflowData> future,
String taskId,
String workflowStep,
ActionListener<MLTask> mlTaskListener
ActionListener<WorkflowData> mlTaskListener
) {
CompletableFuture.runAsync(() -> {
do {
Expand All @@ -82,34 +84,13 @@ protected void retryableGetMlTask(
String id = getResourceId(response);
switch (response.getState()) {
case COMPLETED:
try {
logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id);
flowFrameworkIndicesHandler.updateResourceInStateIndex(
workflowId,
nodeId,
getName(),
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
mlTaskListener.onResponse(response);
}, exception -> {
String errorMessage = "Failed to update new created "
+ nodeId
+ " resource "
+ getName()
+ " id "
+ id;
logger.error(errorMessage, exception);
mlTaskListener.onFailure(
new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource " + resourceName + " id " + id;
logger.error(errorMessage, e);
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
logger.info("{} successful for {} and {} {}", workflowStep, currentNodeInputs, resourceName, id);
ActionListener<WorkflowData> resourceListener = ActionListener.wrap(r -> {
Map<String, Object> content = new HashMap<>(r.getContent());
content.put(REGISTER_MODEL_STATUS, response.getState().toString());
mlTaskListener.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId()));
}, mlTaskListener::onFailure);
flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, nodeId, getName(), id, resourceListener);
break;
case FAILED:
case COMPLETED_WITH_ERROR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;
import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException;
import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap;

Expand Down Expand Up @@ -85,40 +84,14 @@ public PlainActionFuture<WorkflowData> execute(

@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
String resourceName = getResourceByWorkflowStep(getName());
try {
logger.info("Created connector successfully");
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlCreateConnectorResponse.getConnectorId(),
ActionListener.wrap(response -> {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
createConnectorFuture.onResponse(
new WorkflowData(
Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())),
currentNodeInputs.getWorkflowId(),
currentNodeId
)
);
}, exception -> {
String errorMessage = "Failed to update new created "
+ currentNodeId
+ " resource "
+ getName()
+ " id "
+ mlCreateConnectorResponse.getConnectorId();
logger.error(errorMessage, exception);
createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
})
);

} catch (Exception e) {
String errorMessage = "Failed to parse and update new created resource";
logger.error(errorMessage, e);
createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e)));
}
logger.info("Created connector successfully");
flowFrameworkIndicesHandler.addResourceToStateIndex(
currentNodeInputs,
currentNodeId,
getName(),
mlCreateConnectorResponse.getConnectorId(),
createConnectorFuture
);
}

@Override
Expand Down
Loading

0 comments on commit dc20feb

Please sign in to comment.