diff --git a/CHANGELOG.md b/CHANGELOG.md index bdaed0ea0..f5d147ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,3 +24,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Documentation ### Maintenance ### Refactoring +- Refactor workflow step resource updates to eliminate duplication ([#796](https://github.com/opensearch-project/flow-framework/pull/796)) diff --git a/build.gradle b/build.gradle index cee971302..8825d097c 100644 --- a/build.gradle +++ b/build.gradle @@ -164,7 +164,7 @@ configurations { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" - implementation 'org.junit.jupiter:junit-jupiter:5.10.3' + implementation 'org.junit.jupiter:junit-jupiter:5.11.0' api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" api group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' @@ -185,7 +185,7 @@ dependencies { configurations.all { resolutionStrategy { - force("com.google.guava:guava:33.2.1-jre") // CVE for 31.1, keep to force transitive dependencies + force("com.google.guava:guava:33.3.0-jre") // CVE for 31.1, keep to force transitive dependencies force("com.fasterxml.jackson.core:jackson-core:2.17.2") // Dependency Jar Hell force("org.apache.httpcomponents.core5:httpcore5:5.2.5") // Dependency Jar Hell } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 63ac6f7d4..02ef8a825 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -44,6 +44,7 @@ import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -666,13 +667,13 @@ public void updateFlowFrameworkSystemIndexDocWithScript( /** * Creates a new ResourceCreated object and a script to update the state index * @param workflowId workflowId for the relevant step - * @param nodeId WorkflowData object with relevent step information + * @param nodeId current process node (workflow step) id * @param workflowStepName the workflowstep name that created the resource * @param resourceId the id of the newly created resource * @param listener the ActionListener for this step to handle completing the future after update * @throws IOException if parsing fails on new resource */ - public void updateResourceInStateIndex( + private void updateResourceInStateIndex( String workflowId, String nodeId, String workflowStepName, @@ -697,6 +698,44 @@ public void updateResourceInStateIndex( updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> { logger.info("updated resources created of {}", workflowId); listener.onResponse(updateResponse); - }, exception -> { listener.onFailure(exception); })); + }, listener::onFailure)); + } + + /** + * Adds a resource to the state index, including common exception handling + * @param currentNodeInputs Inputs to the current node + * @param nodeId current process node (workflow step) id + * @param workflowStepName the workflow step name that created the resource + * @param resourceId the id of the newly created resource + * @param listener the ActionListener for this step to handle completing the future after update + */ + public void addResourceToStateIndex( + WorkflowData currentNodeInputs, + String nodeId, + String workflowStepName, + String resourceId, + ActionListener listener + ) { + String resourceName = getResourceByWorkflowStep(workflowStepName); + try { + updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + nodeId, + workflowStepName, + resourceId, + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + listener.onResponse(new WorkflowData(Map.of(resourceName, resourceId), currentNodeInputs.getWorkflowId(), nodeId)); + }, exception -> { + String errorMessage = "Failed to update new created " + nodeId + " resource " + workflowStepName + " id " + resourceId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); + } catch (Exception e) { + String errorMessage = "Failed to parse and update new created resource"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } } } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index f4d2ae600..4c8c2c9ef 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -8,12 +8,12 @@ */ package org.opensearch.flowframework.model; -import org.apache.logging.log4j.util.Strings; import org.opensearch.Version; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.yaml.YamlXContent; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; @@ -372,10 +372,10 @@ public static Template updateExistingTemplate(Template existingTemplate, Templat if (templateWithNewFields.name() != null) { builder.name(templateWithNewFields.name()); } - if (!Strings.isBlank(templateWithNewFields.description())) { + if (Strings.hasText(templateWithNewFields.description())) { builder.description(templateWithNewFields.description()); } - if (!Strings.isBlank(templateWithNewFields.useCase())) { + if (Strings.hasText(templateWithNewFields.useCase())) { builder.useCase(templateWithNewFields.useCase()); } if (templateWithNewFields.templateVersion() != null) { diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 032b4b898..8acfab16a 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -138,6 +138,13 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); return processError(ffe, params, request); } + if (reprovision && !params.isEmpty()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } try { Template template; Map useCaseDefaultsMap = Collections.emptyMap(); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index e20b2ed3b..16e8b25e1 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -533,4 +533,21 @@ public static void flattenSettings(String prefix, Map settings, } } } + + /** + * Ensures index is prepended to flattened setting keys + * @param originalSettings the original settings map + * @return new map with keys prepended with index + */ + public static Map prependIndexToSettings(Map originalSettings) { + Map newSettings = new HashMap<>(); + originalSettings.entrySet().stream().forEach(x -> { + if (!x.getKey().startsWith("index.")) { + newSettings.put("index." + x.getKey(), x.getValue()); + } else { + newSettings.put(x.getKey(), x.getValue()); + } + }); + return newSettings; + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java index e23d88b63..bbaa77204 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java @@ -33,7 +33,6 @@ import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -98,43 +97,14 @@ public PlainActionFuture execute( @Override public void onResponse(AcknowledgedResponse acknowledgedResponse) { - String resourceName = getResourceByWorkflowStep(getName()); - try { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - pipelineId, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - // PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead - // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here - createPipelineFuture.onResponse( - new WorkflowData( - Map.of(resourceName, pipelineId), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + pipelineId; - logger.error(errorMessage, exception); - createPipelineFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - createPipelineFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + // PutPipelineRequest returns only an AcknowledgeResponse, saving pipelineId instead + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + pipelineId, + createPipelineFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 10c6a884b..3f6d19219 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -49,7 +49,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; @@ -189,58 +188,38 @@ public PlainActionFuture execute( // Attempt to retrieve the model ID retryableGetMlTask( - currentNodeInputs.getWorkflowId(), + currentNodeInputs, currentNodeId, registerLocalModelFuture, taskId, "Local model registration", - ActionListener.wrap(mlTask -> { - + ActionListener.wrap(mlTaskWorkflowData -> { // Registered Model Resource has been updated String resourceName = getResourceByWorkflowStep(getName()); - String id = getResourceId(mlTask); - if (Boolean.TRUE.equals(deploy)) { - - // Simulate Model deployment step and update resources created - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - DeployModelStep.NAME, - id, - ActionListener.wrap(deployUpdateResponse -> { - logger.info( - "successfully updated resources created in state index: {}", - deployUpdateResponse.getIndex() - ); - registerLocalModelFuture.onResponse( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, id), - Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, deployUpdateException -> { + String id = (String) mlTaskWorkflowData.getContent().get(resourceName); + ActionListener deployUpdateListener = ActionListener.wrap( + deployUpdateResponse -> registerLocalModelFuture.onResponse(mlTaskWorkflowData), + deployUpdateException -> { String errorMessage = "Failed to update simulated deploy step resource " + id; logger.error(errorMessage, deployUpdateException); registerLocalModelFuture.onFailure( new FlowFrameworkException(errorMessage, ExceptionsHelper.status(deployUpdateException)) ); - }) + } ); - } else { - registerLocalModelFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) + // Simulate Model deployment step and update resources created + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + DeployModelStep.NAME, + id, + deployUpdateListener ); + } else { + registerLocalModelFuture.onResponse(mlTaskWorkflowData); } - }, exception -> { registerLocalModelFuture.onFailure(exception); }) + }, registerLocalModelFuture::onFailure) ); }, exception -> { Exception e = getSafeException(exception); diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index c1933b0c4..3e6dc9a85 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.FutureUtils; @@ -24,8 +23,11 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.threadpool.ThreadPool; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -60,7 +62,7 @@ protected AbstractRetryableWorkflowStep( /** * Retryable get ml task - * @param workflowId the workflow id + * @param currentNodeInputs the current Node Inputs * @param nodeId the workflow node id * @param future the workflow step future * @param taskId the ml task id @@ -68,12 +70,12 @@ protected AbstractRetryableWorkflowStep( * @param mlTaskListener the ML Task Listener */ protected void retryableGetMlTask( - String workflowId, + WorkflowData currentNodeInputs, String nodeId, PlainActionFuture future, String taskId, String workflowStep, - ActionListener mlTaskListener + ActionListener mlTaskListener ) { CompletableFuture.runAsync(() -> { do { @@ -82,34 +84,13 @@ protected void retryableGetMlTask( String id = getResourceId(response); switch (response.getState()) { case COMPLETED: - try { - logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - id, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - mlTaskListener.onResponse(response); - }, exception -> { - String errorMessage = "Failed to update new created " - + nodeId - + " resource " - + getName() - + " id " - + id; - logger.error(errorMessage, exception); - mlTaskListener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource " + resourceName + " id " + id; - logger.error(errorMessage, e); - mlTaskListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("{} successful for {} and {} {}", workflowStep, currentNodeInputs, resourceName, id); + ActionListener resourceListener = ActionListener.wrap(r -> { + Map content = new HashMap<>(r.getContent()); + content.put(REGISTER_MODEL_STATUS, response.getState().toString()); + mlTaskListener.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId())); + }, mlTaskListener::onFailure); + flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, nodeId, getName(), id, resourceListener); break; case FAILED: case COMPLETED_WITH_ERROR: diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 484807ce3..cd9d2e2ac 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -43,7 +43,6 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; @@ -85,40 +84,14 @@ public PlainActionFuture execute( @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - String resourceName = getResourceByWorkflowStep(getName()); - try { - logger.info("Created connector successfully"); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlCreateConnectorResponse.getConnectorId(), - ActionListener.wrap(response -> { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); - createConnectorFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlCreateConnectorResponse.getConnectorId(); - logger.error(errorMessage, exception); - createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - createConnectorFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Created connector successfully"); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlCreateConnectorResponse.getConnectorId(), + createConnectorFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 32ca9e9f6..3ac82cb7a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -21,13 +21,11 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.index.mapper.MapperService; -import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; @@ -37,7 +35,6 @@ import static java.util.Collections.singletonMap; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -108,35 +105,14 @@ public PlainActionFuture execute( } client.admin().indices().create(createIndexRequest, ActionListener.wrap(acknowledgedResponse -> { - String resourceName = getResourceByWorkflowStep(getName()); logger.info("Created index: {}", indexName); - try { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - indexName, - ActionListener.wrap(response -> { - logger.info("successfully updated resource created in state index: {}", response.getIndex()); - createIndexFuture.onResponse( - new WorkflowData(Map.of(resourceName, indexName), currentNodeInputs.getWorkflowId(), currentNodeId) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + indexName; - logger.error(errorMessage, exception); - createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - } catch (IOException ex) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, ex); - createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(ex))); - } + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + indexName, + createIndexFuture + ); }, ex -> { Exception e = getSafeException(ex); String errorMessage = (e == null ? "Failed to create the index " + indexName : e.getMessage()); @@ -155,7 +131,7 @@ public PlainActionFuture execute( // to encounter the same behavior and not suddenly have to add `_doc` while using our create_index step // related bug: https://github.com/opensearch-project/OpenSearch/issues/12775 private static Map prepareMappings(Map source) { - if (source.containsKey("mappings") == false || (source.get("mappings") instanceof Map) == false) { + if (!source.containsKey("mappings") || !(source.get("mappings") instanceof Map)) { return source; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index c89b25d16..13f1b6fe5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -17,8 +15,6 @@ * Step to create an ingest pipeline */ public class CreateIngestPipelineStep extends AbstractCreatePipelineStep { - private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "create_ingest_pipeline"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java index 1dc8fc745..67bc5fc0d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStep.java @@ -8,8 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -17,8 +15,6 @@ * Step to create a search pipeline */ public class CreateSearchPipelineStep extends AbstractCreatePipelineStep { - private static final Logger logger = LogManager.getLogger(CreateSearchPipelineStep.class); - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "create_search_pipeline"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 56c2a6181..e51100f04 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -28,7 +28,6 @@ import java.util.Set; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -93,24 +92,16 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { // Attempt to retrieve the model ID retryableGetMlTask( - currentNodeInputs.getWorkflowId(), + currentNodeInputs, currentNodeId, deployModelFuture, taskId, "Deploy model", - ActionListener.wrap(mlTask -> { - // Deployed Model Resource has been updated - String resourceName = getResourceByWorkflowStep(getName()); - String id = getResourceId(mlTask); - deployModelFuture.onResponse( - new WorkflowData(Map.of(resourceName, id), currentNodeInputs.getWorkflowId(), currentNodeId) - ); - }, - e -> { - deployModelFuture.onFailure( - new FlowFrameworkException("Failed to deploy model", ExceptionsHelper.status(e)) - ); - } + ActionListener.wrap( + deployModelFuture::onResponse, + e -> deployModelFuture.onFailure( + new FlowFrameworkException("Failed to deploy model", ExceptionsHelper.status(e)) + ) ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index af2aea1e3..cf03de4bb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -62,6 +62,7 @@ public PlainActionFuture execute( throw new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST); } catch (InterruptedException e) { FutureUtils.cancel(future); + Thread.currentThread().interrupt(); } future.onResponse(WorkflowData.EMPTY); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index d485f6f5c..5eed7a864 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -47,7 +47,6 @@ import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; @@ -95,42 +94,14 @@ public PlainActionFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { - try { - String resourceName = getResourceByWorkflowStep(getName()); - logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - currentNodeId, - getName(), - mlRegisterAgentResponse.getAgentId(), - ActionListener.wrap(response -> { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); - registerAgentModelFuture.onResponse( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), - workflowId, - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterAgentResponse.getAgentId(); - logger.error(errorMessage, exception); - registerAgentModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerAgentModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterAgentResponse.getAgentId(), + registerAgentModelFuture + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index 8fe1271df..6cc990429 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -37,7 +38,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** @@ -77,45 +77,19 @@ public PlainActionFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { - try { - logger.info("Model group registration successful"); - String resourceName = getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlRegisterModelGroupResponse.getModelGroupId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - registerModelGroupFuture.onResponse( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, mlRegisterModelGroupResponse.getModelGroupId()), - Map.entry(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()) - ), - currentNodeInputs.getWorkflowId(), - currentNodeId - ) - ); - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterModelGroupResponse.getModelGroupId(); - logger.error(errorMessage, exception); - registerModelGroupFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerModelGroupFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + logger.info("Model group registration successful"); + ActionListener resourceListener = ActionListener.wrap(r -> { + Map content = new HashMap<>(r.getContent()); + content.put(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()); + registerModelGroupFuture.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId())); + }, registerModelGroupFuture::onFailure); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterModelGroupResponse.getModelGroupId(), + resourceListener + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 25c0de272..72b56ab79 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; @@ -145,60 +144,51 @@ public PlainActionFuture execute( mlClient.register(mlInput, new ActionListener() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("Remote Model registration successful"); + String resourceName = getResourceByWorkflowStep(getName()); + ActionListener registerUpdateListener = ActionListener.wrap(registerUpdateResponse -> { + if (Boolean.TRUE.equals(deploy)) { + updateDeployResource(resourceName, mlRegisterModelResponse); + } else { + completeRegisterFuture(resourceName, mlRegisterModelResponse); + } + }, registerUpdateException -> { + String errorMessage = "Failed to update new created " + + currentNodeId + + " resource " + + getName() + + " id " + + mlRegisterModelResponse.getModelId(); + completeRegisterFutureExceptionally(errorMessage, registerUpdateException); + }); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + getName(), + mlRegisterModelResponse.getModelId(), + registerUpdateListener + ); + } - try { - logger.info("Remote Model registration successful"); - String resourceName = getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - getName(), - mlRegisterModelResponse.getModelId(), - ActionListener.wrap(response -> { - // If we deployed, simulate the deploy step has been called - if (Boolean.TRUE.equals(deploy)) { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - currentNodeId, - DeployModelStep.NAME, - mlRegisterModelResponse.getModelId(), - ActionListener.wrap(deployUpdateResponse -> { - completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse); - }, deployUpdateException -> { - String errorMessage = "Failed to update simulated deploy step resource " - + mlRegisterModelResponse.getModelId(); - logger.error(errorMessage, deployUpdateException); - registerRemoteModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(deployUpdateException)) - ); - }) - ); - } else { - completeRegisterFuture(response, resourceName, mlRegisterModelResponse); - } - }, exception -> { - String errorMessage = "Failed to update new created " - + currentNodeId - + " resource " - + getName() - + " id " - + mlRegisterModelResponse.getModelId(); - logger.error(errorMessage, exception); - registerRemoteModelFuture.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - registerRemoteModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + private void updateDeployResource(String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { + ActionListener deployUpdateListener = ActionListener.wrap( + deployUpdateResponse -> completeRegisterFuture(resourceName, mlRegisterModelResponse), + deployUpdateException -> { + String errorMessage = "Failed to update simulated deploy step resource " + mlRegisterModelResponse.getModelId(); + completeRegisterFutureExceptionally(errorMessage, deployUpdateException); + } + ); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + currentNodeId, + DeployModelStep.NAME, + mlRegisterModelResponse.getModelId(), + deployUpdateListener + ); } - void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("successfully updated resources created in state index: {}", response.getIndex()); + private void completeRegisterFuture(String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { + logger.info("successfully updated resources created in state index"); registerRemoteModelFuture.onResponse( new WorkflowData( Map.ofEntries( @@ -211,6 +201,11 @@ void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegi ); } + private void completeRegisterFutureExceptionally(String errorMessage, Exception exception) { + logger.error(errorMessage, exception); + registerRemoteModelFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + @Override public void onFailure(Exception ex) { Exception e = getSafeException(ex); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java b/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java index b46ddecab..18ce60f14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ReindexStep.java @@ -38,8 +38,6 @@ public class ReindexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ReindexStep.class); private final Client client; - private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ public static final String NAME = "reindex"; /** The refresh field for reindex */ @@ -61,7 +59,6 @@ public class ReindexStep implements WorkflowStep { */ public ReindexStep(Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.client = client; - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java index 9d35a32ce..719ef7237 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/UpdateIndexStep.java @@ -113,7 +113,10 @@ public PlainActionFuture execute( if (updatedSettings.containsKey("index")) { ParseUtils.flattenSettings("", updatedSettings, flattenedSettings); } else { - flattenedSettings.putAll(updatedSettings); + // Create index setting configuration can be a mix of flattened or expanded settings + // prepend index. to ensure successful setting comparison + + flattenedSettings.putAll(ParseUtils.prependIndexToSettings(updatedSettings)); } Map filteredSettings = new HashMap<>(); @@ -133,35 +136,39 @@ public PlainActionFuture execute( filteredSettings.put(e.getKey(), e.getValue()); } } + + // Create and send the update settings request + updateSettingsRequest.settings(filteredSettings); + if (updateSettingsRequest.settings().size() == 0) { + String errorMessage = "Failed to update index settings for index " + + indexName + + ", no settings have been updated"; + updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)); + } else { + client.admin().indices().updateSettings(updateSettingsRequest, ActionListener.wrap(acknowledgedResponse -> { + String resourceName = getResourceByWorkflowStep(getName()); + logger.info("Updated index settings for index {}", indexName); + updateIndexFuture.onResponse( + new WorkflowData(Map.of(resourceName, indexName), currentNodeInputs.getWorkflowId(), currentNodeId) + ); + + }, ex -> { + Exception e = getSafeException(ex); + String errorMessage = (e == null + ? "Failed to update the index settings for index " + indexName + : e.getMessage()); + logger.error(errorMessage, e); + updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + })); + } }, ex -> { Exception e = getSafeException(ex); String errorMessage = (e == null ? "Failed to retrieve the index settings for index " + indexName : e.getMessage()); logger.error(errorMessage, e); updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); })); - - updateSettingsRequest.settings(filteredSettings); } } - - if (updateSettingsRequest.settings().size() == 0) { - String errorMessage = "Failed to update index settings for index " + indexName + ", no settings have been updated"; - throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); - } else { - client.admin().indices().updateSettings(updateSettingsRequest, ActionListener.wrap(acknowledgedResponse -> { - String resourceName = getResourceByWorkflowStep(getName()); - logger.info("Updated index settings for index {}", indexName); - updateIndexFuture.onResponse( - new WorkflowData(Map.of(resourceName, indexName), currentNodeInputs.getWorkflowId(), currentNodeId) - ); - - }, ex -> { - Exception e = getSafeException(ex); - String errorMessage = (e == null ? "Failed to update the index settings for index " + indexName : e.getMessage()); - logger.error(errorMessage, e); - updateIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); - })); - } } catch (Exception e) { updateIndexFuture.onFailure(new WorkflowStepException(e.getMessage(), ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index ba19823a7..dad2d65c9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -74,7 +74,7 @@ public Map getContent() { */ public Map getParams() { return this.params; - }; + } /** * Returns the workflowId associated with this data. @@ -83,7 +83,7 @@ public Map getParams() { @Nullable public String getWorkflowId() { return this.workflowId; - }; + } /** * Returns the nodeId associated with this data. @@ -92,5 +92,5 @@ public String getWorkflowId() { @Nullable public String getNodeId() { return this.nodeId; - }; + } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index dc25f44fa..877b6292a 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -415,6 +415,25 @@ protected Response createWorkflowValidation(RestClient client, Template template return TestHelpers.makeRequest(client, "POST", WORKFLOW_URI, Collections.emptyMap(), template.toJson(), null); } + /** + * Helper method to invoke the Reprovision Workflow API + * @param client the rest client + * @param workflowId the document id + * @param templateFields the template to reprovision + * @throws Exception if the request fails + * @return a rest response + */ + protected Response reprovisionWorkflow(RestClient client, String workflowId, Template template) throws Exception { + return TestHelpers.makeRequest( + client, + "PUT", + String.format(Locale.ROOT, "%s/%s?reprovision=true", WORKFLOW_URI, workflowId), + Collections.emptyMap(), + template.toJson(), + null + ); + } + /** * Helper method to invoke the Update Workflow API * @param client the rest client diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 6c4f3534b..6d136f7a6 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -11,7 +11,6 @@ import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpEntity; import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.logging.log4j.util.Strings; import org.opensearch.client.Request; import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; @@ -24,6 +23,7 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -74,7 +74,7 @@ public static Response makeRequest( String jsonEntity, List
headers ) throws IOException { - HttpEntity httpEntity = Strings.isBlank(jsonEntity) ? null : new StringEntity(jsonEntity, APPLICATION_JSON); + HttpEntity httpEntity = !Strings.hasText(jsonEntity) ? null : new StringEntity(jsonEntity, APPLICATION_JSON); return makeRequest(client, method, endpoint, params, httpEntity, headers); } diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index f55cb3bc1..ecaec46b5 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -37,13 +37,16 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.CreateIndexStep; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.index.get.GetResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -488,4 +491,52 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { exceptionCaptor.getValue().getMessage() ); } + + public void testAddResourceToStateIndex() throws IOException { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED)); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to update state")); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), null, null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + listener + ); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update new created node_id resource create_connector id this_id", exceptionCaptor.getValue().getMessage()); + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 6a454dc75..d176adc3b 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -37,9 +37,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; @@ -48,15 +46,12 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; public class FlowFrameworkRestApiIT extends FlowFrameworkRestTestCase { - private static AtomicBoolean waitToStart = new AtomicBoolean(true); @Before public void waitToStart() throws Exception { // ML Commons cron job runs every 10 seconds and takes 20+ seconds to initialize .plugins-ml-config index - // Delay on the first attempt for 25 seconds to allow this initialization and prevent flaky tests - if (waitToStart.getAndSet(false)) { - CountDownLatch latch = new CountDownLatch(1); - latch.await(25, TimeUnit.SECONDS); + if (!indexExistsWithAdminClient(".plugins-ml-config")) { + assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); } } @@ -93,14 +88,7 @@ public void testFailedUpdateWorkflow() throws Exception { Map responseMap = entityAsMap(response); String workflowId = (String) responseMap.get(WORKFLOW_ID); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - Response provisionResponse; - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - provisionResponse = provisionWorkflow(client(), workflowId); - } else { - provisionResponse = provisionWorkflow(client(), workflowId); - } + Response provisionResponse = provisionResponse = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(provisionResponse)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -122,14 +110,7 @@ public void testUpdateWorkflowUsingFields() throws Exception { Map responseMap = entityAsMap(response); String workflowId = (String) responseMap.get(WORKFLOW_ID); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - Response provisionResponse; - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - provisionResponse = provisionWorkflow(client(), workflowId); - } else { - provisionResponse = provisionWorkflow(client(), workflowId); - } + Response provisionResponse = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(provisionResponse)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -259,14 +240,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { String workflowId = (String) responseMap.get(WORKFLOW_ID); getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = provisionWorkflow(client(), workflowId); - } else { - response = provisionWorkflow(client(), workflowId); - } - + response = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -294,13 +268,7 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { Template template = TestHelpers.createTemplateFromFile("agent-framework.json"); // Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter - Response response; - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = createWorkflowWithProvision(client(), template); - } else { - response = createWorkflowWithProvision(client(), template); - } + Response response = createWorkflowWithProvision(client(), template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); String workflowId = (String) responseMap.get(WORKFLOW_ID); @@ -363,6 +331,233 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS); } + public void testReprovisionWorkflow() throws Exception { + // Begin with a template to register a local pretrained model + Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json"); + + Response response = createWorkflowWithProvision(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(client(), workflowId, 30); + assertEquals(3, resourcesCreated.size()); + Map resourceMap = resourcesCreated.stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + + // Reprovision template to add ingest pipeline which uses the model ID + template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline.json"); + response = reprovisionWorkflow(client(), workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + resourcesCreated = getResourcesCreated(client(), workflowId, 30); + assertEquals(4, resourcesCreated.size()); + resourceMap = resourcesCreated.stream().collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_ingest_pipeline")); + + // Retrieve pipeline by ID to ensure model ID is set correctly + String modelId = resourceMap.get("register_remote_model").resourceId(); + String pipelineId = resourceMap.get("create_ingest_pipeline").resourceId(); + GetPipelineResponse getPipelineResponse = getPipelines(pipelineId); + assertEquals(1, getPipelineResponse.pipelines().size()); + assertTrue(getPipelineResponse.pipelines().get(0).getConfigAsMap().toString().contains(modelId)); + + // Reprovision template to add index which uses default ingest pipeline + template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); + response = reprovisionWorkflow(client(), workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + resourcesCreated = getResourcesCreated(client(), workflowId, 30); + assertEquals(5, resourcesCreated.size()); + resourceMap = resourcesCreated.stream().collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_ingest_pipeline")); + assertTrue(resourceMap.containsKey("create_index")); + + // Retrieve index settings to ensure pipeline ID is set correctly + String indexName = resourceMap.get("create_index").resourceId(); + Map indexSettings = getIndexSettingsAsMap(indexName); + assertEquals(pipelineId, indexSettings.get("index.default_pipeline")); + + // Reprovision template to remove default ingest pipeline + template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-updateindex.json"); + response = reprovisionWorkflow(client(), workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + resourcesCreated = getResourcesCreated(client(), workflowId, 30); + // resource count should remain unchanged when updating an existing node + assertEquals(5, resourcesCreated.size()); + + // Retrieve index settings to ensure default pipeline has been updated correctly + indexSettings = getIndexSettingsAsMap(indexName); + assertEquals("_none", indexSettings.get("index.default_pipeline")); + + // Deprovision and delete all resources + Response deprovisionResponse = deprovisionWorkflowWithAllowDelete(client(), workflowId, pipelineId + "," + indexName); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } + + public void testReprovisionWorkflowMidNodeAddition() throws Exception { + // Begin with a template to register a local pretrained model and create an index, no edges + Template template = TestHelpers.createTemplateFromFile("registerremotemodel-createindex.json"); + + Response response = createWorkflowWithProvision(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(client(), workflowId, 30); + assertEquals(4, resourcesCreated.size()); + Map resourceMap = resourcesCreated.stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_index")); + + // Reprovision template to add ingest pipeline which uses the model ID + template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); + response = reprovisionWorkflow(client(), workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + resourcesCreated = getResourcesCreated(client(), workflowId, 30); + assertEquals(5, resourcesCreated.size()); + resourceMap = resourcesCreated.stream().collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_ingest_pipeline")); + assertTrue(resourceMap.containsKey("create_index")); + + // Ensure ingest pipeline configuration contains the model id and index settings have the ingest pipeline as default + String modelId = resourceMap.get("register_remote_model").resourceId(); + String pipelineId = resourceMap.get("create_ingest_pipeline").resourceId(); + GetPipelineResponse getPipelineResponse = getPipelines(pipelineId); + assertEquals(1, getPipelineResponse.pipelines().size()); + assertTrue(getPipelineResponse.pipelines().get(0).getConfigAsMap().toString().contains(modelId)); + + String indexName = resourceMap.get("create_index").resourceId(); + Map indexSettings = getIndexSettingsAsMap(indexName); + assertEquals(pipelineId, indexSettings.get("index.default_pipeline")); + + // Deprovision and delete all resources + Response deprovisionResponse = deprovisionWorkflowWithAllowDelete(client(), workflowId, pipelineId + "," + indexName); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } + + public void testReprovisionWithNoChange() throws Exception { + Template template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); + + Response response = createWorkflowWithProvision(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Attempt to reprovision the same template with no changes + ResponseException exception = expectThrows(ResponseException.class, () -> reprovisionWorkflow(client(), workflowId, template)); + assertEquals(RestStatus.BAD_REQUEST.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); + assertTrue(exception.getMessage().contains("Template does not contain any modifications")); + + // Deprovision and delete all resources + Response deprovisionResponse = deprovisionWorkflowWithAllowDelete( + client(), + workflowId, + "nlp-ingest-pipeline" + "," + "my-nlp-index" + ); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } + + public void testReprovisionWithDeletion() throws Exception { + Template template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); + + Response response = createWorkflowWithProvision(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Attempt to reprovision template without ingest pipeline node + Template templateWithoutIngestPipeline = TestHelpers.createTemplateFromFile("registerremotemodel-createindex.json"); + ResponseException exception = expectThrows( + ResponseException.class, + () -> reprovisionWorkflow(client(), workflowId, templateWithoutIngestPipeline) + ); + assertEquals(RestStatus.BAD_REQUEST.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); + assertTrue(exception.getMessage().contains("Workflow Step deletion is not supported when reprovisioning a template.")); + + // Deprovision and delete all resources + Response deprovisionResponse = deprovisionWorkflowWithAllowDelete( + client(), + workflowId, + "nlp-ingest-pipeline" + "," + "my-nlp-index" + ); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } + public void testTimestamps() throws Exception { Template noopTemplate = TestHelpers.createTemplateFromFile("noop.json"); // Create the template, should have created and updated matching @@ -421,14 +616,7 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception { String workflowId = (String) responseMap.get(WORKFLOW_ID); getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = provisionWorkflow(client(), workflowId); - } else { - response = provisionWorkflow(client(), workflowId); - } - + response = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -475,14 +663,7 @@ public void testDefaultCohereUseCase() throws Exception { String workflowId = (String) responseMap.get(WORKFLOW_ID); getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = provisionWorkflow(client(), workflowId); - } else { - response = provisionWorkflow(client(), workflowId); - } - + response = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -526,14 +707,7 @@ public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Excepti String workflowId = (String) responseMap.get(WORKFLOW_ID); getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); - // Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status - if (!indexExistsWithAdminClient(".plugins-ml-config")) { - assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = provisionWorkflow(client(), workflowId); - } else { - response = provisionWorkflow(client(), workflowId); - } - + response = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); getAndAssertWorkflowStatus(client(), workflowId, State.FAILED, ProvisioningProgress.FAILED); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index e4f22e947..f6b1a5fc7 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -178,6 +178,20 @@ public void testCreateWorkflowRequestWithCreateAndReprovision() throws Exception ); } + public void testCreateWorkflowRequestWithReprovisionAndSubstitutionParams() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(REPROVISION_WORKFLOW, "true"), Map.entry("open_ai_key", "1234"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse().content().utf8ToString().contains("are permitted unless the provision parameter is set to true.") + ); + } + public void testCreateWorkflowRequestWithUpdateAndParams() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 1cdb0c50e..8237a7a93 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -330,4 +330,21 @@ public void testFlattenSettings() throws Exception { assertTrue(flattenedSettings.entrySet().stream().allMatch(x -> x.getKey().startsWith("index."))); } + + public void testPrependIndexToSettings() throws Exception { + + Map indexSettingsMap = Map.ofEntries( + Map.entry("knn", "true"), + Map.entry("number_of_shards", "2"), + Map.entry("number_of_replicas", "1"), + Map.entry("index.default_pipeline", "_none"), + Map.entry("search", Map.of("default_pipeine", "_none")) + ); + Map prependedSettings = ParseUtils.prependIndexToSettings(indexSettingsMap); + assertEquals(5, prependedSettings.size()); + + // every setting should start with index + assertTrue(prependedSettings.entrySet().stream().allMatch(x -> x.getKey().startsWith("index."))); + + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 86dc4af47..dd9eb369d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -30,8 +28,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -94,10 +90,10 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(CONNECTOR_ID, connectorId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = createConnectorStep.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 538747b94..b9a9dfe88 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -12,7 +12,6 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -21,7 +20,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -36,10 +34,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -96,10 +92,10 @@ public void setUp() throws Exception { public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(INDEX_NAME, "demo"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index c7912d494..efd9275b5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -11,12 +11,10 @@ import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; @@ -27,9 +25,7 @@ import org.mockito.ArgumentCaptor; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.mockito.ArgumentMatchers.any; @@ -78,10 +74,10 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(PIPELINE_ID, "pipelineId"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java index 9a54f2af6..ac1fda8d8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateSearchPipelineStepTests.java @@ -11,12 +11,10 @@ import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; @@ -27,9 +25,7 @@ import org.mockito.ArgumentCaptor; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.mockito.ArgumentMatchers.any; @@ -43,7 +39,7 @@ public class CreateSearchPipelineStepTests extends OpenSearchTestCase { private WorkflowData inputData; - private WorkflowData outpuData; + private WorkflowData outputData; private Client client; private AdminClient adminClient; private ClusterAdminClient clusterAdminClient; @@ -63,7 +59,7 @@ public void setUp() throws Exception { ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); + outputData = new WorkflowData(Map.ofEntries(Map.entry(PIPELINE_ID, "pipelineId")), "test-id", "test-node-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -78,10 +74,10 @@ public void testCreateSearchPipelineStep() throws InterruptedException, Executio CreateSearchPipelineStep createSearchPipelineStep = new CreateSearchPipelineStep(client, flowFrameworkIndicesHandler); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(PIPELINE_ID, "pipelineId"), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -100,7 +96,7 @@ public void testCreateSearchPipelineStep() throws InterruptedException, Executio actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); assertTrue(future.isDone()); - assertEquals(outpuData.getContent(), future.get().getContent()); + assertEquals(outputData.getContent(), future.get().getContent()); } public void testCreateSearchPipelineStepFailure() throws InterruptedException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 60a6b7f48..b05e43ed4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,10 +40,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -152,10 +148,10 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = deployModel.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 4aabee215..b53574cfe 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -15,6 +15,9 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.flowframework.common.CommonValue.DELAY_FIELD; @@ -50,6 +53,49 @@ public void testNoOpStepDelay() throws IOException, InterruptedException { assertTrue(System.nanoTime() - start > 900_000_000L); } + public void testNoOpStepInterrupt() throws IOException, InterruptedException { + NoOpStep noopStep = new NoOpStep(); + WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "5s"), null, null); + + CountDownLatch latch = new CountDownLatch(1); + // Fetch errors from the separate thread + AtomicReference assertionError = new AtomicReference<>(); + + Thread testThread = new Thread(() -> { + try { + PlainActionFuture future = noopStep.execute( + "nodeId", + delayData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + try { + future.actionGet(); + } catch (Exception e) { + // Ignore the IllegalStateExcption/InterruptedExcpetion + } + assertTrue(future.isDone()); + assertTrue(future.isCancelled()); + assertTrue(Thread.currentThread().isInterrupted()); + } catch (AssertionError e) { + assertionError.set(e); + } finally { + latch.countDown(); + } + }); + + testThread.start(); + Thread.sleep(100); + testThread.interrupt(); + + latch.await(1, TimeUnit.SECONDS); + + if (assertionError.get() != null) { + throw assertionError.get(); + } + } + public void testNoOpStepParse() throws IOException { NoOpStep noopStep = new NoOpStep(); WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "foo"), null, null); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 9005052ae..626dfdfa1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -33,8 +31,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -101,10 +97,10 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(AGENT_ID, agentId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), @@ -134,10 +130,10 @@ public void testRegisterAgentFailure() throws IOException { }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(AGENT_ID, agentId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index f8f2bce8f..061d9f8c8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -38,16 +36,15 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -168,10 +165,10 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), @@ -225,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } + // This method tests code in the abstract parent + public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage()); + } + public void testRegisterLocalCustomModelFailure() { doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index dcd2098d5..afe97bacb 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,12 +40,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -162,10 +158,10 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalPretrainedModelStep.execute( workflowData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index df6fd94e2..9c35af33c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -11,12 +11,10 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -42,12 +40,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -165,10 +161,10 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { }).when(machineLearningNodeClient).getTask(any(), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( workflowData.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 921c7ac15..7f7adf44b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -9,9 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -33,8 +31,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -111,10 +108,10 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_GROUP_ID, modelGroupId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = modelGroupStep.execute( inputData.getNodeId(), @@ -140,7 +137,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep assertTrue(future.isDone()); assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); - assertEquals(status, future.get().getContent().get("model_group_status")); + assertEquals(status, future.get().getContent().get(MODEL_GROUP_STATUS)); } public void testRegisterModelGroupFailure() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 603bfde57..0e2ab91e9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -12,9 +12,7 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -30,15 +28,14 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; @@ -94,10 +91,10 @@ public void testRegisterRemoteModelSuccess() throws Exception { }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); PlainActionFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), @@ -109,7 +106,13 @@ public void testRegisterRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); // only updates register resource - verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(1)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -130,10 +133,10 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); WorkflowData deployWorkflowData = new WorkflowData( Map.ofEntries( @@ -156,7 +159,13 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); // updates both register and deploy resources - verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(2)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -182,7 +191,13 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { verify(mlNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); // updates both register and deploy resources - verify(flowFrameworkIndicesHandler, times(4)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + verify(flowFrameworkIndicesHandler, times(4)).addResourceToStateIndex( + any(WorkflowData.class), + anyString(), + anyString(), + anyString(), + any() + ); assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -210,6 +225,99 @@ public void testRegisterRemoteModelFailure() { } + public void testRegisterRemoteModelUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onFailure(new RuntimeException("Failed to update register resource")); + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update new created test-node-id resource register_remote_model id efgh", ex.getCause().getMessage()); + } + + public void testRegisterRemoteModelDeployUpdateFailure() { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + AtomicInteger invocationCount = new AtomicInteger(0); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + if (invocationCount.getAndIncrement() == 0) { + // succeed on first call (update register) + updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id")); + } else { + // fail on second call (update deploy) + updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource")); + } + return null; + }).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, true) + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to update simulated deploy step resource efgh", ex.getCause().getMessage()); + } + public void testReisterRemoteModelInterfaceFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java index 97eff365a..a195025ea 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ReindexStepTests.java @@ -11,12 +11,10 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.OpenSearchException; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.Randomness; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.BulkByScrollTask; @@ -36,16 +34,12 @@ import static java.lang.Math.abs; import static java.util.stream.Collectors.toList; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX; import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.workflow.ReindexStep.NAME; import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -96,12 +90,6 @@ public void testReindexStep() throws ExecutionException, InterruptedException, I @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - PlainActionFuture future = reIndexStep.execute( inputData.getNodeId(), inputData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java index 7dade5607..e4ea939ea 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/UpdateIndexStepTests.java @@ -12,6 +12,7 @@ import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -78,6 +79,12 @@ public void testUpdateIndexStepWithUpdatedSettings() throws ExecutionException, return null; }).when(indicesAdminClient).getSettings(any(), any()); + doAnswer(invocation -> { + ActionListener ackResponseListener = invocation.getArgument(1); + ackResponseListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).updateSettings(any(), any()); + // validate update settings request content @SuppressWarnings({ "unchecked" }) ArgumentCaptor updateSettingsRequestCaptor = ArgumentCaptor.forClass(UpdateSettingsRequest.class); @@ -105,6 +112,63 @@ public void testUpdateIndexStepWithUpdatedSettings() throws ExecutionException, assertEquals(2, settingsToUpdate.size()); assertEquals("_none", settingsToUpdate.get("index.default_pipeline")); assertEquals("_none", settingsToUpdate.get("index.search.default_pipeline")); + + assertTrue(future.isDone()); + WorkflowData returnedData = (WorkflowData) future.get(); + assertEquals(Map.ofEntries(Map.entry(INDEX_NAME, indexName)), returnedData.getContent()); + assertEquals(data.getWorkflowId(), returnedData.getWorkflowId()); + assertEquals(data.getNodeId(), returnedData.getNodeId()); + } + + public void testFailedToUpdateIndexSettings() throws ExecutionException, InterruptedException, IOException { + + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + String indexName = "test-index"; + + // Create existing settings for default pipelines + Settings.Builder builder = Settings.builder(); + builder.put("index.number_of_shards", 2); + builder.put("index.number_of_replicas", 1); + builder.put("index.knn", true); + builder.put("index.default_pipeline", "ingest_pipeline_id"); + builder.put("index.search.default_pipeline", "search_pipeline_id"); + Map indexToSettings = new HashMap<>(); + indexToSettings.put(indexName, builder.build()); + + // Stub get index settings request/response + doAnswer(invocation -> { + ActionListener getSettingsResponseListener = invocation.getArgument(1); + getSettingsResponseListener.onResponse(new GetSettingsResponse(indexToSettings, indexToSettings)); + return null; + }).when(indicesAdminClient).getSettings(any(), any()); + + doAnswer(invocation -> { + ActionListener ackResponseListener = invocation.getArgument(1); + ackResponseListener.onFailure(new Exception("")); + return null; + }).when(indicesAdminClient).updateSettings(any(), any()); + + // Configurations has updated search/ingest pipeline default values of _none + String configurations = + "{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":2,\"number_of_replicas\":1,\"default_pipeline\":\"_none\",\"search\":{\"default_pipeline\":\"_none\"}}},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowData data = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)), + "test-id", + "test-node-id" + ); + PlainActionFuture future = updateIndexStep.execute( + data.getNodeId(), + data, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals("Failed to update the index settings for index test-index", exception.getCause().getMessage()); } public void testMissingSettings() throws InterruptedException { @@ -136,6 +200,55 @@ public void testMissingSettings() throws InterruptedException { ); } + public void testUpdateMixedSettings() throws InterruptedException { + UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); + + String indexName = "test-index"; + + // Create existing settings for default pipelines + Settings.Builder builder = Settings.builder(); + builder.put("index.number_of_shards", 2); + builder.put("index.number_of_replicas", 1); + builder.put("index.knn", true); + builder.put("index.default_pipeline", "ingest_pipeline_id"); + Map indexToSettings = new HashMap<>(); + indexToSettings.put(indexName, builder.build()); + + // Stub get index settings request/response + doAnswer(invocation -> { + ActionListener getSettingsResponseListener = invocation.getArgument(1); + getSettingsResponseListener.onResponse(new GetSettingsResponse(indexToSettings, indexToSettings)); + return null; + }).when(indicesAdminClient).getSettings(any(), any()); + + // validate update settings request content + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor updateSettingsRequestCaptor = ArgumentCaptor.forClass(UpdateSettingsRequest.class); + + // Configurations has updated ingest pipeline default values of _none. Settings have regular and full names + String configurations = + "{\"settings\":{\"index.knn\":true,\"default_pipeline\":\"_none\",\"index.number_of_shards\":2,\"index.number_of_replicas\":1},\"mappings\":{\"properties\":{\"age\":{\"type\":\"integer\"}}},\"aliases\":{\"sample-alias1\":{}}}"; + WorkflowData data = new WorkflowData( + Map.ofEntries(Map.entry(INDEX_NAME, indexName), Map.entry(CONFIGURATIONS, configurations)), + "test-id", + "test-node-id" + ); + PlainActionFuture future = updateIndexStep.execute( + data.getNodeId(), + data, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(indicesAdminClient, times(1)).getSettings(any(GetSettingsRequest.class), any()); + verify(indicesAdminClient, times(1)).updateSettings(updateSettingsRequestCaptor.capture(), any()); + + Settings settingsToUpdate = updateSettingsRequestCaptor.getValue().settings(); + assertEquals(1, settingsToUpdate.size()); + assertEquals("_none", settingsToUpdate.get("index.default_pipeline")); + } + public void testEmptyConfiguration() throws InterruptedException { UpdateIndexStep updateIndexStep = new UpdateIndexStep(client); diff --git a/src/test/resources/template/registerremotemodel-createindex.json b/src/test/resources/template/registerremotemodel-createindex.json new file mode 100644 index 000000000..3005eeed1 --- /dev/null +++ b/src/test/resources/template/registerremotemodel-createindex.json @@ -0,0 +1,84 @@ +{ + "name": "semantic search with local pretrained model", + "description": "Setting up semantic search, with a local pretrained embedding model", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "deploy": true + } + }, + { + "id": "create_index", + "type": "create_index", + "user_inputs": { + "index_name": "my-nlp-index", + "configurations": { + "settings": { + "index.knn": true, + "index.number_of_shards": "2" + }, + "mappings": { + "properties": { + "passage_embedding": { + "type": "knn_vector", + "dimension": "768", + "method": { + "engine": "lucene", + "space_type": "l2", + "name": "hnsw", + "parameters": {} + } + }, + "passage_text": { + "type": "text" + } + } + } + } + } + } + ] + } + } + } diff --git a/src/test/resources/template/registerremotemodel-ingestpipeline-createindex.json b/src/test/resources/template/registerremotemodel-ingestpipeline-createindex.json new file mode 100644 index 000000000..767da07b6 --- /dev/null +++ b/src/test/resources/template/registerremotemodel-ingestpipeline-createindex.json @@ -0,0 +1,111 @@ +{ + "name": "semantic search with local pretrained model", + "description": "Setting up semantic search, with a local pretrained embedding model", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_openai_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "nlp-ingest-pipeline", + "configurations": { + "description": "A text embedding pipeline", + "processors": [ + { + "text_embedding": { + "model_id": "${{register_openai_model.model_id}}", + "field_map": { + "passage_text": "passage_embedding" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "my-nlp-index", + "configurations": { + "settings": { + "index.knn": true, + "default_pipeline": "${{create_ingest_pipeline.pipeline_id}}", + "index.number_of_shards": "2" + }, + "mappings": { + "properties": { + "passage_embedding": { + "type": "knn_vector", + "dimension": "768", + "method": { + "engine": "lucene", + "space_type": "l2", + "name": "hnsw", + "parameters": {} + } + }, + "passage_text": { + "type": "text" + } + } + } + } + } + } + ] + } + } +} diff --git a/src/test/resources/template/registerremotemodel-ingestpipeline-updateindex.json b/src/test/resources/template/registerremotemodel-ingestpipeline-updateindex.json new file mode 100644 index 000000000..fc873ae66 --- /dev/null +++ b/src/test/resources/template/registerremotemodel-ingestpipeline-updateindex.json @@ -0,0 +1,111 @@ +{ + "name": "semantic search with local pretrained model", + "description": "Setting up semantic search, with a local pretrained embedding model", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_openai_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "nlp-ingest-pipeline", + "configurations": { + "description": "A text embedding pipeline", + "processors": [ + { + "text_embedding": { + "model_id": "${{register_openai_model.model_id}}", + "field_map": { + "passage_text": "passage_embedding" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "my-nlp-index", + "configurations": { + "settings": { + "index.knn": true, + "default_pipeline": "_none", + "index.number_of_shards": "2" + }, + "mappings": { + "properties": { + "passage_embedding": { + "type": "knn_vector", + "dimension": "768", + "method": { + "engine": "lucene", + "space_type": "l2", + "name": "hnsw", + "parameters": {} + } + }, + "passage_text": { + "type": "text" + } + } + } + } + } + } + ] + } + } +} diff --git a/src/test/resources/template/registerremotemodel-ingestpipeline.json b/src/test/resources/template/registerremotemodel-ingestpipeline.json new file mode 100644 index 000000000..dede163c1 --- /dev/null +++ b/src/test/resources/template/registerremotemodel-ingestpipeline.json @@ -0,0 +1,77 @@ +{ + "name": "semantic search with local pretrained model", + "description": "Setting up semantic search, with a local pretrained embedding model", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_openai_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "nlp-ingest-pipeline", + "configurations": { + "description": "A text embedding pipeline", + "processors": [ + { + "text_embedding": { + "model_id": "${{register_openai_model.model_id}}", + "field_map": { + "passage_text": "passage_embedding" + } + } + } + ] + } + } + } + ] + } + } +} diff --git a/src/test/resources/template/registerremotemodel.json b/src/test/resources/template/registerremotemodel.json new file mode 100644 index 000000000..58c520af4 --- /dev/null +++ b/src/test/resources/template/registerremotemodel.json @@ -0,0 +1,54 @@ +{ + "name": "semantic search with local pretrained model", + "description": "Setting up semantic search, with a local pretrained embedding model", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_openai_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for text embedding model", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo", + "response_filter": "$.choices[0].message.content" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "register_openai_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_openai_connector": "connector_id" + }, + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "deploy": true + } + } + ] + } + } + }