Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds deploy model flag support for local model registration, fixes integration tests #350

Merged
merged 24 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
31e1ee9
Fixing local model integration test
joshpalis Jan 2, 2024
51d42ac
Merge branch 'main' into local
joshpalis Jan 2, 2024
ba8fdff
Added deploy model flag support for local model registration, added a…
joshpalis Jan 2, 2024
d94cbcf
Fixing comment
joshpalis Jan 3, 2024
8ee5189
Fixing deprovision workflow transport action, removing use of templat…
joshpalis Jan 3, 2024
a5310d9
Removing rest status checks for deprovision API tests
joshpalis Jan 3, 2024
7414a96
Increasing wait time for deprovision status
joshpalis Jan 3, 2024
46bd85a
Removing sdeprovision status checks for model deployment tests
joshpalis Jan 3, 2024
f840e21
increasing timeout for local model registration test template
joshpalis Jan 3, 2024
aabe87f
Reverting timeout increase, setting ML Commons native memory threshol…
joshpalis Jan 3, 2024
5c98e43
Passing an action listener to retryableGetMlTask
joshpalis Jan 3, 2024
2067d83
Merge branch 'main' into local
joshpalis Jan 3, 2024
b279f70
Addressing PR comments, preserving order of resource map
joshpalis Jan 4, 2024
7c6ca49
Testing if a wait time after deprovisioning will mitigate circuit bre…
joshpalis Jan 4, 2024
1da5351
Increasing mlconfig index creation wait time
joshpalis Jan 4, 2024
91fbfe8
Combining local model registration tests into one
joshpalis Jan 4, 2024
a703684
removing resource map from deprovision workflow transport action
joshpalis Jan 4, 2024
030dfe9
Merge branch 'main' into local
joshpalis Jan 4, 2024
24f61ac
Fixing getResourceFromDeprovisionNOde and tests
joshpalis Jan 4, 2024
9eec3dd
Separating out local model registration tests, using ml jvm heap memo…
joshpalis Jan 4, 2024
ebc0bcf
Testing : removing second local model registration test
joshpalis Jan 4, 2024
b970504
Reducing model registration tests, testing local model registration w…
joshpalis Jan 4, 2024
46f15dc
Removing suffix from simulated deploy model step
joshpalis Jan 5, 2024
669db42
merge main
joshpalis Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -24,31 +24,25 @@
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD;
import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep;
Expand All @@ -65,84 +59,50 @@ public class DeprovisionWorkflowTransportAction extends HandledTransportAction<W

private final ThreadPool threadPool;
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
private final WorkflowStepFactory workflowStepFactory;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final EncryptorUtils encryptorUtils;

/**
* Instantiates a new ProvisionWorkflowTransportAction
* @param transportService The TransportService
* @param actionFilters action filters
* @param threadPool The OpenSearch thread pool
* @param client The node client to retrieve a stored use case template
* @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes
* @param workflowStepFactory The factory instantiating workflow steps
* @param flowFrameworkIndicesHandler Class to handle all internal system indices actions
* @param encryptorUtils Utility class to handle encryption/decryption
*/
@Inject
public DeprovisionWorkflowTransportAction(
TransportService transportService,
ActionFilters actionFilters,
ThreadPool threadPool,
Client client,
WorkflowProcessSorter workflowProcessSorter,
WorkflowStepFactory workflowStepFactory,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
EncryptorUtils encryptorUtils
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
super(DeprovisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.threadPool = threadPool;
this.client = client;
this.workflowProcessSorter = workflowProcessSorter;
this.workflowStepFactory = workflowStepFactory;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.encryptorUtils = encryptorUtils;
}

@Override
protected void doExecute(Task task, WorkflowRequest request, ActionListener<WorkflowResponse> listener) {
// Retrieve use case template from global context
String workflowId = request.getWorkflowId();
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);

// Stash thread context to interact with system index
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(response -> {
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
context.restore();

if (!response.isExists()) {
listener.onFailure(
new FlowFrameworkException(
"Failed to retrieve template (" + workflowId + ") from global context.",
RestStatus.NOT_FOUND
)
);
return;
}

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

// Decrypt template
template = encryptorUtils.decryptTemplateCredentials(template);

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
workflowProcessSorter.validate(provisionProcessSequence);

// We have a valid template and sorted nodes, get the created resources
getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener);
// Retrieve resources from workflow state and deprovision
executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener);
}, exception -> {
if (exception instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for workflow : " + workflowId);
listener.onFailure(exception);
} else {
logger.error("Failed to retrieve template from global context.", exception);
listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
}
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception)));
}));
} catch (Exception e) {
String message = "Failed to retrieve template from global context.";
Expand All @@ -151,64 +111,38 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}
}

private void getResourcesAndExecute(
String workflowId,
List<ProcessNode> provisionProcessSequence,
ActionListener<WorkflowResponse> listener
) {
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
// Get a map of step id to created resources
final Map<String, ResourceCreated> resourceMap = response.getWorkflowState()
.resourcesCreated()
.stream()
.collect(Collectors.toMap(ResourceCreated::workflowStepId, Function.identity()));

// Now finally do the deprovision
executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener);
}, exception -> {
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception)));
}));
}

