From efad064df9b473220007322a250b9c8a14c63add Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 5 Jan 2024 11:48:19 -0800 Subject: [PATCH] [Backport 2.x] Manual Back Port of #350 (#372) Adds deploy model flag support for local model registration, fixes integration tests (#350) * Fixing local model integration test * Added deploy model flag support for local model registration, added associated integration test * Fixing comment * Fixing deprovision workflow transport action, removing use of template, ascertaining deprovision sequence from created resources * Removing rest status checks for deprovision API tests * Increasing wait time for deprovision status * Removing sdeprovision status checks for model deployment tests * increasing timeout for local model registration test template * Reverting timeout increase, setting ML Commons native memory threshold to 100 to avoid opening circuit breaker * Passing an action listener to retryableGetMlTask * Addressing PR comments, preserving order of resource map * Testing if a wait time after deprovisioning will mitigate circuit breaker issues * Increasing mlconfig index creation wait time * Combining local model registration tests into one * removing resource map from deprovision workflow transport action * Fixing getResourceFromDeprovisionNOde and tests * Separating out local model registration tests, using ml jvm heap memory setting instead of native memory heap setting * Testing : removing second local model registration test * Reducing model registration tests, testing local model registration with deployed flag, testing remote model registration with deploy step * Removing suffix from simulated deploy model step --------- Signed-off-by: Joshua Palis --- .../DeprovisionWorkflowTransportAction.java | 146 ++++++------------ .../AbstractRetryableWorkflowStep.java | 29 ++-- .../workflow/DeployModelStep.java | 26 +++- .../workflow/RegisterLocalModelStep.java | 103 ++++++++---- .../FlowFrameworkRestTestCase.java | 23 ++- .../rest/FlowFrameworkRestApiIT.java | 63 +++----- ...provisionWorkflowTransportActionTests.java | 86 +---------- .../registerlocalmodel-deployflag.json | 36 +++++ .../registerlocalmodel-deploymodel.json | 89 +++++------ 9 files changed, 279 insertions(+), 322 deletions(-) create mode 100644 src/test/resources/template/registerlocalmodel-deployflag.json diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 1777a2457..068ac7976 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -11,11 +11,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -24,31 +24,25 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.State; -import org.opensearch.flowframework.model.Template; -import org.opensearch.flowframework.model.Workflow; -import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; @@ -65,10 +59,8 @@ public class DeprovisionWorkflowTransportAction extends HandledTransportAction listener) { - // Retrieve use case template from global context String workflowId = request.getWorkflowId(); - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); // Stash thread context to interact with system index try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(response -> { + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); - if (!response.isExists()) { - listener.onFailure( - new FlowFrameworkException( - "Failed to retrieve template (" + workflowId + ") from global context.", - RestStatus.NOT_FOUND - ) - ); - return; - } - - // Parse template from document source - Template template = Template.parse(response.getSourceAsString()); - - // Decrypt template - template = encryptorUtils.decryptTemplateCredentials(template); - - // Sort and validate graph - Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); - List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); - workflowProcessSorter.validate(provisionProcessSequence); - - // We have a valid template and sorted nodes, get the created resources - getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener); + // Retrieve resources from workflow state and deprovision + executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener); }, exception -> { - if (exception instanceof FlowFrameworkException) { - logger.error("Workflow validation failed for workflow : " + workflowId); - listener.onFailure(exception); - } else { - logger.error("Failed to retrieve template from global context.", exception); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); - } + String message = "Failed to get workflow state for workflow " + workflowId; + logger.error(message, exception); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception))); })); } catch (Exception e) { String message = "Failed to retrieve template from global context."; @@ -151,64 +111,38 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence, - ActionListener listener - ) { - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); - client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { - // Get a map of step id to created resources - final Map resourceMap = response.getWorkflowState() - .resourcesCreated() - .stream() - .collect(Collectors.toMap(ResourceCreated::workflowStepId, Function.identity())); - - // Now finally do the deprovision - executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener); - }, exception -> { - String message = "Failed to get workflow state for workflow " + workflowId; - logger.error(message, exception); - listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception))); - })); - } - private void executeDeprovisionSequence( String workflowId, - Map resourceMap, - List provisionProcessSequence, + List resourcesCreated, ActionListener listener ) { + // Create a list of ProcessNodes with the corresponding deprovision workflow steps - List deprovisionProcessSequence = provisionProcessSequence.stream() - // Only include nodes that created a resource - .filter(pn -> resourceMap.containsKey(pn.id())) - // Create a new ProcessNode with a deprovision step - .map(pn -> { - String stepName = pn.workflowStep().getName(); - String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName); - // Unimplemented steps presently return null, so skip - if (deprovisionStep == null) { - return null; - } - // New ID is old ID with deprovision added - String deprovisionStepId = pn.id() + DEPROVISION_SUFFIX; - return new ProcessNode( + List deprovisionProcessSequence = new ArrayList<>(); + for (ResourceCreated resource : resourcesCreated) { + String workflowStepId = resource.workflowStepId(); + + String stepName = resource.workflowStepName(); + String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName); + // Unimplemented steps presently return null, so skip + if (deprovisionStep == null) { + continue; + } + // New ID is old ID with deprovision added + String deprovisionStepId = workflowStepId + DEPROVISION_SUFFIX; + deprovisionProcessSequence.add( + new ProcessNode( deprovisionStepId, workflowStepFactory.createStep(deprovisionStep), Collections.emptyMap(), - new WorkflowData( - Map.of(getResourceByWorkflowStep(stepName), resourceMap.get(pn.id()).resourceId()), - workflowId, - deprovisionStepId - ), + new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId), Collections.emptyList(), this.threadPool, - pn.nodeTimeout() - ); - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + TimeValue.ZERO + ) + ); + } + // Deprovision in reverse order of provisioning to minimize risk of dependencies Collections.reverse(deprovisionProcessSequence); logger.info("Deprovisioning steps: {}", deprovisionProcessSequence.stream().map(ProcessNode::id).collect(Collectors.joining(", "))); @@ -219,7 +153,7 @@ private void executeDeprovisionSequence( Iterator iter = deprovisionProcessSequence.iterator(); while (iter.hasNext()) { ProcessNode deprovisionNode = iter.next(); - ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourceMap); + ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourcesCreated); String resourceNameAndId = getResourceNameAndId(resource); CompletableFuture deprovisionFuture = deprovisionNode.execute(); try { @@ -265,7 +199,7 @@ private void executeDeprovisionSequence( } // Get corresponding resources List remainingResources = deprovisionProcessSequence.stream() - .map(pn -> getResourceFromDeprovisionNode(pn, resourceMap)) + .map(pn -> getResourceFromDeprovisionNode(pn, resourcesCreated)) .collect(Collectors.toList()); logger.info("Resources remaining: {}", remainingResources); updateWorkflowState(workflowId, remainingResources, listener); @@ -322,10 +256,18 @@ private void updateWorkflowState( } } - private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, Map resourceMap) { + private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, List resourcesCreated) { String deprovisionId = deprovisionNode.id(); int pos = deprovisionId.indexOf(DEPROVISION_SUFFIX); - return pos > 0 ? resourceMap.get(deprovisionId.substring(0, pos)) : null; + ResourceCreated resource = null; + if (pos > 0) { + for (ResourceCreated resourceCreated : resourcesCreated) { + if (resourceCreated.workflowStepId().equals(deprovisionId.substring(0, pos))) { + resource = resourceCreated; + } + } + } + return resource; } private static String getResourceNameAndId(ResourceCreated resource) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index 3d17fcf26..5f2c118e5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -21,12 +21,10 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.threadpool.ThreadPool; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** @@ -66,13 +64,15 @@ protected AbstractRetryableWorkflowStep( * @param future the workflow step future * @param taskId the ml task id * @param workflowStep the workflow step which requires a retry get ml task functionality + * @param mlTaskListener the ML Task Listener */ protected void retryableGetMlTask( String workflowId, String nodeId, CompletableFuture future, String taskId, - String workflowStep + String workflowStep, + ActionListener mlTaskListener ) { AtomicInteger retries = new AtomicInteger(); CompletableFuture.runAsync(() -> { @@ -91,38 +91,29 @@ protected void retryableGetMlTask( id, ActionListener.wrap(updateResponse -> { logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - future.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, id), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); + mlTaskListener.onResponse(response); }, exception -> { logger.error("Failed to update new created resource", exception); - future.completeExceptionally( + mlTaskListener.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) ); } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + mlTaskListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } break; case FAILED: case COMPLETED_WITH_ERROR: String errorMessage = workflowStep + " failed with error : " + response.getError(); logger.error(errorMessage); - future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); break; case CANCELLED: errorMessage = workflowStep + " task was cancelled."; logger.error(errorMessage); - future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); + mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); break; default: // Task started or running, do nothing @@ -130,7 +121,7 @@ protected void retryableGetMlTask( }, exception -> { String errorMessage = workflowStep + " failed with error : " + exception.getMessage(); logger.error(errorMessage); - future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); })); // Wait long enough for future to possibly complete try { @@ -143,7 +134,7 @@ protected void retryableGetMlTask( if (!future.isDone()) { String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries"; logger.error(errorMessage); - future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); + mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); } }, threadPool.executor(PROVISION_THREAD_POOL)); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index c7c38c30f..2a52d41a1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -26,7 +26,9 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** * Step to deploy a model @@ -75,7 +77,29 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { String taskId = mlDeployModelResponse.getTaskId(); // Attempt to retrieve the model ID - retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, "Deploy model"); + retryableGetMlTask( + currentNodeInputs.getWorkflowId(), + currentNodeId, + deployModelFuture, + taskId, + "Deploy model", + ActionListener.wrap(mlTask -> { + // Deployed Model Resource has been updated + String resourceName = getResourceByWorkflowStep(getName()); + String id = getResourceId(mlTask); + deployModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, + e -> { + deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + ) + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 4213d99cd..8f1127312 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -26,7 +26,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; -import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.threadpool.ThreadPool; import java.util.Map; @@ -34,6 +33,7 @@ import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; @@ -42,9 +42,11 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** * Step to register a local model @@ -88,30 +90,6 @@ public CompletableFuture execute( CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); - ActionListener actionListener = new ActionListener<>() { - @Override - public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("Local Model registration task creation successful"); - - String taskId = mlRegisterModelResponse.getTaskId(); - - // Attempt to retrieve the model ID - retryableGetMlTask( - currentNodeInputs.getWorkflowId(), - currentNodeId, - registerLocalModelFuture, - taskId, - "Local model registration" - ); - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to register local model"); - registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - }; - Set requiredKeys = Set.of( NAME_FIELD, VERSION_FIELD, @@ -122,7 +100,7 @@ public void onFailure(Exception e) { MODEL_CONTENT_HASH_VALUE, URL ); - Set optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, FUNCTION_NAME); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, FUNCTION_NAME, DEPLOY_FIELD); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -145,6 +123,7 @@ public void onFailure(Exception e) { String allConfig = (String) inputs.get(ALL_CONFIG); String url = (String) inputs.get(URL); String functionName = (String) inputs.get(FUNCTION_NAME); + final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); // Create Model configuration TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() @@ -173,10 +152,80 @@ public void onFailure(Exception e) { if (functionName != null) { mlInputBuilder.functionName(FunctionName.from(functionName)); } + if (deploy != null) { + mlInputBuilder.deployModel(deploy); + } MLRegisterModelInput mlInput = mlInputBuilder.build(); - mlClient.register(mlInput, actionListener); + mlClient.register(mlInput, ActionListener.wrap(response -> { + logger.info("Local Model registration task creation successful"); + + String taskId = response.getTaskId(); + + // Attempt to retrieve the model ID + retryableGetMlTask( + currentNodeInputs.getWorkflowId(), + currentNodeId, + registerLocalModelFuture, + taskId, + "Local model registration", + ActionListener.wrap(mlTask -> { + + // Registered Model Resource has been updated + String resourceName = getResourceByWorkflowStep(getName()); + String id = getResourceId(mlTask); + + if (Boolean.TRUE.equals(deploy)) { + + // Simulate Model deployment step and update resources created + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + DeployModelStep.NAME, + id, + ActionListener.wrap(deployUpdateResponse -> { + logger.info( + "successfully updated resources created in state index: {}", + deployUpdateResponse.getIndex() + ); + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, id), + Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, deployUpdateException -> { + logger.error("Failed to update simulated deploy step resource", deployUpdateException); + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException( + deployUpdateException.getMessage(), + ExceptionsHelper.status(deployUpdateException) + ) + ); + }) + ); + } else { + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + } + }, exception -> { registerLocalModelFuture.completeExceptionally(exception); }) + ); + }, exception -> { + logger.error("Failed to register local model"); + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + })); } catch (FlowFrameworkException e) { registerLocalModelFuture.completeExceptionally(e); } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 70a322f42..593b2bebe 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -112,8 +112,19 @@ public void setUpSettings() throws Exception { // Need a delay here on 2.x or next line consistently fails tests. // TODO: figure out know why we need this and we should pursue a better option that doesn't require HTTP5 Thread.sleep(10000); + // Set ML jvm heap memory threshold to 100 to avoid opening the circuit breaker during tests + response = TestHelpers.makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":100}}", + List.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + // Ensure .plugins-ml-config is created before proceeding with integration tests - assertBusy(() -> { assertTrue(indexExistsWithAdminClient(".plugins-ml-config")); }, 30, TimeUnit.SECONDS); + assertBusy(() -> { assertTrue(indexExistsWithAdminClient(".plugins-ml-config")); }, 60, TimeUnit.SECONDS); } } @@ -287,15 +298,21 @@ protected boolean preserveClusterSettings() { } /** - * Helper method to invoke the Create Workflow Rest Action + * Helper method to invoke the Create Workflow Rest Action without validation * @param template the template to create * @throws Exception if the request fails * @return a rest response */ protected Response createWorkflow(Template template) throws Exception { - return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI, Collections.emptyMap(), template.toJson(), null); + return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI + "?validation=off", Collections.emptyMap(), template.toJson(), null); } + /** + * Helper method to invoke the Create Workflow Rest Action with provision + * @param template the template to create + * @throws Exception if the request fails + * @return a rest response + */ protected Response createWorkflowWithProvision(Template template) throws Exception { return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI + "?provision=true", Collections.emptyMap(), template.toJson(), null); } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index f09cb4718..13eb20af1 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -23,6 +23,7 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.model.WorkflowState; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -68,15 +69,11 @@ public void testSearchWorkflows() throws Exception { } public void testCreateAndProvisionLocalModelWorkflow() throws Exception { - /*- Local model registration is not yet fully complete. Commenting this test out until it is. - * https://github.com/opensearch-project/flow-framework/issues/305 + // Using a 1 step template to register a local model and deploy model + Template template = TestHelpers.createTemplateFromFile("registerlocalmodel-deployflag.json"); - // Using a 3 step template to create a model group, register a remote model and deploy model - Template template = TestHelpers.createTemplateFromFile("registerlocalmodel-deploymodel.json"); - - // Remove deploy model input to test validation + // Remove register model input to test validation Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW); - List modifiednodes = originalWorkflow.nodes() .stream() .map( @@ -85,9 +82,7 @@ public void testCreateAndProvisionLocalModelWorkflow() throws Exception { : n ) .collect(Collectors.toList()); - Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges()); - Template templateWithMissingInputs = new Template.Builder().name(template.name()) .description(template.description()) .useCase(template.useCase()) @@ -125,9 +120,18 @@ public void testCreateAndProvisionLocalModelWorkflow() throws Exception { // Wait until provisioning has completed successfully before attempting to retrieve created resources List resourcesCreated = getResourcesCreated(workflowId, 100); - // TODO: This template should create 2 resources, registered_model_id and deployed model_id - assertEquals(0, resourcesCreated.size()); - */ + // This template should create 2 resources, registered_model_id and deployed model_id + assertEquals(2, resourcesCreated.size()); + assertEquals("register_local_model", resourcesCreated.get(0).workflowStepName()); + assertNotNull(resourcesCreated.get(0).resourceId()); + assertEquals("deploy_model", resourcesCreated.get(1).workflowStepName()); + assertNotNull(resourcesCreated.get(1).resourceId()); + + // Deprovision the workflow to avoid opening circut breaker when running additional tests + Response deprovisionResponse = deprovisionWorkflow(workflowId); + + // wait for deprovision to complete + Thread.sleep(5000); } public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { @@ -184,36 +188,12 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { assertNotNull(resourcesCreated.get(1).resourceId()); assertEquals("deploy_model", resourcesCreated.get(2).workflowStepName()); assertNotNull(resourcesCreated.get(2).resourceId()); - } - - public void testCreateAndProvisionDeployedRemoteModelWorkflow() throws Exception { - - // Using a 2 step template to create a connector, register remote model with deploy=true param set - Template template = TestHelpers.createTemplateFromFile("createconnector-registerdeployremotemodel.json"); - Response response = createWorkflow(template); - assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); - - Map responseMap = entityAsMap(response); - String workflowId = (String) responseMap.get(WORKFLOW_ID); - getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); - - // Hit Provision API and assert status - response = provisionWorkflow(workflowId); - assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); - getAndAssertWorkflowStatus(workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); - - // Wait until provisioning has completed successfully before attempting to retrieve created resources - List resourcesCreated = getResourcesCreated(workflowId, 30); + // Deprovision the workflow to avoid opening circut breaker when running additional tests + Response deprovisionResponse = deprovisionWorkflow(workflowId); - // This template should create 3 resources, connector_id, registered model_id and deployed model_id - assertEquals(3, resourcesCreated.size()); - assertEquals("create_connector", resourcesCreated.get(0).workflowStepName()); - assertNotNull(resourcesCreated.get(0).resourceId()); - assertEquals("register_remote_model", resourcesCreated.get(1).workflowStepName()); - assertNotNull(resourcesCreated.get(1).resourceId()); - assertEquals("deploy_model", resourcesCreated.get(2).workflowStepName()); - assertNotNull(resourcesCreated.get(2).resourceId()); + // wait for deprovision to complete + Thread.sleep(5000); } public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { @@ -250,10 +230,9 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { // Hit Deprovision API Response deprovisionResponse = deprovisionWorkflow(workflowId); - assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); assertBusy( () -> { getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, - 30, + 60, TimeUnit.SECONDS ); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 559aae36c..0369807ef 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -8,31 +8,20 @@ */ package org.opensearch.flowframework.transport; -import org.opensearch.Version; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ResourceCreated; -import org.opensearch.flowframework.model.Template; -import org.opensearch.flowframework.model.Workflow; -import org.opensearch.flowframework.model.WorkflowEdge; -import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.model.WorkflowState; -import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.DeleteConnectorStep; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; -import org.opensearch.index.get.GetResult; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -44,13 +33,11 @@ import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import org.mockito.ArgumentCaptor; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -66,53 +53,27 @@ public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase private static ThreadPool threadPool = new TestThreadPool(DeprovisionWorkflowTransportActionTests.class.getName()); private Client client; - private WorkflowProcessSorter workflowProcessSorter; private WorkflowStepFactory workflowStepFactory; private DeleteConnectorStep deleteConnectorStep; private DeprovisionWorkflowTransportAction deprovisionWorkflowTransportAction; - private Template template; - private GetResult getResult; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - private EncryptorUtils encryptorUtils; @Override public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); - this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.workflowStepFactory = mock(WorkflowStepFactory.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.encryptorUtils = mock(EncryptorUtils.class); this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), threadPool, client, - workflowProcessSorter, workflowStepFactory, - flowFrameworkIndicesHandler, - encryptorUtils + flowFrameworkIndicesHandler ); - Version templateVersion = Version.fromString("1.0.0"); - List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode node = new WorkflowNode("step_1", "create_connector", Collections.emptyMap(), Collections.emptyMap()); - List nodes = List.of(node); - List edges = Collections.emptyList(); - Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); - this.template = new Template( - "test", - "description", - "use case", - templateVersion, - compatibilityVersions, - Map.of(PROVISION_WORKFLOW, workflow), - Collections.emptyMap(), - TestHelpers.randomUser() - ); - this.getResult = mock(GetResult.class); - MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); ProcessNode processNode = mock(ProcessNode.class); when(processNode.id()).thenReturn("step_1"); @@ -120,7 +81,6 @@ public void setUp() throws Exception { when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap()); when(processNode.input()).thenReturn(WorkflowData.EMPTY); when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5)); - when(this.workflowProcessSorter.sortProcessNodes(any(Workflow.class), any(String.class))).thenReturn(List.of(processNode)); this.deleteConnectorStep = mock(DeleteConnectorStep.class); when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); @@ -141,17 +101,6 @@ public void testDeprovisionWorkflow() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - when(getResult.sourceAsString()).thenReturn(this.template.toJson()); - - doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(1); - - when(getResult.isExists()).thenReturn(true); - responseListener.onResponse(new GetResponse(getResult)); - return null; - }).when(client).get(any(GetRequest.class), any()); - - when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); @@ -174,44 +123,11 @@ public void testDeprovisionWorkflow() throws IOException { assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); } - public void testFailedToRetrieveTemplateFromGlobalContext() { - String workflowId = "1"; - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - when(getResult.sourceAsString()).thenReturn(this.template.toJson()); - - doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(1); - - when(getResult.isExists()).thenReturn(false); - responseListener.onResponse(new GetResponse(getResult)); - return null; - }).when(client).get(any(GetRequest.class), any()); - - deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - - verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to retrieve template (1) from global context.", exceptionCaptor.getValue().getMessage()); - } - public void testFailToDeprovision() throws IOException { String workflowId = "1"; @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - when(getResult.sourceAsString()).thenReturn(this.template.toJson()); - - doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(1); - - when(getResult.isExists()).thenReturn(true); - responseListener.onResponse(new GetResponse(getResult)); - return null; - }).when(client).get(any(GetRequest.class), any()); - - when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); diff --git a/src/test/resources/template/registerlocalmodel-deployflag.json b/src/test/resources/template/registerlocalmodel-deployflag.json new file mode 100644 index 000000000..6b4f00a73 --- /dev/null +++ b/src/test/resources/template/registerlocalmodel-deployflag.json @@ -0,0 +1,36 @@ +{ + "name": "registerlocalmodel-deployflag", + "description": "test case", + "use_case": "TEST_CASE", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "workflow_step_1", + "type": "register_local_model", + "user_inputs": { + "node_timeout": "60s", + "name": "all-MiniLM-L6-v2", + "version": "1.0.0", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_content_hash_value": "c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f", + "model_type": "bert", + "embedding_dimension": "384", + "framework_type": "sentence_transformers", + "all_config": "{\"_name_or_path\":\"nreimers/MiniLM-L6-H384-uncased\",\"architectures\":[\"BertModel\"],\"attention_probs_dropout_prob\":0.1,\"gradient_checkpointing\":false,\"hidden_act\":\"gelu\",\"hidden_dropout_prob\":0.1,\"hidden_size\":384,\"initializer_range\":0.02,\"intermediate_size\":1536,\"layer_norm_eps\":1e-12,\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6,\"pad_token_id\":0,\"position_embedding_type\":\"absolute\",\"transformers_version\":\"4.8.2\",\"type_vocab_size\":2,\"use_cache\":true,\"vocab_size\":30522}", + "url": "https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip", + "deploy": true + } + } + ] + } + } + } diff --git a/src/test/resources/template/registerlocalmodel-deploymodel.json b/src/test/resources/template/registerlocalmodel-deploymodel.json index 55bf6f21b..040a73074 100644 --- a/src/test/resources/template/registerlocalmodel-deploymodel.json +++ b/src/test/resources/template/registerlocalmodel-deploymodel.json @@ -1,48 +1,51 @@ { - "name": "registerlocalmodel-deploymodel", - "description": "test case", - "use_case": "TEST_CASE", - "version": { - "template": "1.0.0", - "compatibility": [ - "2.12.0", - "3.0.0" - ] - }, - "workflows": { - "provision": { - "nodes": [ - { - "id": "workflow_step_1", - "type": "register_local_model", - "user_inputs": { - "node_timeout": "60s", - "name": "all-MiniLM-L6-v2", - "version": "1.0.0", - "description": "test model", - "model_format": "TORCH_SCRIPT", - "model_content_hash_value": "c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f", - "model_type": "bert", - "embedding_dimension": "384", - "framework_type": "sentence_transformers", - "all_config": "{\"_name_or_path\":\"nreimers/MiniLM-L6-H384-uncased\",\"architectures\":[\"BertModel\"],\"attention_probs_dropout_prob\":0.1,\"gradient_checkpointing\":false,\"hidden_act\":\"gelu\",\"hidden_dropout_prob\":0.1,\"hidden_size\":384,\"initializer_range\":0.02,\"intermediate_size\":1536,\"layer_norm_eps\":1e-12,\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6,\"pad_token_id\":0,\"position_embedding_type\":\"absolute\",\"transformers_version\":\"4.8.2\",\"type_vocab_size\":2,\"use_cache\":true,\"vocab_size\":30522}", - "url": "https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip" - } - }, - { - "id": "workflow_step_2", - "type": "deploy_model", - "previous_node_inputs": { - "workflow_step_2": "model_id" - } + "name": "registerlocalmodel-deploymodel", + "description": "test case", + "use_case": "TEST_CASE", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "workflow_step_1", + "type": "register_local_model", + "user_inputs": { + "node_timeout": "60s", + "name": "all-MiniLM-L6-v2", + "version": "1.0.0", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_content_hash_value": "c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f", + "model_type": "bert", + "embedding_dimension": "384", + "framework_type": "sentence_transformers", + "all_config": "{\"_name_or_path\":\"nreimers/MiniLM-L6-H384-uncased\",\"architectures\":[\"BertModel\"],\"attention_probs_dropout_prob\":0.1,\"gradient_checkpointing\":false,\"hidden_act\":\"gelu\",\"hidden_dropout_prob\":0.1,\"hidden_size\":384,\"initializer_range\":0.02,\"intermediate_size\":1536,\"layer_norm_eps\":1e-12,\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6,\"pad_token_id\":0,\"position_embedding_type\":\"absolute\",\"transformers_version\":\"4.8.2\",\"type_vocab_size\":2,\"use_cache\":true,\"vocab_size\":30522}", + "url": "https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip" } - ], - "edges": [ - { - "source": "workflow_step_1", - "dest": "workflow_step_2" + }, + { + "id": "workflow_step_2", + "type": "deploy_model", + "user_inputs": { + "node_timeout": "60s" + }, + "previous_node_inputs": { + "workflow_step_1": "model_id" } - ] - } + } + ], + "edges": [ + { + "source": "workflow_step_1", + "dest": "workflow_step_2" + } + ] } } +}