diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index dbfc17891..774660bfd 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -77,6 +77,8 @@ private CommonValue() {} public static final String INDEX_NAME = "index_name"; /** Type field */ public static final String TYPE = "type"; + /** default_mapping_option filed */ + public static final String DEFAULT_MAPPING_OPTION = "default_mapping_option"; /** ID Field */ public static final String ID = "id"; /** Pipeline Id field */ @@ -103,6 +105,8 @@ private CommonValue() {} public static final String MODEL_VERSION = "model_version"; /** Model Group Id field */ public static final String MODEL_GROUP_ID = "model_group_id"; + /** Model Group Id field */ + public static final String MODEL_GROUP_STATUS = "model_group_status"; /** Description field */ public static final String DESCRIPTION_FIELD = "description"; /** Connector Id field */ @@ -158,10 +162,10 @@ private CommonValue() {} public static final String USER_OUTPUTS_FIELD = "user_outputs"; /** The template field name for template resources created */ public static final String RESOURCES_CREATED_FIELD = "resources_created"; - /** The field name for the ResourceCreated's resource ID */ - public static final String RESOURCE_ID_FIELD = "resource_id"; - /** The field name for the ResourceCreated's resource name */ + /** The field name for the step name where a resource is created */ public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; + /** The field name for the step ID where a resource is created */ + public static final String WORKFLOW_STEP_ID = "workflow_step_id"; /** LLM Name for registering an agent */ public static final String LLM_FIELD = "llm"; /** The tools' field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java new file mode 100644 index 000000000..04a8650b2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; + +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Enum encapsulating the different step names and the resources they create + */ +public enum WorkflowResources { + + /** official workflow step name for creating a connector and associated created resource */ + CREATE_CONNECTOR("create_connector", "connector_id"), + /** official workflow step name for registering a remote model and associated created resource */ + REGISTER_REMOTE_MODEL("register_remote_model", "model_id"), + /** official workflow step name for registering a local model and associated created resource */ + REGISTER_LOCAL_MODEL("register_local_model", "model_id"), + /** official workflow step name for registering a model group and associated created resource */ + REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + /** official workflow step name for creating an ingest-pipeline and associated created resource */ + CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), + /** official workflow step name for creating an index and associated created resource */ + CREATE_INDEX("create_index", "index_name"); + + private final String workflowStep; + private final String resourceCreated; + private static final Logger logger = LogManager.getLogger(WorkflowResources.class); + private static final Set allResources = Stream.of(values()) + .map(WorkflowResources::getResourceCreated) + .collect(Collectors.toSet()); + + WorkflowResources(String workflowStep, String resourceCreated) { + this.workflowStep = workflowStep; + this.resourceCreated = resourceCreated; + } + + /** + * Returns the workflowStep for the given enum Constant + * @return the workflowStep of this data. + */ + public String getWorkflowStep() { + return workflowStep; + } + + /** + * Returns the resourceCreated for the given enum Constant + * @return the resourceCreated of this data. + */ + public String getResourceCreated() { + return resourceCreated; + } + + /** + * gets the resources created type based on the workflowStep + * @param workflowStep workflow step name + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static String getResourceByWorkflowStep(String workflowStep) throws FlowFrameworkException { + if (workflowStep != null && !workflowStep.isEmpty()) { + for (WorkflowResources mapping : values()) { + if (mapping.getWorkflowStep().equals(workflowStep)) { + return mapping.getResourceCreated(); + } + } + } + logger.error("Unable to find resource type for step: " + workflowStep); + throw new FlowFrameworkException("Unable to find resource type for step: " + workflowStep, RestStatus.BAD_REQUEST); + } + + /** + * Returns all the possible resource created types in enum + * @return a set of all the resource created types + */ + public static Set getAllResourcesCreated() { + return allResources; + } +} diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index cce4ba839..63df7824c 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -32,14 +32,17 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import java.io.IOException; import java.net.URL; @@ -435,6 +438,7 @@ public void updateFlowFrameworkSystemIndexDoc( updatedContent.putAll(updatedFields); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + updateRequest.retryOnConflict(3); // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { @@ -468,7 +472,8 @@ public void updateFlowFrameworkSystemIndexDocWithScript( // TODO: Also add ability to change other fields at the same time when adding detailed provision progress updateRequest.script(script); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - // TODO: decide what condition can be considered as an update conflict and add retry strategy + updateRequest.retryOnConflict(3); + // TODO: Implement our own concurrency control to improve on retry mechanism client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); @@ -478,4 +483,38 @@ public void updateFlowFrameworkSystemIndexDocWithScript( } } } + + /** + * Creates a new ResourceCreated object and a script to update the state index + * @param workflowId workflowId for the relevant step + * @param nodeId WorkflowData object with relevent step information + * @param workflowStepName the workflowstep name that created the resource + * @param resourceId the id of the newly created resource + * @param listener the ActionListener for this step to handle completing the future after update + * @throws IOException if parsing fails on new resource + */ + public void updateResourceInStateIndex( + String workflowId, + String nodeId, + String workflowStepName, + String resourceId, + ActionListener listener + ) throws IOException { + ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceId); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + + // The script to append a new object to the resources_created array + Script script = new Script( + ScriptType.INLINE, + "painless", + "ctx._source.resources_created.add(params.newResource)", + Collections.singletonMap("newResource", newResource) + ); + + updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> { + logger.info("updated resources created of {}", workflowId); + listener.onResponse(updateResponse); + }, exception -> { listener.onFailure(exception); })); + } } diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java index 0ec8f34d5..d039e2f8c 100644 --- a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -8,17 +8,22 @@ */ package org.opensearch.flowframework.model; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; import java.io.IOException; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_NAME; /** @@ -27,16 +32,21 @@ // TODO: create an enum to add the resource name itself for each step example (create_connector_step -> connector) public class ResourceCreated implements ToXContentObject, Writeable { + private static final Logger logger = LogManager.getLogger(ResourceCreated.class); + private final String workflowStepName; + private final String workflowStepId; private final String resourceId; /** - * Create this resources created object with given resource name and ID. + * Create this resources created object with given workflow step name, ID and resource ID. * @param workflowStepName The workflow step name associating to the step where it was created + * @param workflowStepId The workflow step ID associating to the step where it was created * @param resourceId The resources ID for relating to the created resource */ - public ResourceCreated(String workflowStepName, String resourceId) { + public ResourceCreated(String workflowStepName, String workflowStepId, String resourceId) { this.workflowStepName = workflowStepName; + this.workflowStepId = workflowStepId; this.resourceId = resourceId; } @@ -47,6 +57,7 @@ public ResourceCreated(String workflowStepName, String resourceId) { */ public ResourceCreated(StreamInput input) throws IOException { this.workflowStepName = input.readString(); + this.workflowStepId = input.readString(); this.resourceId = input.readString(); } @@ -54,13 +65,15 @@ public ResourceCreated(StreamInput input) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject() .field(WORKFLOW_STEP_NAME, workflowStepName) - .field(RESOURCE_ID_FIELD, resourceId); + .field(WORKFLOW_STEP_ID, workflowStepId) + .field(WorkflowResources.getResourceByWorkflowStep(workflowStepName), resourceId); return xContentBuilder.endObject(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowStepName); + out.writeString(workflowStepId); out.writeString(resourceId); } @@ -82,6 +95,15 @@ public String workflowStepName() { return workflowStepName; } + /** + * Gets the workflow step id associated to the created resource + * + * @return the workflowStepId. + */ + public String workflowStepId() { + return workflowStepId; + } + /** * Parse raw JSON content into a ResourceCreated instance. * @@ -91,6 +113,7 @@ public String workflowStepName() { */ public static ResourceCreated parse(XContentParser parser) throws IOException { String workflowStepName = null; + String workflowStepId = null; String resourceId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -102,22 +125,50 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { case WORKFLOW_STEP_NAME: workflowStepName = parser.text(); break; - case RESOURCE_ID_FIELD: - resourceId = parser.text(); + case WORKFLOW_STEP_ID: + workflowStepId = parser.text(); break; default: - throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + if (!isValidFieldName(fieldName)) { + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + } else { + if (fieldName.equals(WorkflowResources.getResourceByWorkflowStep(workflowStepName))) { + resourceId = parser.text(); + } + break; + } } } - if (workflowStepName == null || resourceId == null) { - throw new IOException("A ResourceCreated object requires both a workflowStepName and resourceId."); + if (workflowStepName == null) { + logger.error("Resource created object failed parsing: workflowStepName: {}", workflowStepName); + throw new FlowFrameworkException("A ResourceCreated object requires workflowStepName", RestStatus.BAD_REQUEST); + } + if (workflowStepId == null) { + logger.error("Resource created object failed parsing: workflowStepId: {}", workflowStepId); + throw new FlowFrameworkException("A ResourceCreated object requires workflowStepId", RestStatus.BAD_REQUEST); + } + if (resourceId == null) { + logger.error("Resource created object failed parsing: resourceId: {}", resourceId); + throw new FlowFrameworkException("A ResourceCreated object requires resourceId", RestStatus.BAD_REQUEST); } - return new ResourceCreated(workflowStepName, resourceId); + return new ResourceCreated(workflowStepName, workflowStepId, resourceId); + } + + private static boolean isValidFieldName(String fieldName) { + return (WORKFLOW_STEP_NAME.equals(fieldName) + || WORKFLOW_STEP_ID.equals(fieldName) + || WorkflowResources.getAllResourcesCreated().contains(fieldName)); } @Override public String toString() { - return "resources_Created [resource_name=" + workflowStepName + ", id=" + resourceId + "]"; + return "resources_Created [workflow_step_name= " + + workflowStepName + + ", workflow_step_id= " + + workflowStepName + + ", resource_id= " + + resourceId + + "]"; } } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index e49d7d68a..eb1779e93 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -17,7 +17,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; /** - * This represents the an object of workflow steps json which maps each step to expected inputs and outputs + * This represents an object of workflow steps json which maps each step to expected inputs and outputs */ public class WorkflowStepValidator { diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index b381b41ec..da9643cb5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -217,10 +217,10 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage(), exception); }) ); } catch (Exception ex) { - logger.error("Provisioning failed for workflow {} : {}", workflowId, ex); + logger.error("Provisioning failed for workflow: {}", workflowId, ex); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, ImmutableMap.of( @@ -235,7 +235,7 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage()); }) + }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage(), ex); }) ); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index bc4132087..dc4c83d4e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,21 +11,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; -import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; -import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import java.io.IOException; import java.security.AccessController; @@ -48,7 +43,6 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** @@ -61,7 +55,7 @@ public class CreateConnectorStep implements WorkflowStep { private MachineLearningNodeClient mlClient; private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - static final String NAME = "create_connector"; + static final String NAME = WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(); /** * Instantiate this class @@ -87,44 +81,35 @@ public CompletableFuture execute( @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - String workflowId = currentNodeInputs.getWorkflowId(); - createConnectorFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), - workflowId, - currentNodeInputs.getNodeId() - ) - ); + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); logger.info("Created connector successfully"); - String workflowStepName = getName(); - ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); - - // The script to append a new object to the resources_created array - Script script = new Script( - ScriptType.INLINE, - "painless", - "ctx._source.resources_created.add(params.newResource)", - Collections.singletonMap("newResource", newResource) - ); - - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( - WORKFLOW_STATE_INDEX, - workflowId, - script, - ActionListener.wrap(updateResponse -> { - logger.info("updated resources created of {}", workflowId); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + getName(), + mlCreateConnectorResponse.getConnectorId(), + ActionListener.wrap(response -> { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + createConnectorFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); }, exception -> { + logger.error("Failed to update new created resource", exception); createConnectorFuture.completeExceptionally( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); - logger.error("Failed to update workflow state with newly created resource: {}", exception); }) ); - } catch (IOException e) { - logger.error("Failed to parse new created resource", e); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 07246134a..0ace57dc3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.client.Client; @@ -17,6 +18,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import java.util.ArrayList; @@ -26,8 +29,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; -import static org.opensearch.flowframework.common.CommonValue.INDEX_NAME; -import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.common.CommonValue.DEFAULT_MAPPING_OPTION; /** * Step to create an index @@ -37,21 +39,23 @@ public class CreateIndexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); private final ClusterService clusterService; private final Client client; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ - static final String NAME = "create_index"; + static final String NAME = WorkflowResources.CREATE_INDEX.getWorkflowStep(); static Map indexMappingUpdated = new HashMap<>(); - private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); /** * Instantiate this class * * @param clusterService The OpenSearch cluster service * @param client Client to create an index + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateIndexStep(ClusterService clusterService, Client client) { + public CreateIndexStep(ClusterService clusterService, Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.clusterService = clusterService; this.client = client; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -61,30 +65,50 @@ public CompletableFuture execute( Map outputs, Map previousNodeInputs ) { - CompletableFuture future = new CompletableFuture<>(); + CompletableFuture createIndexFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(CreateIndexResponse createIndexResponse) { - logger.info("created index: {}", createIndexResponse.index()); - future.complete( - new WorkflowData( - Map.of(INDEX_NAME, createIndexResponse.index()), + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + logger.info("created index: {}", createIndexResponse.index()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + createIndexResponse.index(), + ActionListener.wrap(response -> { + logger.info("successfully updated resource created in state index: {}", response.getIndex()); + createIndexFuture.complete( + new WorkflowData( + Map.of(resourceName, createIndexResponse.index()), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + createIndexFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override public void onFailure(Exception e) { logger.error("Failed to create an index", e); - future.completeExceptionally(e); + createIndexFuture.completeExceptionally(e); } }; String index = null; - String type = null; + String defaultMappingOption = null; Settings settings = null; // TODO: Recreating the list to get this compiling @@ -93,13 +117,18 @@ public void onFailure(Exception e) { data.add(currentNodeInputs); data.addAll(outputs.values()); - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - index = (String) content.get(INDEX_NAME); - type = (String) content.get(TYPE); - if (index != null && type != null && settings != null) { - break; + try { + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + index = (String) content.get(WorkflowResources.getResourceByWorkflowStep(getName())); + defaultMappingOption = (String) content.get(DEFAULT_MAPPING_OPTION); + if (index != null && defaultMappingOption != null && settings != null) { + break; + } } + } catch (Exception e) { + logger.error("Failed to find the correct resource for the workflow step", e); + createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } // TODO: @@ -107,7 +136,7 @@ public void onFailure(Exception e) { try { CreateIndexRequest request = new CreateIndexRequest(index).mapping( - FlowFrameworkIndicesHandler.getIndexMappings("mappings/" + type + ".json"), + FlowFrameworkIndicesHandler.getIndexMappings("mappings/" + defaultMappingOption + ".json"), JsonXContent.jsonXContent.mediaType() ); client.admin().indices().create(request, actionListener); @@ -115,7 +144,7 @@ public void onFailure(Exception e) { logger.error("Failed to find the right mapping for the index", e); } - return future; + return createIndexFuture; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 77dae29eb..352772a49 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -18,6 +19,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import java.io.IOException; import java.util.ArrayList; @@ -33,7 +37,6 @@ import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; import static org.opensearch.flowframework.common.CommonValue.TYPE; @@ -45,18 +48,21 @@ public class CreateIngestPipelineStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ - static final String NAME = "create_ingest_pipeline"; + static final String NAME = WorkflowResources.CREATE_INGEST_PIPELINE.getWorkflowStep(); // Client to store a pipeline in the cluster state private final ClusterAdminClient clusterAdminClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + /** * Instantiates a new CreateIngestPipelineStep - * * @param client The client to create a pipeline and store workflow data into the global context index + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateIngestPipelineStep(Client client) { + public CreateIngestPipelineStep(Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.clusterAdminClient = client.admin().cluster(); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -136,16 +142,38 @@ public CompletableFuture execute( clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); - // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete( - new WorkflowData( - Map.of(PIPELINE_ID, putPipelineRequest.getId()), + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + putPipelineRequest.getId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead + // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here + createIngestPipelineFuture.complete( + new WorkflowData( + Map.of(resourceName, putPipelineRequest.getId()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + createIngestPipelineFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); - // TODO : Use node client to index response data to global context (pending global context index implementation) + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createIngestPipelineFuture.completeExceptionally( + new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)) + ); + } }, exception -> { logger.error("Failed to create ingest pipeline : " + exception.getMessage()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 22c6ae810..50ae30986 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -31,6 +33,7 @@ import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; /** @@ -42,14 +45,18 @@ public class ModelGroupStep implements WorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "model_group"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_MODEL_GROUP.getWorkflowStep(); /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public ModelGroupStep(MachineLearningNodeClient mlClient) { + public ModelGroupStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -65,17 +72,38 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { - logger.info("Model group registration successful"); - registerModelGroupFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), - Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) - ), + try { + logger.info("Remote Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + mlRegisterModelGroupResponse.getModelGroupId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + registerModelGroupFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, mlRegisterModelGroupResponse.getModelGroupId()), + Map.entry(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerModelGroupFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index cc5645306..3dc730b54 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -16,7 +16,9 @@ import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; @@ -42,7 +44,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -58,17 +59,26 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "register_local_model"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_LOCAL_MODEL.getWorkflowStep(); /** * Instantiate this class * @param settings The OpenSearch settings * @param clusterService The cluster service * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public RegisterLocalModelStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) { + public RegisterLocalModelStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { super(settings, clusterService); this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -218,7 +228,7 @@ public String getName() { * Retryable get ml task * @param workflowId the workflow id * @param nodeId the workflow node id - * @param getMLTaskFuture the workflow step future + * @param registerLocalModelFuture the workflow step future * @param taskId the ml task id * @param retries the current number of request retries */ @@ -242,17 +252,38 @@ void retryableGetMlTask( throw new IllegalStateException("Local model registration is not yet completed"); } } else { - logger.info("Local model registeration successful"); - registerLocalModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(MODEL_ID, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), + try { + logger.info("Local Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( workflowId, - nodeId - ) - ); + nodeId, + getName(), + response.getTaskId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } }, exception -> { if (retries < maxRetry) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 27a77cb98..7e33937bc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -32,7 +34,6 @@ import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -45,14 +46,18 @@ public class RegisterRemoteModelStep implements WorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "register_remote_model"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(); /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { + public RegisterRemoteModelStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -68,17 +73,39 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("Remote Model registration successful"); - registerRemoteModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ), + + try { + logger.info("Remote Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + mlRegisterModelResponse.getModelId(), + ActionListener.wrap(response -> { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + registerRemoteModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, mlRegisterModelResponse.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerRemoteModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index bac65c23a..ce0b24d24 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -43,14 +43,17 @@ public WorkflowStepFactory( FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); - stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); - stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(settings, clusterService, mlClient)); - stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); + stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); + stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); + stepMap.put( + RegisterLocalModelStep.NAME, + new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) + ); + stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); - stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, new ToolStep()); stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); } diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 21df5ccd6..86fbeef6e 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,15 +31,7 @@ "type": "object" }, "resources_created": { - "type": "nested", - "properties": { - "workflow_step_name": { - "type": "keyword" - }, - "resource_id": { - "type": "keyword" - } - } + "type": "object" }, "ui_metadata": { "type": "object", diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 5840b0906..eb92ccd5e 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -6,7 +6,7 @@ "create_index": { "inputs":[ "index_name", - "type" + "default_mapping_option" ], "outputs":[ "index_name" @@ -83,7 +83,7 @@ "deploy_model_status" ] }, - "model_group": { + "register_model_group": { "inputs":[ "name" ], diff --git a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java index a8536dd43..216c18c9e 100644 --- a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java +++ b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -20,17 +21,21 @@ public void setUp() throws Exception { } public void testParseFeature() throws IOException { - ResourceCreated ResourceCreated = new ResourceCreated("A", "B"); - assertEquals(ResourceCreated.workflowStepName(), "A"); - assertEquals(ResourceCreated.resourceId(), "B"); - - String expectedJson = "{\"workflow_step_name\":\"A\",\"resource_id\":\"B\"}"; + String workflowStepName = WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(); + ResourceCreated ResourceCreated = new ResourceCreated(workflowStepName, "workflow_step_1", "L85p1IsBbfF"); + assertEquals(ResourceCreated.workflowStepName(), workflowStepName); + assertEquals(ResourceCreated.workflowStepId(), "workflow_step_1"); + assertEquals(ResourceCreated.resourceId(), "L85p1IsBbfF"); + + String expectedJson = + "{\"workflow_step_name\":\"create_connector\",\"workflow_step_id\":\"workflow_step_1\",\"connector_id\":\"L85p1IsBbfF\"}"; String json = TemplateTestJsonUtil.parseToJson(ResourceCreated); assertEquals(expectedJson, json); ResourceCreated ResourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); - assertEquals("A", ResourceCreatedTwo.workflowStepName()); - assertEquals("B", ResourceCreatedTwo.resourceId()); + assertEquals(workflowStepName, ResourceCreatedTwo.workflowStepName()); + assertEquals("workflow_step_1", ResourceCreatedTwo.workflowStepId()); + assertEquals("L85p1IsBbfF", ResourceCreatedTwo.resourceId()); } public void testExceptions() throws IOException { @@ -40,7 +45,7 @@ public void testExceptions() throws IOException { String missingJson = "{\"resource_id\":\"B\"}"; e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); - assertEquals("A ResourceCreated object requires both a workflowStepName and resourceId.", e.getMessage()); + assertEquals("Unable to parse field [resource_id] in a resources_created object.", e.getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 1135a0ca6..09c6c3c68 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,7 +8,9 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -28,7 +30,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -87,6 +92,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), inputData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 8be5c5787..9d4e79335 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -21,9 +22,12 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -35,8 +39,12 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -64,12 +72,18 @@ public class CreateIndexStepTests extends OpenSearchTestCase { private ThreadPool threadPool; @Mock IndexMetadata indexMetadata; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id", "test-node-id"); + inputData = new WorkflowData( + Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("default_mapping_option", "knn")), + "test-id", + "test-node-id" + ); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -84,11 +98,18 @@ public void setUp() throws Exception { when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test cluster")).build()); when(metadata.indices()).thenReturn(Map.of(GLOBAL_CONTEXT_INDEX, indexMetadata)); - createIndexStep = new CreateIndexStep(clusterService, client); + createIndexStep = new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler); CreateIndexStep.indexMappingUpdated = indexMappingUpdated; } - public void testCreateIndexStep() throws ExecutionException, InterruptedException { + public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIndexStep.execute( diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index f0a970758..1c7940949 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -10,12 +10,16 @@ import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -23,7 +27,11 @@ import org.mockito.ArgumentCaptor; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -37,10 +45,12 @@ public class CreateIngestPipelineStepTests extends OpenSearchTestCase { private Client client; private AdminClient adminClient; private ClusterAdminClient clusterAdminClient; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); inputData = new WorkflowData( Map.ofEntries( @@ -66,9 +76,15 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); } - public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException { + public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException, IOException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -91,7 +107,7 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio public void testCreateIngestPipelineStepFailure() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -116,7 +132,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { } public void testMissingData() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index bc914baa7..d78a97e8a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -9,9 +9,12 @@ package org.opensearch.flowframework.workflow; import com.google.common.collect.ImmutableList; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLTaskState; @@ -29,8 +32,12 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; public class ModelGroupStepTests extends OpenSearchTestCase { @@ -40,10 +47,12 @@ public class ModelGroupStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); inputData = new WorkflowData( Map.ofEntries( @@ -63,7 +72,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep String modelGroupId = "model_group_id"; String status = MLTaskState.CREATED.name(); - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -75,6 +84,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = modelGroupStep.execute( inputData.getNodeId(), inputData, @@ -90,8 +105,8 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep } - public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException { - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + public void testRegisterModelGroupFailure() throws IOException { + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -119,7 +134,7 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt } public void testRegisterModelGroupWithNoName() throws IOException { - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); CompletableFuture future = modelGroupStep.execute( inputDataWithNoName.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d169812a9..c38f8a120 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -10,12 +10,15 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; @@ -34,10 +37,13 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -49,6 +55,7 @@ public class RegisterLocalModelStepTests extends OpenSearchTestCase { private RegisterLocalModelStep registerLocalModelStep; private WorkflowData workflowData; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -56,7 +63,7 @@ public class RegisterLocalModelStepTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); ClusterService clusterService = mock(ClusterService.class); final Set> settingsSet = Stream.concat( @@ -69,7 +76,12 @@ public void setUp() throws Exception { ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.registerLocalModelStep = new RegisterLocalModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient); + this.registerLocalModelStep = new RegisterLocalModelStep( + testMaxRetrySetting, + clusterService, + machineLearningNodeClient, + flowFrameworkIndicesHandler + ); this.workflowData = new WorkflowData( Map.ofEntries( @@ -127,6 +139,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).getTask(any(), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index cde194326..a83443f05 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -10,8 +10,11 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -26,10 +29,14 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -38,6 +45,7 @@ public class RegisterRemoteModelStepTests extends OpenSearchTestCase { private RegisterRemoteModelStep registerRemoteModelStep; private WorkflowData workflowData; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Mock MachineLearningNodeClient mlNodeClient; @@ -45,9 +53,9 @@ public class RegisterRemoteModelStepTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient); + this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient, flowFrameworkIndicesHandler); this.workflowData = new WorkflowData( Map.ofEntries( Map.entry("function_name", "remote"), @@ -73,6 +81,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), workflowData,