private void executeDeprovisionSequence(
String workflowId,
Map<String, ResourceCreated> resourceMap,
List<ProcessNode> provisionProcessSequence,
List<ResourceCreated> resourcesCreated,
ActionListener<WorkflowResponse> listener
) {

// Create a list of ProcessNodes with the corresponding deprovision workflow steps
List<ProcessNode> deprovisionProcessSequence = provisionProcessSequence.stream()
// Only include nodes that created a resource
.filter(pn -> resourceMap.containsKey(pn.id()))
// Create a new ProcessNode with a deprovision step
.map(pn -> {
String stepName = pn.workflowStep().getName();
String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName);
// Unimplemented steps presently return null, so skip
if (deprovisionStep == null) {
return null;
}
// New ID is old ID with deprovision added
String deprovisionStepId = pn.id() + DEPROVISION_SUFFIX;
return new ProcessNode(
List<ProcessNode> deprovisionProcessSequence = new ArrayList<>();
for (ResourceCreated resource : resourcesCreated) {
String workflowStepId = resource.workflowStepId();

String stepName = resource.workflowStepName();
String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName);
// Unimplemented steps presently return null, so skip
if (deprovisionStep == null) {
continue;
}
// New ID is old ID with deprovision added
String deprovisionStepId = workflowStepId + DEPROVISION_SUFFIX;
deprovisionProcessSequence.add(
new ProcessNode(
deprovisionStepId,
workflowStepFactory.createStep(deprovisionStep),
Collections.emptyMap(),
new WorkflowData(
Map.of(getResourceByWorkflowStep(stepName), resourceMap.get(pn.id()).resourceId()),
workflowId,
deprovisionStepId
),
new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId),
Collections.emptyList(),
this.threadPool,
pn.nodeTimeout()
);
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
TimeValue.ZERO
)
);
}

// Deprovision in reverse order of provisioning to minimize risk of dependencies
Collections.reverse(deprovisionProcessSequence);
logger.info("Deprovisioning steps: {}", deprovisionProcessSequence.stream().map(ProcessNode::id).collect(Collectors.joining(", ")));
Expand All @@ -219,7 +153,7 @@ private void executeDeprovisionSequence(
Iterator<ProcessNode> iter = deprovisionProcessSequence.iterator();
while (iter.hasNext()) {
ProcessNode deprovisionNode = iter.next();
ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourceMap);
ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourcesCreated);
String resourceNameAndId = getResourceNameAndId(resource);
CompletableFuture<WorkflowData> deprovisionFuture = deprovisionNode.execute();
try {
Expand Down Expand Up @@ -265,7 +199,7 @@ private void executeDeprovisionSequence(
}
// Get corresponding resources
List<ResourceCreated> remainingResources = deprovisionProcessSequence.stream()
.map(pn -> getResourceFromDeprovisionNode(pn, resourceMap))
.map(pn -> getResourceFromDeprovisionNode(pn, resourcesCreated))
.collect(Collectors.toList());
logger.info("Resources remaining: {}", remainingResources);
updateWorkflowState(workflowId, remainingResources, listener);
Expand Down Expand Up @@ -322,10 +256,18 @@ private void updateWorkflowState(
}
}

private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, Map<String, ResourceCreated> resourceMap) {
private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, List<ResourceCreated> resourcesCreated) {
String deprovisionId = deprovisionNode.id();
int pos = deprovisionId.indexOf(DEPROVISION_SUFFIX);
return pos > 0 ? resourceMap.get(deprovisionId.substring(0, pos)) : null;
ResourceCreated resource = null;
if (pos > 0) {
for (ResourceCreated resourceCreated : resourcesCreated) {
if (resourceCreated.workflowStepId().equals(deprovisionId.substring(0, pos))) {
resource = resourceCreated;
}
}
}
return resource;
}

private static String getResourceNameAndId(ResourceCreated resource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.threadpool.ThreadPool;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

/**
Expand Down Expand Up @@ -66,13 +64,15 @@ protected AbstractRetryableWorkflowStep(
* @param future the workflow step future
* @param taskId the ml task id
* @param workflowStep the workflow step which requires a retry get ml task functionality
* @param mlTaskListener the ML Task Listener
*/
protected void retryableGetMlTask(
String workflowId,
String nodeId,
CompletableFuture<WorkflowData> future,
String taskId,
String workflowStep
String workflowStep,
ActionListener<MLTask> mlTaskListener
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
) {
AtomicInteger retries = new AtomicInteger();
CompletableFuture.runAsync(() -> {
Expand All @@ -91,46 +91,37 @@ protected void retryableGetMlTask(
id,
ActionListener.wrap(updateResponse -> {
logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex());
future.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(resourceName, id),
Map.entry(REGISTER_MODEL_STATUS, response.getState().name())
),
workflowId,
nodeId
)
);
mlTaskListener.onResponse(response);
}, exception -> {
logger.error("Failed to update new created resource", exception);
future.completeExceptionally(
mlTaskListener.onFailure(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);
} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
mlTaskListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
break;
case FAILED:
case COMPLETED_WITH_ERROR:
String errorMessage = workflowStep + " failed with error : " + response.getError();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
break;
case CANCELLED:
errorMessage = workflowStep + " task was cancelled.";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
break;
default:
// Task started or running, do nothing
}
}, exception -> {
String errorMessage = workflowStep + " failed with error : " + exception.getMessage();
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST));
}));
// Wait long enough for future to possibly complete
try {
Expand All @@ -143,7 +134,7 @@ protected void retryableGetMlTask(
if (!future.isDone()) {
String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries";
logger.error(errorMessage);
future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT));
}
}, threadPool.executor(PROVISION_THREAD_POOL));
}
Expand Down
Loading
Loading