From 5afa4bc3c33bcc37595f4ba61f25f8116630bd97 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 29 Nov 2023 19:34:57 -0800 Subject: [PATCH 01/27] [Feature/agent_framework] Registers a single agent with multiple tools (#198) * Initial register agent workflow step Signed-off-by: Owais Kazi * Added tools step Signed-off-by: Owais Kazi * Fixed ClassCastException Signed-off-by: Owais Kazi * Handled exception for Instant Signed-off-by: Owais Kazi * Added type Instant for WorklowNode Parser Signed-off-by: Owais Kazi * Removed created and last updated time Signed-off-by: Owais Kazi * Addressed parsing error Signed-off-by: Owais Kazi * Handled parsing of Long values for Instant Signed-off-by: Owais Kazi * Handled nested object for llm key Signed-off-by: Owais Kazi * Handled parsing error Signed-off-by: Owais Kazi * Another attempt to fix parsing error for llm Signed-off-by: Owais Kazi * Another attemp to fix XContent Signed-off-by: Owais Kazi * Fixed Parsing error Signed-off-by: Owais Kazi * Added tests for toolstep and javadocs Signed-off-by: Owais Kazi * Undo CI changes Signed-off-by: Owais Kazi * Addressing PR comments Signed-off-by: Owais Kazi * Addressing PR comments Signed-off-by: Owais Kazi * Handled interface changes Signed-off-by: Owais Kazi * Addressed conflicts Signed-off-by: Owais Kazi * Added TODO Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- build.gradle | 1 + .../flowframework/common/CommonValue.java | 17 +- .../flowframework/model/Template.java | 2 +- .../flowframework/model/WorkflowNode.java | 35 ++- .../flowframework/util/ParseUtils.java | 99 +++++++ .../workflow/CreateConnectorStep.java | 9 +- .../workflow/CreateIndexStep.java | 4 +- .../workflow/DeployModelStep.java | 2 +- .../workflow/ModelGroupStep.java | 2 +- .../workflow/RegisterAgentStep.java | 241 ++++++++++++++++++ .../workflow/RegisterLocalModelStep.java | 2 +- .../workflow/RegisterRemoteModelStep.java | 2 +- .../flowframework/workflow/ToolStep.java | 127 +++++++++ .../flowframework/workflow/WorkflowData.java | 2 +- .../workflow/WorkflowStepFactory.java | 2 + .../resources/mappings/workflow-steps.json | 24 ++ .../model/WorkflowNodeTests.java | 22 +- .../flowframework/util/ParseUtilsTests.java | 19 ++ .../workflow/RegisterAgentTests.java | 130 ++++++++++ .../flowframework/workflow/ToolStepTests.java | 53 ++++ 20 files changed, 773 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/ToolStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java diff --git a/build.gradle b/build.gradle index 20a9837ca..d6e7afbd8 100644 --- a/build.gradle +++ b/build.gradle @@ -155,6 +155,7 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' + implementation "com.google.code.gson:gson:2.10.1" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 90f208c8d..dbfc17891 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -162,5 +162,20 @@ private CommonValue() {} public static final String RESOURCE_ID_FIELD = "resource_id"; /** The field name for the ResourceCreated's resource name */ public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; - + /** LLM Name for registering an agent */ + public static final String LLM_FIELD = "llm"; + /** The tools' field for an agent */ + public static final String TOOLS_FIELD = "tools"; + /** The memory field for an agent */ + public static final String MEMORY_FIELD = "memory"; + /** The app type field for an agent */ + public static final String APP_TYPE_FIELD = "app_type"; + /** The agent id of an agent */ + public static final String AGENT_ID = "agent_id"; + /** To include field for an agent response */ + public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + /** The created time field for an agent */ + public static final String CREATED_TIME = "created_time"; + /** The last updated time field for an agent */ + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index bfb40b696..f4a8b1958 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -229,12 +229,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (Entry e : workflows.entrySet()) { xContentBuilder.field(e.getKey(), e.getValue(), params); } + xContentBuilder.endObject(); if (uiMetadata != null && !uiMetadata.isEmpty()) { xContentBuilder.field(UI_METADATA_FIELD, uiMetadata); } - xContentBuilder.endObject(); if (user != null) { xContentBuilder.field(USER_FIELD, user); } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 7d04a5a3f..999ba460f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -14,6 +14,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.util.ArrayList; @@ -24,7 +25,10 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; +import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseLLM; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -34,7 +38,6 @@ * and its inputs are used to populate the {@link WorkflowData} input. */ public class WorkflowNode implements ToXContentObject { - /** The template field name for node id */ public static final String ID_FIELD = "id"; /** The template field name for node type */ @@ -82,7 +85,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.startObject(USER_INPUTS_FIELD); for (Entry e : userInputs.entrySet()) { xContentBuilder.field(e.getKey()); - if (e.getValue() instanceof String) { + if (e.getValue() instanceof String || e.getValue() instanceof Number) { xContentBuilder.value(e.getValue()); } else if (e.getValue() instanceof Map) { buildStringToStringMap(xContentBuilder, (Map) e.getValue()); @@ -98,6 +101,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } xContentBuilder.endArray(); + } else if (e.getValue() instanceof LLMSpec) { + if (LLM_FIELD.equals(e.getKey())) { + xContentBuilder.startObject(); + buildLLMMap(xContentBuilder, (LLMSpec) e.getValue()); + xContentBuilder.endObject(); + } } } xContentBuilder.endObject(); @@ -141,7 +150,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - userInputs.put(inputFieldName, parseStringToStringMap(parser)); + if (LLM_FIELD.equals(inputFieldName)) { + userInputs.put(inputFieldName, parseLLM(parser)); + } else { + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + } break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { @@ -158,6 +171,22 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } break; + case VALUE_NUMBER: + switch (parser.numberType()) { + case INT: + userInputs.put(inputFieldName, parser.intValue()); + break; + case LONG: + userInputs.put(inputFieldName, parser.longValue()); + break; + case FLOAT: + userInputs.put(inputFieldName, parser.floatValue()); + break; + case DOUBLE: + userInputs.put(inputFieldName, parser.doubleValue()); + break; + } + break; default: throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object."); } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 0f725f687..14d113e34 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.util; +import com.google.gson.Gson; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -21,14 +22,21 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; /** * Utility methods for Template parsing @@ -36,6 +44,12 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); + public static final Gson gson; + + static { + gson = new Gson(); + } + private ParseUtils() {} /** @@ -70,6 +84,21 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map parameters = llm.getParameters(); + xContentBuilder.field(MODEL_ID, modelId); + xContentBuilder.field(PARAMETERS_FIELD); + buildStringToStringMap(xContentBuilder, parameters); + } + /** * Parses an XContent object representing a map of String keys to String values. * @@ -88,6 +117,37 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } + // TODO Figure out a way to use the parse method of LLMSpec of ml-commons + /** + * Parses an XContent object representing the object of LLMSpec + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return instance of {@link org.opensearch.ml.common.agent.LLMSpec} + * @throws IOException parsing error + */ + public static LLMSpec parseLLM(XContentParser parser) throws IOException { + String modelId = null; + Map parameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID: + modelId = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + return LLMSpec.builder().modelId(modelId).parameters(parameters).build(); + } + /** * Parse content parser to {@link java.time.Instant}. * @@ -116,6 +176,31 @@ public static User getUserContext(Client client) { return User.parse(userStr); } + /** + * Generates a parameter map required when the parameter is nested within an object + * @param parameterObjs parameters + * @return a parameters map of type String,String + */ + public static Map getParameterMap(Map parameterObjs) { + Map parameters = new HashMap<>(); + for (String key : parameterObjs.keySet()) { + Object value = parameterObjs.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + parameters.put(key, (String) value); + } else { + parameters.put(key, gson.toJson(value)); + } + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + /** * Creates a XContentParser from a given Registry * @@ -129,4 +214,18 @@ public static XContentParser createXContentParserFromRegistry(NamedXContentRegis return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); } + /** + * Generates a string to string Map + * @param map content map + * @param fieldName fieldName + * @return instance of the map + */ + @SuppressWarnings("unchecked") + public static Map getStringToStringMap(Object map, String fieldName) { + if (map instanceof Map) { + return (Map) map; + } + throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index b00857ff6..bc4132087 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -49,6 +49,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** * Step to create a connector for a remote model @@ -210,14 +211,6 @@ public String getName() { return NAME; } - @SuppressWarnings("unchecked") - private static Map getStringToStringMap(Object map, String fieldName) { - if (map instanceof Map) { - return (Map) map; - } - throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); - } - private static Map getParameterMap(Object parameterMap) throws PrivilegedActionException { Map parameters = new HashMap<>(); for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f443e9c2c..07246134a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -35,8 +35,8 @@ public class CreateIndexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); - private ClusterService clusterService; - private Client client; + private final ClusterService clusterService; + private final Client client; /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ static final String NAME = "create_index"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index aa6768605..81409ef77 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -30,7 +30,7 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "deploy_model"; /** diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 325c3edb8..22c6ae810 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -40,7 +40,7 @@ public class ModelGroupStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ModelGroupStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "model_group"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java new file mode 100644 index 000000000..44270d8e6 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLAgent.MLAgentBuilder; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; +import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; + +/** + * Step to register an agent + */ +public class RegisterAgentStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterAgentStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "register_agent"; + + private List mlToolSpecList; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public RegisterAgentStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + this.mlToolSpecList = new ArrayList<>(); + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + + CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { + logger.info("Remote Agent registration successful"); + registerAgentModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register the agent"); + registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String name = null; + String type = null; + String description = null; + LLMSpec llm = null; + List tools = new ArrayList<>(); + Map parameters = Collections.emptyMap(); + MLMemorySpec memory = null; + Instant createdTime = null; + Instant lastUpdateTime = null; + String appType = null; + + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) entry.getValue(); + break; + case DESCRIPTION_FIELD: + description = (String) entry.getValue(); + break; + case TYPE: + type = (String) entry.getValue(); + break; + case LLM_FIELD: + llm = getLLMSpec(entry.getValue()); + break; + case TOOLS_FIELD: + tools = addTools(entry.getValue()); + break; + case PARAMETERS_FIELD: + parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD); + break; + case MEMORY_FIELD: + memory = getMLMemorySpec(entry.getValue()); + break; + case CREATED_TIME: + createdTime = Instant.ofEpochMilli((Long) entry.getValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue()); + break; + case APP_TYPE_FIELD: + appType = (String) entry.getValue(); + break; + default: + break; + } + } + } + + if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) { + MLAgentBuilder builder = MLAgent.builder().name(name); + + if (description != null) { + builder.description(description); + } + + builder.type(type) + .llm(llm) + .tools(tools) + .parameters(parameters) + .memory(memory) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .appType(appType); + + MLAgent mlAgent = builder.build(); + + mlClient.registerAgent(mlAgent, actionListener); + + } else { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); + } + + return registerAgentModelFuture; + } + + @Override + public String getName() { + return NAME; + } + + private List addTools(Object tools) { + MLToolSpec mlToolSpec = (MLToolSpec) tools; + mlToolSpecList.add(mlToolSpec); + return mlToolSpecList; + } + + private LLMSpec getLLMSpec(Object llm) { + if (llm instanceof LLMSpec) { + return (LLMSpec) llm; + } + throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec."); + } + + private MLMemorySpec getMLMemorySpec(Object mlMemory) { + + Map map = (Map) mlMemory; + String type = null; + String sessionId = null; + Integer windowSize = null; + type = (String) map.get(MLMemorySpec.MEMORY_TYPE_FIELD); + if (type == null) { + throw new IllegalArgumentException("agent name is null"); + } + sessionId = (String) map.get(MLMemorySpec.SESSION_ID_FIELD); + windowSize = (Integer) map.get(MLMemorySpec.WINDOW_SIZE_FIELD); + + @SuppressWarnings("unchecked") + MLMemorySpec.MLMemorySpecBuilder builder = MLMemorySpec.builder(); + + builder.type(type); + if (sessionId != null) { + builder.sessionId(sessionId); + } + if (windowSize != null) { + builder.windowSize(windowSize); + } + + MLMemorySpec mlMemorySpec = builder.build(); + return mlMemorySpec; + + } + + private Instant getInstant(Object instant, String fieldName) { + if (instant instanceof Instant) { + return (Instant) instant; + } + throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant."); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 19229efd1..cc5645306 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -56,7 +56,7 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "register_local_model"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index e41323a14..27a77cb98 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -43,7 +43,7 @@ public class RegisterRemoteModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterRemoteModelStep.class); - private MachineLearningNodeClient mlClient; + private final MachineLearningNodeClient mlClient; static final String NAME = "register_remote_model"; diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java new file mode 100644 index 000000000..339142139 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; + +/** + * Step to register a tool for an agent + */ +public class ToolStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(ToolStep.class); + CompletableFuture toolFuture = new CompletableFuture<>(); + static final String NAME = "create_tool"; + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + String type = null; + String name = null; + String description = null; + Map parameters = Collections.emptyMap(); + Boolean includeOutputInAgentResponse = null; + + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case TYPE: + type = (String) content.get(TYPE); + break; + case NAME_FIELD: + name = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); + break; + case PARAMETERS_FIELD: + parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + break; + case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: + includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + break; + default: + break; + } + + } + + } + + if (type == null) { + toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST)); + } else { + MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); + + builder.type(type); + if (name != null) { + builder.name(name); + } + if (description != null) { + builder.description(description); + } + if (parameters != null) { + builder.parameters(parameters); + } + if (includeOutputInAgentResponse != null) { + builder.includeOutputInAgentResponse(includeOutputInAgentResponse); + } + + MLToolSpec mlToolSpec = builder.build(); + + toolFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + logger.info("Tool registered successfully {}", type); + return toolFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index a0d901f74..ba19823a7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -66,7 +66,7 @@ public WorkflowData(Map content, Map params, @Nu */ public Map getContent() { return this.content; - }; + } /** * Returns a map represents the params associated with a Rest API request, parsed from the URI. diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index b95a0449d..c9e565bba 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -62,6 +62,8 @@ private void populateMap( stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); + stepMap.put(ToolStep.NAME, new ToolStep()); + stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 6256189c1..c9794d4ea 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -83,5 +83,29 @@ "model_group_id", "model_group_status" ] + }, + "register_agent": { + "inputs":[ + "name", + "type", + "llm", + "tools", + "parameters", + "memory", + "created_time", + "last_updated_time", + "app_type" + ], + "outputs":[ + "agent_id" + ] + }, + "create_tool": { + "inputs": [ + "type" + ], + "outputs": [ + "tools" + ] } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index 700e1d0d2..c0011f7ae 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -8,9 +8,11 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.HashMap; import java.util.Map; public class WorkflowNodeTests extends OpenSearchTestCase { @@ -21,6 +23,12 @@ public void setUp() throws Exception { } public void testNode() throws IOException { + Map parameters = new HashMap<>(); + parameters.put("stop", "true"); + parameters.put("max", "5"); + + LLMSpec llmSpec = new LLMSpec("modelId", parameters); + WorkflowNode nodeA = new WorkflowNode( "A", "a-type", @@ -29,7 +37,9 @@ public void testNode() throws IOException { Map.entry("foo", "a string"), Map.entry("bar", Map.of("key", "value")), Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), - Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }) + Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), + Map.entry("llm", llmSpec), + Map.entry("created_time", 1689793598499L) ) ); assertEquals("A", nodeA.id()); @@ -43,6 +53,7 @@ public void testNode() throws IOException { assertEquals(1, pp.length); assertEquals("test-type", pp[0].type()); assertEquals(Map.of("key2", "value2"), pp[0].params()); + assertEquals(1689793598499L, map.get("created_time")); // node equality is based only on ID WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); @@ -52,13 +63,17 @@ public void testNode() throws IOException { assertNotEquals(nodeA, nodeB); String json = TemplateTestJsonUtil.parseToJson(nodeA); - logger.info("TESTING : " + json); + logger.info("JSON : " + json); assertTrue(json.startsWith("{\"id\":\"A\",\"type\":\"a-type\",\"previous_node_inputs\":{\"foo\":\"field\"},")); assertTrue(json.contains("\"user_inputs\":{")); assertTrue(json.contains("\"foo\":\"a string\"")); assertTrue(json.contains("\"baz\":[{\"A\":\"a\"},{\"B\":\"b\"}]")); assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); + assertTrue(json.contains("\"created_time\":1689793598499")); + assertTrue(json.contains("llm\":{")); + assertTrue(json.contains("\"parameters\":{\"stop\":\"true\",\"max\":\"5\"")); + assertTrue(json.contains("\"model_id\":\"modelId\"")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); @@ -73,6 +88,9 @@ public void testNode() throws IOException { assertEquals(1, ppX.length); assertEquals("test-type", ppX[0].type()); assertEquals(Map.of("key2", "value2"), ppX[0].params()); + LLMSpec llm = (LLMSpec) mapX.get("llm"); + assertEquals("modelId", llm.getModelId()); + assertEquals(parameters, llm.getParameters()); } public void testExceptions() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index a5c4253b3..76334b52b 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -12,9 +12,12 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.junit.Assert; import java.io.IOException; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -54,4 +57,20 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } + + public void testGetParameterMap() { + Map parameters = new HashMap<>(); + parameters.put("key1", "value1"); + parameters.put("key2", 2); + parameters.put("key3", 2.1); + parameters.put("key4", new int[] { 10, 20 }); + parameters.put("key5", new Object[] { 1.01, "abc" }); + Map parameterMap = ParseUtils.getParameterMap(parameters); + Assert.assertEquals(5, parameterMap.size()); + Assert.assertEquals("value1", parameterMap.get("key1")); + Assert.assertEquals("2", parameterMap.get("key2")); + Assert.assertEquals("2.1", parameterMap.get("key3")); + Assert.assertEquals("[10,20]", parameterMap.get("key4")); + Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java new file mode 100644 index 000000000..0f4b33471 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class RegisterAgentTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); + + LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); + + Map mlMemorySpec = Map.ofEntries( + Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"), + Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"), + Map.entry(MLMemorySpec.WINDOW_SIZE_FIELD, 2) + ); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("type", "type"), + Map.entry("llm", llmSpec), + Map.entry("tools", tools), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("memory", mlMemorySpec), + Map.entry("created_time", 1689793598499L), + Map.entry("last_updated_time", 1689793598499L), + Map.entry("app_type", "app") + ), + "test-id", + "test-node-id" + ); + } + + public void testRegisterAgent() throws IOException, ExecutionException, InterruptedException { + String agentId = "agent_id"; + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(agentId, future.get().getContent().get("agent_id")); + } + + public void testRegisterAgentFailure() throws IOException { + String agentId = "agent_id"; + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register the agent", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register the agent", ex.getCause().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java new file mode 100644 index 000000000..c7e8df2d8 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +public class ToolStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", false) + ), + "test-id", + "test-node-id" + ); + } + + public void testTool() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + CompletableFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + } +} From f9c4622d7556bbc440ababec054829261fe88762 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 1 Dec 2023 10:34:10 -0800 Subject: [PATCH 02/27] [Feature/agent_framework] Add Delete Connector Step (#211) * Add Delete Connector Step Signed-off-by: Daniel Widdis * Add eclipse core runtime version resolution Signed-off-by: Daniel Widdis * Use JDK17 for spotless Signed-off-by: Daniel Widdis * Add Delete Connector Step Signed-off-by: Daniel Widdis * Add eclipse core runtime version resolution Signed-off-by: Daniel Widdis * Use JDK17 for spotless Signed-off-by: Daniel Widdis * Fetch connector ID from appropriate previous node output Signed-off-by: Daniel Widdis * Fix tests Signed-off-by: Daniel Widdis * Test that actual ID is properly passed Signed-off-by: Daniel Widdis * Update to current setup-java version Signed-off-by: Daniel Widdis * Remove unneeded argument captors Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../workflow/DeleteConnectorStep.java | 105 ++++++++++++++++ .../workflow/WorkflowStepFactory.java | 15 +-- .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/CreateConnectorStepTests.java | 16 +-- .../workflow/DeleteConnectorStepTests.java | 115 ++++++++++++++++++ 5 files changed, 234 insertions(+), 25 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java new file mode 100644 index 000000000..bf0fae33e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; + +/** + * Step to delete a connector for a remote model + */ +public class DeleteConnectorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteConnectorStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_connector"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteConnectorFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteConnectorFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete connector"); + deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String connectorId = null; + + // Previous Node inputs defines which step the connector ID came from + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> CONNECTOR_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + if (previousNode.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) { + connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString(); + } + } + + if (connectorId != null) { + mlClient.deleteConnector(connectorId, actionListener); + } else { + deleteConnectorFuture.completeExceptionally( + new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST) + ); + } + + return deleteConnectorFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c9e565bba..bac65c23a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -25,7 +25,6 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); - private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -42,17 +41,6 @@ public WorkflowStepFactory( Client client, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler - ) { - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; - populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler); - } - - private void populateMap( - Settings settings, - ClusterService clusterService, - Client client, - MachineLearningNodeClient mlClient, - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); @@ -61,6 +49,7 @@ private void populateMap( stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); stepMap.put(ToolStep.NAME, new ToolStep()); stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); @@ -80,7 +69,7 @@ public WorkflowStep createStep(String type) { /** * Gets the step map - * @return the step map + * @return a read-only copy of the step map */ public Map getStepMap() { return Map.copyOf(this.stepMap); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index c9794d4ea..5840b0906 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -39,6 +39,14 @@ "connector_id" ] }, + "delete_connector": { + "inputs": [ + "connector_id" + ], + "outputs":[ + "connector_id" + ] + }, "register_local_model": { "inputs":[ "name", diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index de3add996..1135a0ca6 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -25,7 +25,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -81,15 +80,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), @@ -98,8 +94,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr Collections.emptyMap() ); - verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); assertTrue(future.isDone()); assertEquals(connectorId, future.get().getContent().get("connector_id")); @@ -108,14 +103,11 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr public void testCreateConnectorFailure() throws IOException { CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); return null; - }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), @@ -124,7 +116,7 @@ public void testCreateConnectorFailure() throws IOException { Collections.emptyMap() ); - verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java new file mode 100644 index 000000000..3c997a02e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteConnectorStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id"); + } + + public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { + + String connectorId = randomAlphaOfLength(5); + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String connectorIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, connectorIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("connector_id", connectorId), "workflowId", "nodeId")), + Map.of("step_1", "connector_id") + ); + verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector_id")); + } + + public void testNoConnectorIdInOutput() throws IOException { + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Required field connector_id is not provided", ex.getCause().getMessage()); + } + + public void testDeleteConnectorFailure() throws IOException { + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("connector_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "connector_id") + ); + + verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete connector", ex.getCause().getMessage()); + } +} From 062834be6cea0c1ea4e9e5249abd3c870e4ad155 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Fri, 1 Dec 2023 15:59:28 -0800 Subject: [PATCH 03/27] [feature/agent_framework] Changing resources created format (#231) * adding new resources created format and adding enum for resource types Signed-off-by: Amit Galitzky * remove spotless from java 17 Signed-off-by: Amit Galitzky * add action listener to update resource created Signed-off-by: Amit Galitzky * fixing UT Signed-off-by: Amit Galitzky * changed exception type Signed-off-by: Amit Galitzky --------- Signed-off-by: Amit Galitzky --- .../flowframework/common/CommonValue.java | 10 +- .../common/WorkflowResources.java | 91 +++++++++++++++++++ .../indices/FlowFrameworkIndicesHandler.java | 41 ++++++++- .../flowframework/model/ResourceCreated.java | 73 ++++++++++++--- .../model/WorkflowStepValidator.java | 2 +- .../ProvisionWorkflowTransportAction.java | 6 +- .../workflow/CreateConnectorStep.java | 61 +++++-------- .../workflow/CreateIndexStep.java | 75 ++++++++++----- .../workflow/CreateIngestPipelineStep.java | 52 ++++++++--- .../workflow/ModelGroupStep.java | 52 ++++++++--- .../workflow/RegisterLocalModelStep.java | 59 +++++++++--- .../workflow/RegisterRemoteModelStep.java | 53 ++++++++--- .../workflow/WorkflowStepFactory.java | 13 ++- .../resources/mappings/workflow-state.json | 10 +- .../resources/mappings/workflow-steps.json | 4 +- .../model/ResourceCreatedTests.java | 25 +++-- .../workflow/CreateConnectorStepTests.java | 11 +++ .../workflow/CreateIndexStepTests.java | 27 +++++- .../CreateIngestPipelineStepTests.java | 24 ++++- .../workflow/ModelGroupStepTests.java | 25 ++++- .../workflow/RegisterLocalModelStepTests.java | 22 ++++- .../RegisterRemoteModelStepTests.java | 18 +++- 22 files changed, 581 insertions(+), 173 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/common/WorkflowResources.java diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index dbfc17891..774660bfd 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -77,6 +77,8 @@ private CommonValue() {} public static final String INDEX_NAME = "index_name"; /** Type field */ public static final String TYPE = "type"; + /** default_mapping_option filed */ + public static final String DEFAULT_MAPPING_OPTION = "default_mapping_option"; /** ID Field */ public static final String ID = "id"; /** Pipeline Id field */ @@ -103,6 +105,8 @@ private CommonValue() {} public static final String MODEL_VERSION = "model_version"; /** Model Group Id field */ public static final String MODEL_GROUP_ID = "model_group_id"; + /** Model Group Id field */ + public static final String MODEL_GROUP_STATUS = "model_group_status"; /** Description field */ public static final String DESCRIPTION_FIELD = "description"; /** Connector Id field */ @@ -158,10 +162,10 @@ private CommonValue() {} public static final String USER_OUTPUTS_FIELD = "user_outputs"; /** The template field name for template resources created */ public static final String RESOURCES_CREATED_FIELD = "resources_created"; - /** The field name for the ResourceCreated's resource ID */ - public static final String RESOURCE_ID_FIELD = "resource_id"; - /** The field name for the ResourceCreated's resource name */ + /** The field name for the step name where a resource is created */ public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; + /** The field name for the step ID where a resource is created */ + public static final String WORKFLOW_STEP_ID = "workflow_step_id"; /** LLM Name for registering an agent */ public static final String LLM_FIELD = "llm"; /** The tools' field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java new file mode 100644 index 000000000..04a8650b2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; + +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Enum encapsulating the different step names and the resources they create + */ +public enum WorkflowResources { + + /** official workflow step name for creating a connector and associated created resource */ + CREATE_CONNECTOR("create_connector", "connector_id"), + /** official workflow step name for registering a remote model and associated created resource */ + REGISTER_REMOTE_MODEL("register_remote_model", "model_id"), + /** official workflow step name for registering a local model and associated created resource */ + REGISTER_LOCAL_MODEL("register_local_model", "model_id"), + /** official workflow step name for registering a model group and associated created resource */ + REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + /** official workflow step name for creating an ingest-pipeline and associated created resource */ + CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), + /** official workflow step name for creating an index and associated created resource */ + CREATE_INDEX("create_index", "index_name"); + + private final String workflowStep; + private final String resourceCreated; + private static final Logger logger = LogManager.getLogger(WorkflowResources.class); + private static final Set allResources = Stream.of(values()) + .map(WorkflowResources::getResourceCreated) + .collect(Collectors.toSet()); + + WorkflowResources(String workflowStep, String resourceCreated) { + this.workflowStep = workflowStep; + this.resourceCreated = resourceCreated; + } + + /** + * Returns the workflowStep for the given enum Constant + * @return the workflowStep of this data. + */ + public String getWorkflowStep() { + return workflowStep; + } + + /** + * Returns the resourceCreated for the given enum Constant + * @return the resourceCreated of this data. + */ + public String getResourceCreated() { + return resourceCreated; + } + + /** + * gets the resources created type based on the workflowStep + * @param workflowStep workflow step name + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static String getResourceByWorkflowStep(String workflowStep) throws FlowFrameworkException { + if (workflowStep != null && !workflowStep.isEmpty()) { + for (WorkflowResources mapping : values()) { + if (mapping.getWorkflowStep().equals(workflowStep)) { + return mapping.getResourceCreated(); + } + } + } + logger.error("Unable to find resource type for step: " + workflowStep); + throw new FlowFrameworkException("Unable to find resource type for step: " + workflowStep, RestStatus.BAD_REQUEST); + } + + /** + * Returns all the possible resource created types in enum + * @return a set of all the resource created types + */ + public static Set getAllResourcesCreated() { + return allResources; + } +} diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index cce4ba839..63df7824c 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -32,14 +32,17 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import java.io.IOException; import java.net.URL; @@ -435,6 +438,7 @@ public void updateFlowFrameworkSystemIndexDoc( updatedContent.putAll(updatedFields); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + updateRequest.retryOnConflict(3); // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { @@ -468,7 +472,8 @@ public void updateFlowFrameworkSystemIndexDocWithScript( // TODO: Also add ability to change other fields at the same time when adding detailed provision progress updateRequest.script(script); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - // TODO: decide what condition can be considered as an update conflict and add retry strategy + updateRequest.retryOnConflict(3); + // TODO: Implement our own concurrency control to improve on retry mechanism client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); @@ -478,4 +483,38 @@ public void updateFlowFrameworkSystemIndexDocWithScript( } } } + + /** + * Creates a new ResourceCreated object and a script to update the state index + * @param workflowId workflowId for the relevant step + * @param nodeId WorkflowData object with relevent step information + * @param workflowStepName the workflowstep name that created the resource + * @param resourceId the id of the newly created resource + * @param listener the ActionListener for this step to handle completing the future after update + * @throws IOException if parsing fails on new resource + */ + public void updateResourceInStateIndex( + String workflowId, + String nodeId, + String workflowStepName, + String resourceId, + ActionListener listener + ) throws IOException { + ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceId); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + + // The script to append a new object to the resources_created array + Script script = new Script( + ScriptType.INLINE, + "painless", + "ctx._source.resources_created.add(params.newResource)", + Collections.singletonMap("newResource", newResource) + ); + + updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> { + logger.info("updated resources created of {}", workflowId); + listener.onResponse(updateResponse); + }, exception -> { listener.onFailure(exception); })); + } } diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java index 0ec8f34d5..d039e2f8c 100644 --- a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -8,17 +8,22 @@ */ package org.opensearch.flowframework.model; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; import java.io.IOException; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_NAME; /** @@ -27,16 +32,21 @@ // TODO: create an enum to add the resource name itself for each step example (create_connector_step -> connector) public class ResourceCreated implements ToXContentObject, Writeable { + private static final Logger logger = LogManager.getLogger(ResourceCreated.class); + private final String workflowStepName; + private final String workflowStepId; private final String resourceId; /** - * Create this resources created object with given resource name and ID. + * Create this resources created object with given workflow step name, ID and resource ID. * @param workflowStepName The workflow step name associating to the step where it was created + * @param workflowStepId The workflow step ID associating to the step where it was created * @param resourceId The resources ID for relating to the created resource */ - public ResourceCreated(String workflowStepName, String resourceId) { + public ResourceCreated(String workflowStepName, String workflowStepId, String resourceId) { this.workflowStepName = workflowStepName; + this.workflowStepId = workflowStepId; this.resourceId = resourceId; } @@ -47,6 +57,7 @@ public ResourceCreated(String workflowStepName, String resourceId) { */ public ResourceCreated(StreamInput input) throws IOException { this.workflowStepName = input.readString(); + this.workflowStepId = input.readString(); this.resourceId = input.readString(); } @@ -54,13 +65,15 @@ public ResourceCreated(StreamInput input) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject() .field(WORKFLOW_STEP_NAME, workflowStepName) - .field(RESOURCE_ID_FIELD, resourceId); + .field(WORKFLOW_STEP_ID, workflowStepId) + .field(WorkflowResources.getResourceByWorkflowStep(workflowStepName), resourceId); return xContentBuilder.endObject(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowStepName); + out.writeString(workflowStepId); out.writeString(resourceId); } @@ -82,6 +95,15 @@ public String workflowStepName() { return workflowStepName; } + /** + * Gets the workflow step id associated to the created resource + * + * @return the workflowStepId. + */ + public String workflowStepId() { + return workflowStepId; + } + /** * Parse raw JSON content into a ResourceCreated instance. * @@ -91,6 +113,7 @@ public String workflowStepName() { */ public static ResourceCreated parse(XContentParser parser) throws IOException { String workflowStepName = null; + String workflowStepId = null; String resourceId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -102,22 +125,50 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { case WORKFLOW_STEP_NAME: workflowStepName = parser.text(); break; - case RESOURCE_ID_FIELD: - resourceId = parser.text(); + case WORKFLOW_STEP_ID: + workflowStepId = parser.text(); break; default: - throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + if (!isValidFieldName(fieldName)) { + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); + } else { + if (fieldName.equals(WorkflowResources.getResourceByWorkflowStep(workflowStepName))) { + resourceId = parser.text(); + } + break; + } } } - if (workflowStepName == null || resourceId == null) { - throw new IOException("A ResourceCreated object requires both a workflowStepName and resourceId."); + if (workflowStepName == null) { + logger.error("Resource created object failed parsing: workflowStepName: {}", workflowStepName); + throw new FlowFrameworkException("A ResourceCreated object requires workflowStepName", RestStatus.BAD_REQUEST); + } + if (workflowStepId == null) { + logger.error("Resource created object failed parsing: workflowStepId: {}", workflowStepId); + throw new FlowFrameworkException("A ResourceCreated object requires workflowStepId", RestStatus.BAD_REQUEST); + } + if (resourceId == null) { + logger.error("Resource created object failed parsing: resourceId: {}", resourceId); + throw new FlowFrameworkException("A ResourceCreated object requires resourceId", RestStatus.BAD_REQUEST); } - return new ResourceCreated(workflowStepName, resourceId); + return new ResourceCreated(workflowStepName, workflowStepId, resourceId); + } + + private static boolean isValidFieldName(String fieldName) { + return (WORKFLOW_STEP_NAME.equals(fieldName) + || WORKFLOW_STEP_ID.equals(fieldName) + || WorkflowResources.getAllResourcesCreated().contains(fieldName)); } @Override public String toString() { - return "resources_Created [resource_name=" + workflowStepName + ", id=" + resourceId + "]"; + return "resources_Created [workflow_step_name= " + + workflowStepName + + ", workflow_step_id= " + + workflowStepName + + ", resource_id= " + + resourceId + + "]"; } } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index e49d7d68a..eb1779e93 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -17,7 +17,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; /** - * This represents the an object of workflow steps json which maps each step to expected inputs and outputs + * This represents an object of workflow steps json which maps each step to expected inputs and outputs */ public class WorkflowStepValidator { diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index b381b41ec..da9643cb5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -217,10 +217,10 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage(), exception); }) ); } catch (Exception ex) { - logger.error("Provisioning failed for workflow {} : {}", workflowId, ex); + logger.error("Provisioning failed for workflow: {}", workflowId, ex); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, ImmutableMap.of( @@ -235,7 +235,7 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage()); }) + }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage(), ex); }) ); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index bc4132087..dc4c83d4e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,21 +11,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; -import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; -import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import java.io.IOException; import java.security.AccessController; @@ -48,7 +43,6 @@ import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** @@ -61,7 +55,7 @@ public class CreateConnectorStep implements WorkflowStep { private MachineLearningNodeClient mlClient; private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - static final String NAME = "create_connector"; + static final String NAME = WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(); /** * Instantiate this class @@ -87,44 +81,35 @@ public CompletableFuture execute( @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { - String workflowId = currentNodeInputs.getWorkflowId(); - createConnectorFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), - workflowId, - currentNodeInputs.getNodeId() - ) - ); + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); logger.info("Created connector successfully"); - String workflowStepName = getName(); - ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); - - // The script to append a new object to the resources_created array - Script script = new Script( - ScriptType.INLINE, - "painless", - "ctx._source.resources_created.add(params.newResource)", - Collections.singletonMap("newResource", newResource) - ); - - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( - WORKFLOW_STATE_INDEX, - workflowId, - script, - ActionListener.wrap(updateResponse -> { - logger.info("updated resources created of {}", workflowId); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + getName(), + mlCreateConnectorResponse.getConnectorId(), + ActionListener.wrap(response -> { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + createConnectorFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); }, exception -> { + logger.error("Failed to update new created resource", exception); createConnectorFuture.completeExceptionally( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); - logger.error("Failed to update workflow state with newly created resource: {}", exception); }) ); - } catch (IOException e) { - logger.error("Failed to parse new created resource", e); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 07246134a..0ace57dc3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.client.Client; @@ -17,6 +18,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import java.util.ArrayList; @@ -26,8 +29,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; -import static org.opensearch.flowframework.common.CommonValue.INDEX_NAME; -import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.common.CommonValue.DEFAULT_MAPPING_OPTION; /** * Step to create an index @@ -37,21 +39,23 @@ public class CreateIndexStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); private final ClusterService clusterService; private final Client client; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ - static final String NAME = "create_index"; + static final String NAME = WorkflowResources.CREATE_INDEX.getWorkflowStep(); static Map indexMappingUpdated = new HashMap<>(); - private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); /** * Instantiate this class * * @param clusterService The OpenSearch cluster service * @param client Client to create an index + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateIndexStep(ClusterService clusterService, Client client) { + public CreateIndexStep(ClusterService clusterService, Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.clusterService = clusterService; this.client = client; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -61,30 +65,50 @@ public CompletableFuture execute( Map outputs, Map previousNodeInputs ) { - CompletableFuture future = new CompletableFuture<>(); + CompletableFuture createIndexFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(CreateIndexResponse createIndexResponse) { - logger.info("created index: {}", createIndexResponse.index()); - future.complete( - new WorkflowData( - Map.of(INDEX_NAME, createIndexResponse.index()), + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + logger.info("created index: {}", createIndexResponse.index()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + createIndexResponse.index(), + ActionListener.wrap(response -> { + logger.info("successfully updated resource created in state index: {}", response.getIndex()); + createIndexFuture.complete( + new WorkflowData( + Map.of(resourceName, createIndexResponse.index()), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + createIndexFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override public void onFailure(Exception e) { logger.error("Failed to create an index", e); - future.completeExceptionally(e); + createIndexFuture.completeExceptionally(e); } }; String index = null; - String type = null; + String defaultMappingOption = null; Settings settings = null; // TODO: Recreating the list to get this compiling @@ -93,13 +117,18 @@ public void onFailure(Exception e) { data.add(currentNodeInputs); data.addAll(outputs.values()); - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - index = (String) content.get(INDEX_NAME); - type = (String) content.get(TYPE); - if (index != null && type != null && settings != null) { - break; + try { + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + index = (String) content.get(WorkflowResources.getResourceByWorkflowStep(getName())); + defaultMappingOption = (String) content.get(DEFAULT_MAPPING_OPTION); + if (index != null && defaultMappingOption != null && settings != null) { + break; + } } + } catch (Exception e) { + logger.error("Failed to find the correct resource for the workflow step", e); + createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } // TODO: @@ -107,7 +136,7 @@ public void onFailure(Exception e) { try { CreateIndexRequest request = new CreateIndexRequest(index).mapping( - FlowFrameworkIndicesHandler.getIndexMappings("mappings/" + type + ".json"), + FlowFrameworkIndicesHandler.getIndexMappings("mappings/" + defaultMappingOption + ".json"), JsonXContent.jsonXContent.mediaType() ); client.admin().indices().create(request, actionListener); @@ -115,7 +144,7 @@ public void onFailure(Exception e) { logger.error("Failed to find the right mapping for the index", e); } - return future; + return createIndexFuture; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 77dae29eb..352772a49 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -18,6 +19,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import java.io.IOException; import java.util.ArrayList; @@ -33,7 +37,6 @@ import static org.opensearch.flowframework.common.CommonValue.INPUT_FIELD_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.OUTPUT_FIELD_NAME; -import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PROCESSORS; import static org.opensearch.flowframework.common.CommonValue.TYPE; @@ -45,18 +48,21 @@ public class CreateIngestPipelineStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ - static final String NAME = "create_ingest_pipeline"; + static final String NAME = WorkflowResources.CREATE_INGEST_PIPELINE.getWorkflowStep(); // Client to store a pipeline in the cluster state private final ClusterAdminClient clusterAdminClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + /** * Instantiates a new CreateIngestPipelineStep - * * @param client The client to create a pipeline and store workflow data into the global context index + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public CreateIngestPipelineStep(Client client) { + public CreateIngestPipelineStep(Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.clusterAdminClient = client.admin().cluster(); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -136,16 +142,38 @@ public CompletableFuture execute( clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); - // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete( - new WorkflowData( - Map.of(PIPELINE_ID, putPipelineRequest.getId()), + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + putPipelineRequest.getId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead + // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here + createIngestPipelineFuture.complete( + new WorkflowData( + Map.of(resourceName, putPipelineRequest.getId()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + createIngestPipelineFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); - // TODO : Use node client to index response data to global context (pending global context index implementation) + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + createIngestPipelineFuture.completeExceptionally( + new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)) + ); + } }, exception -> { logger.error("Failed to create ingest pipeline : " + exception.getMessage()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 22c6ae810..50ae30986 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -31,6 +33,7 @@ import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; /** @@ -42,14 +45,18 @@ public class ModelGroupStep implements WorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "model_group"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_MODEL_GROUP.getWorkflowStep(); /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public ModelGroupStep(MachineLearningNodeClient mlClient) { + public ModelGroupStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -65,17 +72,38 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { - logger.info("Model group registration successful"); - registerModelGroupFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), - Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) - ), + try { + logger.info("Remote Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + mlRegisterModelGroupResponse.getModelGroupId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + registerModelGroupFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, mlRegisterModelGroupResponse.getModelGroupId()), + Map.entry(MODEL_GROUP_STATUS, mlRegisterModelGroupResponse.getStatus()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerModelGroupFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index cc5645306..3dc730b54 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -16,7 +16,9 @@ import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; @@ -42,7 +44,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -58,17 +59,26 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "register_local_model"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_LOCAL_MODEL.getWorkflowStep(); /** * Instantiate this class * @param settings The OpenSearch settings * @param clusterService The cluster service * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public RegisterLocalModelStep(Settings settings, ClusterService clusterService, MachineLearningNodeClient mlClient) { + public RegisterLocalModelStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { super(settings, clusterService); this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -218,7 +228,7 @@ public String getName() { * Retryable get ml task * @param workflowId the workflow id * @param nodeId the workflow node id - * @param getMLTaskFuture the workflow step future + * @param registerLocalModelFuture the workflow step future * @param taskId the ml task id * @param retries the current number of request retries */ @@ -242,17 +252,38 @@ void retryableGetMlTask( throw new IllegalStateException("Local model registration is not yet completed"); } } else { - logger.info("Local model registeration successful"); - registerLocalModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(MODEL_ID, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), + try { + logger.info("Local Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( workflowId, - nodeId - ) - ); + nodeId, + getName(), + response.getTaskId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + registerLocalModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerLocalModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } }, exception -> { if (retries < maxRetry) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 27a77cb98..7e33937bc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -32,7 +34,6 @@ import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -45,14 +46,18 @@ public class RegisterRemoteModelStep implements WorkflowStep { private final MachineLearningNodeClient mlClient; - static final String NAME = "register_remote_model"; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + static final String NAME = WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(); /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { + public RegisterRemoteModelStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -68,17 +73,39 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - logger.info("Remote Model registration successful"); - registerRemoteModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ), + + try { + logger.info("Remote Model registration successful"); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + currentNodeId, + getName(), + mlRegisterModelResponse.getModelId(), + ActionListener.wrap(response -> { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + registerRemoteModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, mlRegisterModelResponse.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) + ), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerRemoteModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index bac65c23a..ce0b24d24 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -43,14 +43,17 @@ public WorkflowStepFactory( FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); - stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); - stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterLocalModelStep.NAME, new RegisterLocalModelStep(settings, clusterService, mlClient)); - stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); + stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); + stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); + stepMap.put( + RegisterLocalModelStep.NAME, + new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) + ); + stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); - stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, new ToolStep()); stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); } diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 21df5ccd6..86fbeef6e 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,15 +31,7 @@ "type": "object" }, "resources_created": { - "type": "nested", - "properties": { - "workflow_step_name": { - "type": "keyword" - }, - "resource_id": { - "type": "keyword" - } - } + "type": "object" }, "ui_metadata": { "type": "object", diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 5840b0906..eb92ccd5e 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -6,7 +6,7 @@ "create_index": { "inputs":[ "index_name", - "type" + "default_mapping_option" ], "outputs":[ "index_name" @@ -83,7 +83,7 @@ "deploy_model_status" ] }, - "model_group": { + "register_model_group": { "inputs":[ "name" ], diff --git a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java index c38d6f81c..216c18c9e 100644 --- a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java +++ b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -20,17 +21,21 @@ public void setUp() throws Exception { } public void testParseFeature() throws IOException { - ResourceCreated resourceCreated = new ResourceCreated("A", "B"); - assertEquals(resourceCreated.workflowStepName(), "A"); - assertEquals(resourceCreated.resourceId(), "B"); - - String expectedJson = "{\"workflow_step_name\":\"A\",\"resource_id\":\"B\"}"; - String json = TemplateTestJsonUtil.parseToJson(resourceCreated); + String workflowStepName = WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(); + ResourceCreated ResourceCreated = new ResourceCreated(workflowStepName, "workflow_step_1", "L85p1IsBbfF"); + assertEquals(ResourceCreated.workflowStepName(), workflowStepName); + assertEquals(ResourceCreated.workflowStepId(), "workflow_step_1"); + assertEquals(ResourceCreated.resourceId(), "L85p1IsBbfF"); + + String expectedJson = + "{\"workflow_step_name\":\"create_connector\",\"workflow_step_id\":\"workflow_step_1\",\"connector_id\":\"L85p1IsBbfF\"}"; + String json = TemplateTestJsonUtil.parseToJson(ResourceCreated); assertEquals(expectedJson, json); - ResourceCreated resourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); - assertEquals("A", resourceCreatedTwo.workflowStepName()); - assertEquals("B", resourceCreatedTwo.resourceId()); + ResourceCreated ResourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals(workflowStepName, ResourceCreatedTwo.workflowStepName()); + assertEquals("workflow_step_1", ResourceCreatedTwo.workflowStepId()); + assertEquals("L85p1IsBbfF", ResourceCreatedTwo.resourceId()); } public void testExceptions() throws IOException { @@ -40,7 +45,7 @@ public void testExceptions() throws IOException { String missingJson = "{\"resource_id\":\"B\"}"; e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); - assertEquals("A ResourceCreated object requires both a workflowStepName and resourceId.", e.getMessage()); + assertEquals("Unable to parse field [resource_id] in a resources_created object.", e.getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 1135a0ca6..09c6c3c68 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,7 +8,9 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -28,7 +30,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -87,6 +92,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), inputData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 3b9233c95..a4c8ae92c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -21,9 +22,12 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -35,8 +39,12 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -60,12 +68,18 @@ public class CreateIndexStepTests extends OpenSearchTestCase { private ThreadPool threadPool; @Mock IndexMetadata indexMetadata; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id", "test-node-id"); + inputData = new WorkflowData( + Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("default_mapping_option", "knn")), + "test-id", + "test-node-id" + ); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -80,11 +94,18 @@ public void setUp() throws Exception { when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test cluster")).build()); when(metadata.indices()).thenReturn(Map.of(GLOBAL_CONTEXT_INDEX, indexMetadata)); - createIndexStep = new CreateIndexStep(clusterService, client); + createIndexStep = new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler); CreateIndexStep.indexMappingUpdated = indexMappingUpdated; } - public void testCreateIndexStep() throws ExecutionException, InterruptedException { + public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = createIndexStep.execute( diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index f0a970758..1c7940949 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -10,12 +10,16 @@ import org.opensearch.action.ingest.PutPipelineRequest; import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -23,7 +27,11 @@ import org.mockito.ArgumentCaptor; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -37,10 +45,12 @@ public class CreateIngestPipelineStepTests extends OpenSearchTestCase { private Client client; private AdminClient adminClient; private ClusterAdminClient clusterAdminClient; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); inputData = new WorkflowData( Map.ofEntries( @@ -66,9 +76,15 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); } - public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException { + public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException, IOException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -91,7 +107,7 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio public void testCreateIngestPipelineStepFailure() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -116,7 +132,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { } public void testMissingData() throws InterruptedException { - CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler); // Data with missing input and output fields WorkflowData incorrectData = new WorkflowData( diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index bc914baa7..d78a97e8a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -9,9 +9,12 @@ package org.opensearch.flowframework.workflow; import com.google.common.collect.ImmutableList; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLTaskState; @@ -29,8 +32,12 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; public class ModelGroupStepTests extends OpenSearchTestCase { @@ -40,10 +47,12 @@ public class ModelGroupStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); inputData = new WorkflowData( Map.ofEntries( @@ -63,7 +72,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep String modelGroupId = "model_group_id"; String status = MLTaskState.CREATED.name(); - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -75,6 +84,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = modelGroupStep.execute( inputData.getNodeId(), inputData, @@ -90,8 +105,8 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep } - public void testRegisterModelGroupFailure() throws ExecutionException, InterruptedException, IOException { - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + public void testRegisterModelGroupFailure() throws IOException { + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -119,7 +134,7 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt } public void testRegisterModelGroupWithNoName() throws IOException { - ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); + ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); CompletableFuture future = modelGroupStep.execute( inputDataWithNoName.getNodeId(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d169812a9..c38f8a120 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -10,12 +10,15 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; @@ -34,10 +37,13 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -49,6 +55,7 @@ public class RegisterLocalModelStepTests extends OpenSearchTestCase { private RegisterLocalModelStep registerLocalModelStep; private WorkflowData workflowData; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -56,7 +63,7 @@ public class RegisterLocalModelStepTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); ClusterService clusterService = mock(ClusterService.class); final Set> settingsSet = Stream.concat( @@ -69,7 +76,12 @@ public void setUp() throws Exception { ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.registerLocalModelStep = new RegisterLocalModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient); + this.registerLocalModelStep = new RegisterLocalModelStep( + testMaxRetrySetting, + clusterService, + machineLearningNodeClient, + flowFrameworkIndicesHandler + ); this.workflowData = new WorkflowData( Map.ofEntries( @@ -127,6 +139,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).getTask(any(), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index cde194326..a83443f05 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -10,8 +10,11 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -26,10 +29,14 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -38,6 +45,7 @@ public class RegisterRemoteModelStepTests extends OpenSearchTestCase { private RegisterRemoteModelStep registerRemoteModelStep; private WorkflowData workflowData; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Mock MachineLearningNodeClient mlNodeClient; @@ -45,9 +53,9 @@ public class RegisterRemoteModelStepTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient); + this.registerRemoteModelStep = new RegisterRemoteModelStep(mlNodeClient, flowFrameworkIndicesHandler); this.workflowData = new WorkflowData( Map.ofEntries( Map.entry("function_name", "remote"), @@ -73,6 +81,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), workflowData, From 9cb621d31f0efccfc2bc2c2942f102c87c60cda1 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 1 Dec 2023 16:53:47 -0800 Subject: [PATCH 04/27] [Feature/agent_framework] Fetches modelID for RegisterAgent and Tools workflow steps (#235) * Flattened llm field of register agent Signed-off-by: Owais Kazi * Handled optional modelId Signed-off-by: Owais Kazi * Handled modelId for llm Signed-off-by: Owais Kazi * Parsing for parameters field of tools Signed-off-by: Owais Kazi * Handled test case failures Signed-off-by: Owais Kazi * Fixed spotless failure Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- build.gradle | 1 - .../flowframework/common/CommonValue.java | 2 - .../flowframework/model/WorkflowNode.java | 16 +--- .../flowframework/util/ParseUtils.java | 67 ---------------- .../workflow/RegisterAgentStep.java | 79 ++++++++++++++----- .../flowframework/workflow/ToolStep.java | 37 +++++++-- .../resources/mappings/workflow-steps.json | 1 - .../model/WorkflowNodeTests.java | 14 ---- .../flowframework/util/ParseUtilsTests.java | 19 ----- .../workflow/RegisterAgentTests.java | 3 +- 10 files changed, 95 insertions(+), 144 deletions(-) diff --git a/build.gradle b/build.gradle index d6e7afbd8..20a9837ca 100644 --- a/build.gradle +++ b/build.gradle @@ -155,7 +155,6 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'org.bouncycastle:bcprov-jdk18on:1.77' - implementation "com.google.code.gson:gson:2.10.1" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 774660bfd..2343cd305 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -166,8 +166,6 @@ private CommonValue() {} public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; /** The field name for the step ID where a resource is created */ public static final String WORKFLOW_STEP_ID = "workflow_step_id"; - /** LLM Name for registering an agent */ - public static final String LLM_FIELD = "llm"; /** The tools' field for an agent */ public static final String TOOLS_FIELD = "tools"; /** The memory field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 999ba460f..b942ccb16 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -14,7 +14,6 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; -import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.util.ArrayList; @@ -25,10 +24,7 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; -import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; -import static org.opensearch.flowframework.util.ParseUtils.parseLLM; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -101,12 +97,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } xContentBuilder.endArray(); - } else if (e.getValue() instanceof LLMSpec) { - if (LLM_FIELD.equals(e.getKey())) { - xContentBuilder.startObject(); - buildLLMMap(xContentBuilder, (LLMSpec) e.getValue()); - xContentBuilder.endObject(); - } } } xContentBuilder.endObject(); @@ -150,11 +140,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (LLM_FIELD.equals(inputFieldName)) { - userInputs.put(inputFieldName, parseLLM(parser)); - } else { - userInputs.put(inputFieldName, parseStringToStringMap(parser)); - } + userInputs.put(inputFieldName, parseStringToStringMap(parser)); break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 14d113e34..e22017eaf 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -8,7 +8,6 @@ */ package org.opensearch.flowframework.util; -import com.google.gson.Gson; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -25,9 +24,6 @@ import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.time.Instant; import java.util.HashMap; import java.util.Map; @@ -36,7 +32,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; /** * Utility methods for Template parsing @@ -44,12 +39,6 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); - public static final Gson gson; - - static { - gson = new Gson(); - } - private ParseUtils() {} /** @@ -117,37 +106,6 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } - // TODO Figure out a way to use the parse method of LLMSpec of ml-commons - /** - * Parses an XContent object representing the object of LLMSpec - * @param parser An XContent parser whose position is at the start of the map object to parse - * @return instance of {@link org.opensearch.ml.common.agent.LLMSpec} - * @throws IOException parsing error - */ - public static LLMSpec parseLLM(XContentParser parser) throws IOException { - String modelId = null; - Map parameters = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - switch (fieldName) { - case MODEL_ID: - modelId = parser.text(); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap(parser.map()); - break; - default: - parser.skipChildren(); - break; - } - } - return LLMSpec.builder().modelId(modelId).parameters(parameters).build(); - } - /** * Parse content parser to {@link java.time.Instant}. * @@ -176,31 +134,6 @@ public static User getUserContext(Client client) { return User.parse(userStr); } - /** - * Generates a parameter map required when the parameter is nested within an object - * @param parameterObjs parameters - * @return a parameters map of type String,String - */ - public static Map getParameterMap(Map parameterObjs) { - Map parameters = new HashMap<>(); - for (String key : parameterObjs.keySet()) { - Object value = parameterObjs.get(key); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (value instanceof String) { - parameters.put(key, (String) value); - } else { - parameters.put(key, gson.toJson(value)); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - return parameters; - } - /** * Creates a XContentParser from a given Registry * diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 44270d8e6..022d46c22 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -37,8 +38,8 @@ import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; import static org.opensearch.flowframework.common.CommonValue.MEMORY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; @@ -56,6 +57,9 @@ public class RegisterAgentStep implements WorkflowStep { static final String NAME = "register_agent"; + private static final String LLM_MODEL_ID = "llm.model_id"; + private static final String LLM_PARAMETERS = "llm.parameters"; + private List mlToolSpecList; /** @@ -80,7 +84,7 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { - logger.info("Remote Agent registration successful"); + logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); registerAgentModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), @@ -100,7 +104,8 @@ public void onFailure(Exception e) { String name = null; String type = null; String description = null; - LLMSpec llm = null; + String llmModelId = null; + Map llmParameters = Collections.emptyMap(); List tools = new ArrayList<>(); Map parameters = Collections.emptyMap(); MLMemorySpec memory = null; @@ -128,8 +133,11 @@ public void onFailure(Exception e) { case TYPE: type = (String) entry.getValue(); break; - case LLM_FIELD: - llm = getLLMSpec(entry.getValue()); + case LLM_MODEL_ID: + llmModelId = (String) entry.getValue(); + break; + case LLM_PARAMETERS: + llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS); break; case TOOLS_FIELD: tools = addTools(entry.getValue()); @@ -155,7 +163,22 @@ public void onFailure(Exception e) { } } - if (Stream.of(name, type, llm, tools, parameters, memory, appType).allMatch(x -> x != null)) { + // Case when modelId is present in previous node inputs + if (llmModelId == null) { + llmModelId = getLlmModelId(previousNodeInputs, outputs); + } + + // Case when modelId is not present at all + if (llmModelId == null) { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + ); + return registerAgentModelFuture; + } + + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + + if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) { MLAgentBuilder builder = MLAgent.builder().name(name); if (description != null) { @@ -163,7 +186,7 @@ public void onFailure(Exception e) { } builder.type(type) - .llm(llm) + .llm(llmSpec) .tools(tools) .parameters(parameters) .memory(memory) @@ -195,11 +218,38 @@ private List addTools(Object tools) { return mlToolSpecList; } - private LLMSpec getLLMSpec(Object llm) { - if (llm instanceof LLMSpec) { - return (LLMSpec) llm; + private String getLlmModelId(Map previousNodeInputs, Map outputs) { + // Case when modelId is already pass in the template + String llmModelId = null; + + // Case when modelId is passed through previousSteps + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> MODEL_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + + if (previousNode.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { + llmModelId = previousNodeOutput.getContent().get(MODEL_ID).toString(); + } } - throw new IllegalArgumentException("[" + LLM_FIELD + "] must be of type LLMSpec."); + return llmModelId; + } + + private LLMSpec getLLMSpec(String llmModelId, Map llmParameters) { + if (llmModelId == null) { + throw new IllegalArgumentException("model id for llm is null"); + } + LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); + builder.modelId(llmModelId); + if (llmParameters != null) { + builder.parameters(llmParameters); + } + + LLMSpec llmSpec = builder.build(); + return llmSpec; } private MLMemorySpec getMLMemorySpec(Object mlMemory) { @@ -231,11 +281,4 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) { } - private Instant getInstant(Object instant, String fieldName) { - if (instant instanceof Instant) { - return (Instant) instant; - } - throw new IllegalArgumentException("[" + fieldName + "] must be of type Instant."); - } - } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 339142139..af8556289 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -20,15 +20,16 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; -import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** * Step to register a tool for an agent @@ -64,19 +65,19 @@ public CompletableFuture execute( for (Entry entry : content.entrySet()) { switch (entry.getKey()) { case TYPE: - type = (String) content.get(TYPE); + type = (String) entry.getValue(); break; case NAME_FIELD: - name = (String) content.get(NAME_FIELD); + name = (String) entry.getValue(); break; case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); + description = (String) entry.getValue(); break; case PARAMETERS_FIELD: - parameters = getStringToStringMap(content.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs); break; case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: - includeOutputInAgentResponse = (Boolean) content.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + includeOutputInAgentResponse = (Boolean) entry.getValue(); break; default: break; @@ -124,4 +125,28 @@ public CompletableFuture execute( public String getName() { return NAME; } + + private Map getToolsParametersMap( + Object parameters, + Map previousNodeInputs, + Map outputs + ) { + Map parametersMap = (Map) parameters; + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> MODEL_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + // Case when modelId is passed through previousSteps and not present already in parameters + if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { + parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString()); + return parametersMap; + } + } + // For other cases where modelId is already present in the parameters or not return the parametersMap + return parametersMap; + } + } diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index eb92ccd5e..b5d09e8cb 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -96,7 +96,6 @@ "inputs":[ "name", "type", - "llm", "tools", "parameters", "memory", diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index c0011f7ae..b9620c214 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -8,11 +8,9 @@ */ package org.opensearch.flowframework.model; -import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.HashMap; import java.util.Map; public class WorkflowNodeTests extends OpenSearchTestCase { @@ -23,11 +21,6 @@ public void setUp() throws Exception { } public void testNode() throws IOException { - Map parameters = new HashMap<>(); - parameters.put("stop", "true"); - parameters.put("max", "5"); - - LLMSpec llmSpec = new LLMSpec("modelId", parameters); WorkflowNode nodeA = new WorkflowNode( "A", @@ -38,7 +31,6 @@ public void testNode() throws IOException { Map.entry("bar", Map.of("key", "value")), Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), - Map.entry("llm", llmSpec), Map.entry("created_time", 1689793598499L) ) ); @@ -71,9 +63,6 @@ public void testNode() throws IOException { assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); assertTrue(json.contains("\"created_time\":1689793598499")); - assertTrue(json.contains("llm\":{")); - assertTrue(json.contains("\"parameters\":{\"stop\":\"true\",\"max\":\"5\"")); - assertTrue(json.contains("\"model_id\":\"modelId\"")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); @@ -88,9 +77,6 @@ public void testNode() throws IOException { assertEquals(1, ppX.length); assertEquals("test-type", ppX[0].type()); assertEquals(Map.of("key2", "value2"), ppX[0].params()); - LLMSpec llm = (LLMSpec) mapX.get("llm"); - assertEquals("modelId", llm.getModelId()); - assertEquals(parameters, llm.getParameters()); } public void testExceptions() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 76334b52b..a5c4253b3 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -12,12 +12,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; -import org.junit.Assert; import java.io.IOException; import java.time.Instant; -import java.util.HashMap; -import java.util.Map; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -57,20 +54,4 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } - - public void testGetParameterMap() { - Map parameters = new HashMap<>(); - parameters.put("key1", "value1"); - parameters.put("key2", 2); - parameters.put("key3", 2.1); - parameters.put("key4", new int[] { 10, 20 }); - parameters.put("key5", new Object[] { 1.01, "abc" }); - Map parameterMap = ParseUtils.getParameterMap(parameters); - Assert.assertEquals(5, parameterMap.size()); - Assert.assertEquals("value1", parameterMap.get("key1")); - Assert.assertEquals("2", parameterMap.get("key2")); - Assert.assertEquals("2.1", parameterMap.get("key3")); - Assert.assertEquals("[10,20]", parameterMap.get("key4")); - Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); - } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 0f4b33471..c393be5e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -60,7 +60,8 @@ public void setUp() throws Exception { Map.entry("name", "test"), Map.entry("description", "description"), Map.entry("type", "type"), - Map.entry("llm", llmSpec), + Map.entry("llm.model_id", "xyz"), + Map.entry("llm.parameters", Collections.emptyMap()), Map.entry("tools", tools), Map.entry("parameters", Collections.emptyMap()), Map.entry("memory", mlMemorySpec), From 83af2bac02ee1653e5e4253512a442a4003f9bcb Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 1 Dec 2023 22:08:44 -0800 Subject: [PATCH 05/27] Add Util method to fetch inputs from parameters, content, and previous step output (#234) * Util method to get required inputs Signed-off-by: Daniel Widdis * Implement parsing in some of the steps Signed-off-by: Daniel Widdis * Handle parsing exceptions in the future Signed-off-by: Daniel Widdis * Improve exception handling Signed-off-by: Daniel Widdis * More steps using the new input parsing Signed-off-by: Daniel Widdis * Update Delete Connector Step with parsing util Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../flowframework/util/ParseUtils.java | 102 +++++++++++++++ .../workflow/CreateConnectorStep.java | 91 ++++++-------- .../workflow/DeleteConnectorStep.java | 39 +++--- .../workflow/DeployModelStep.java | 35 +++--- .../workflow/ModelGroupStep.java | 63 +++------- .../workflow/RegisterLocalModelStep.java | 117 ++++++------------ .../workflow/RegisterRemoteModelStep.java | 70 +++-------- .../flowframework/workflow/WorkflowStep.java | 2 +- .../flowframework/util/ParseUtilsTests.java | 51 ++++++++ .../workflow/DeleteConnectorStepTests.java | 5 +- .../workflow/ModelGroupStepTests.java | 8 +- .../workflow/RegisterLocalModelStepTests.java | 18 ++- .../RegisterRemoteModelStepTests.java | 8 +- 13 files changed, 327 insertions(+), 282 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index e22017eaf..9e3b8d067 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -18,16 +18,24 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.time.Instant; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -39,6 +47,9 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); + // Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar + private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}"); + private ParseUtils() {} /** @@ -161,4 +172,95 @@ public static Map getStringToStringMap(Object map, String fieldN throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); } + /** + * Creates a map containing the specified input keys, with values derived from template data or previous node + * output. + * + * @param requiredInputKeys A set of keys that must be present, or will cause an exception to be thrown + * @param optionalInputKeys A set of keys that may be present, or will be absent in the returned map + * @param currentNodeInputs Input params and content for this node, from workflow parsing + * @param outputs WorkflowData content of previous steps + * @param previousNodeInputs Input params for this node that come from previous steps + * @return A map containing the requiredInputKeys with their corresponding values, + * and optionalInputKeys with their corresponding values if present. + * Throws a {@link FlowFrameworkException} if a required key is not present. + */ + public static Map getInputsFromPreviousSteps( + Set requiredInputKeys, + Set optionalInputKeys, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { + // Mutable set to ensure all required keys are used + Set requiredKeys = new HashSet<>(requiredInputKeys); + // Merge input sets to add all requested keys + Set keys = new HashSet<>(requiredInputKeys); + keys.addAll(optionalInputKeys); + // Initialize return map + Map inputs = new HashMap<>(); + for (String key : keys) { + Object value = null; + // Priority 1: specifically named prior step inputs + // ... parse the previousNodeInputs map and fill in the specified keys + Optional previousNodeForKey = previousNodeInputs.entrySet() + .stream() + .filter(e -> key.equals(e.getValue())) + .map(Map.Entry::getKey) + .findAny(); + if (previousNodeForKey.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNodeForKey.get()); + if (previousNodeOutput != null) { + value = previousNodeOutput.getContent().get(key); + } + } + // Priority 2: inputs specified in template + // ... fetch from currentNodeInputs (params take precedence) + if (value == null) { + value = currentNodeInputs.getParams().get(key); + } + if (value == null) { + value = currentNodeInputs.getContent().get(key); + } + // Priority 3: other inputs + if (value == null) { + Optional matchedValue = outputs.values() + .stream() + .map(WorkflowData::getContent) + .filter(m -> m.containsKey(key)) + .map(m -> m.get(key)) + .findAny(); + if (matchedValue.isPresent()) { + value = matchedValue.get(); + } + } + // Check for substitution + if (value != null) { + Matcher m = SUBSTITUTION_PATTERN.matcher(value.toString()); + if (m.matches()) { + WorkflowData data = outputs.get(m.group(1)); + if (data != null && data.getContent().containsKey(m.group(2))) { + value = data.getContent().get(m.group(2)); + } + } + inputs.put(key, value); + requiredKeys.remove(key); + } + } + // After iterating is complete, throw exception if requiredKeys is not empty + if (!requiredKeys.isEmpty()) { + throw new FlowFrameworkException( + "Missing required inputs " + + requiredKeys + + " in workflow [" + + currentNodeInputs.getWorkflowId() + + "] node [" + + currentNodeInputs.getNodeId() + + "]", + RestStatus.BAD_REQUEST + ); + } + // Finally return the map + return inputs; + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index dc4c83d4e..bca1c8856 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -16,6 +16,7 @@ import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; @@ -33,8 +34,8 @@ import java.util.Locale; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; @@ -120,57 +121,44 @@ public void onFailure(Exception e) { } }; - String name = null; - String description = null; - String version = null; - String protocol = null; - Map parameters = Collections.emptyMap(); - Map credentials = Collections.emptyMap(); - List actions = Collections.emptyList(); - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); + Set requiredKeys = Set.of( + NAME_FIELD, + DESCRIPTION_FIELD, + VERSION_FIELD, + PROTOCOL_FIELD, + PARAMETERS_FIELD, + CREDENTIAL_FIELD, + ACTIONS_FIELD + ); + Set optionalKeys = Collections.emptySet(); try { - for (WorkflowData workflowData : data) { - for (Entry entry : workflowData.getContent().entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case VERSION_FIELD: - version = (String) entry.getValue(); - break; - case PROTOCOL_FIELD: - protocol = (String) entry.getValue(); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap(entry.getValue()); - break; - case CREDENTIAL_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD); - break; - case ACTIONS_FIELD: - actions = getConnectorActionList(entry.getValue()); - break; - } - } + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String version = (String) inputs.get(VERSION_FIELD); + String protocol = (String) inputs.get(PROTOCOL_FIELD); + Map parameters; + Map credentials; + List actions; + + try { + parameters = getParameterMap(inputs.get(PARAMETERS_FIELD)); + credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); + actions = getConnectorActionList(inputs.get(ACTIONS_FIELD)); + } catch (IllegalArgumentException iae) { + throw new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST); + } catch (PrivilegedActionException pae) { + throw new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED); } - } catch (IllegalArgumentException iae) { - createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST)); - return createConnectorFuture; - } catch (PrivilegedActionException pae) { - createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED)); - return createConnectorFuture; - } - if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder() .name(name) .description(description) @@ -182,12 +170,9 @@ public void onFailure(Exception e) { .build(); mlClient.createConnector(mlInput, actionListener); - } else { - createConnectorFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + createConnectorFuture.completeExceptionally(e); } - return createConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index bf0fae33e..517e484a7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -13,13 +13,14 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import java.io.IOException; +import java.util.Collections; import java.util.Map; -import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; @@ -72,29 +73,23 @@ public void onFailure(Exception e) { } }; - String connectorId = null; - - // Previous Node inputs defines which step the connector ID came from - Optional previousNode = previousNodeInputs.entrySet() - .stream() - .filter(e -> CONNECTOR_ID.equals(e.getValue())) - .map(Map.Entry::getKey) - .findFirst(); - if (previousNode.isPresent()) { - WorkflowData previousNodeOutput = outputs.get(previousNode.get()); - if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) { - connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString(); - } - } + Set requiredKeys = Set.of(CONNECTOR_ID); + Set optionalKeys = Collections.emptySet(); - if (connectorId != null) { - mlClient.deleteConnector(connectorId, actionListener); - } else { - deleteConnectorFuture.completeExceptionally( - new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST) + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs ); - } + String connectorId = (String) inputs.get(CONNECTOR_ID); + mlClient.deleteConnector(connectorId, actionListener); + } catch (FlowFrameworkException e) { + deleteConnectorFuture.completeExceptionally(e); + } return deleteConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 81409ef77..f878fbdc2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -12,14 +12,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -71,27 +71,24 @@ public void onFailure(Exception e) { } }; - String modelId = null; + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - for (WorkflowData workflowData : data) { - if (workflowData.getContent().containsKey(MODEL_ID)) { - modelId = (String) workflowData.getContent().get(MODEL_ID); - break; - } - } + String modelId = (String) inputs.get(MODEL_ID); - if (modelId != null) { mlClient.deploy(modelId, actionListener); - } else { - deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST)); + } catch (FlowFrameworkException e) { + deployModelFuture.completeExceptionally(e); } - return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 50ae30986..e2aea19df 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -12,10 +12,10 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -23,10 +23,9 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES; @@ -113,49 +112,23 @@ public void onFailure(Exception e) { } }; - String modelGroupName = null; - String description = null; - List backendRoles = new ArrayList<>(); - AccessMode modelAccessMode = null; - Boolean isAddAllBackendRoles = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelGroupName = (String) content.get(NAME_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case BACKEND_ROLES_FIELD: - backendRoles = getBackendRoles(content); - break; - case MODEL_ACCESS_MODE: - modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE); - break; - case ADD_ALL_BACKEND_ROLES: - isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES); - break; - default: - break; - } - } - } + Set requiredKeys = Set.of(NAME_FIELD); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, BACKEND_ROLES_FIELD, MODEL_ACCESS_MODE, ADD_ALL_BACKEND_ROLES); - if (modelGroupName == null) { - registerModelGroupFuture.completeExceptionally( - new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST) + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs ); - } else { + String modelGroupName = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + List backendRoles = getBackendRoles(inputs); + AccessMode modelAccessMode = (AccessMode) inputs.get(MODEL_ACCESS_MODE); + Boolean isAddAllBackendRoles = (Boolean) inputs.get(ADD_ALL_BACKEND_ROLES); + MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder(); builder.name(modelGroupName); if (description != null) { @@ -173,6 +146,8 @@ public void onFailure(Exception e) { MLRegisterModelGroupInput mlInput = builder.build(); mlClient.registerModelGroup(mlInput, actionListener); + } catch (FlowFrameworkException e) { + registerModelGroupFuture.completeExceptionally(e); } return registerModelGroupFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 3dc730b54..fb1d383b5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -19,6 +19,7 @@ import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; @@ -30,10 +31,8 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -91,12 +90,6 @@ public CompletableFuture execute( CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -115,76 +108,41 @@ public void onFailure(Exception e) { } }; - String modelName = null; - String modelVersion = null; - String description = null; - MLModelFormat modelFormat = null; - String modelGroupId = null; - String modelContentHashValue = null; - String modelType = null; - String embeddingDimension = null; - FrameworkType frameworkType = null; - String allConfig = null; - String url = null; - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case VERSION_FIELD: - modelVersion = (String) content.get(VERSION_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case MODEL_FORMAT: - modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_TYPE: - modelType = (String) content.get(MODEL_TYPE); - break; - case EMBEDDING_DIMENSION: - embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); - break; - case FRAMEWORK_TYPE: - frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); - break; - case ALL_CONFIG: - allConfig = (String) content.get(ALL_CONFIG); - break; - case MODEL_CONTENT_HASH_VALUE: - modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); - break; - case URL: - url = (String) content.get(URL); - break; - default: - break; + Set requiredKeys = Set.of( + NAME_FIELD, + VERSION_FIELD, + MODEL_FORMAT, + MODEL_GROUP_ID, + MODEL_TYPE, + EMBEDDING_DIMENSION, + FRAMEWORK_TYPE, + MODEL_CONTENT_HASH_VALUE, + URL + ); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, ALL_CONFIG); - } - } - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - if (Stream.of( - modelName, - modelVersion, - modelFormat, - modelGroupId, - modelType, - embeddingDimension, - frameworkType, - modelContentHashValue, - url - ).allMatch(x -> x != null)) { + String modelName = (String) inputs.get(NAME_FIELD); + String modelVersion = (String) inputs.get(VERSION_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + MLModelFormat modelFormat = MLModelFormat.from((String) inputs.get(MODEL_FORMAT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String modelContentHashValue = (String) inputs.get(MODEL_CONTENT_HASH_VALUE); + String modelType = (String) inputs.get(MODEL_TYPE); + String embeddingDimension = (String) inputs.get(EMBEDDING_DIMENSION); + FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE)); + String allConfig = (String) inputs.get(ALL_CONFIG); + String url = (String) inputs.get(URL); - // Create Model configudation + // Create Model configuration TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() .modelType(modelType) .embeddingDimension(Integer.valueOf(embeddingDimension)) @@ -210,12 +168,9 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = mlInputBuilder.build(); mlClient.register(mlInput, actionListener); - } else { - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + registerLocalModelFuture.completeExceptionally(e); } - return registerLocalModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 7e33937bc..cde546d32 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -12,23 +12,20 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -115,48 +112,23 @@ public void onFailure(Exception e) { } }; - String modelName = null; - FunctionName functionName = null; - String modelGroupId = null; - String description = null; - String connectorId = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 - - for (WorkflowData workflowData : data) { - - Map content = workflowData.getContent(); - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case FUNCTION_NAME: - functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; + Set requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD); - } - } - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) { + String modelName = (String) inputs.get(NAME_FIELD); + FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String connectorId = (String) inputs.get(CONNECTOR_ID); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(functionName) @@ -172,12 +144,10 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, actionListener); - } else { - registerRemoteModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); - } + } catch (FlowFrameworkException e) { + registerRemoteModelFuture.completeExceptionally(e); + } return registerRemoteModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index f106ee652..738a31497 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -21,8 +21,8 @@ public interface WorkflowStep { * Triggers the actual processing of the building block. * @param currentNodeId The id of the node executing this step * @param currentNodeInputs Input params and content for this node, from workflow parsing - * @param previousNodeInputs Input params for this node that come from previous steps * @param outputs WorkflowData content of previous steps. + * @param previousNodeInputs Input params for this node that come from previous steps * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. * @throws IOException on a failure. */ diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index a5c4253b3..94fe7b01e 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -9,12 +9,17 @@ package org.opensearch.flowframework.util; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.time.Instant; +import java.util.Map; +import java.util.Set; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -54,4 +59,50 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } + + public void testGetInputsFromPreviousSteps() { + WorkflowData currentNodeInputs = new WorkflowData( + Map.ofEntries(Map.entry("content1", 1), Map.entry("param1", 2), Map.entry("content3", "${{step1.output1}}")), + Map.of("param1", "value1"), + "workflowId", + "nodeId" + ); + Map outputs = Map.ofEntries( + Map.entry( + "step1", + new WorkflowData( + Map.ofEntries(Map.entry("output1", "outputvalue1"), Map.entry("output2", "step1outputvalue2")), + "workflowId", + "step1" + ) + ), + Map.entry("step2", new WorkflowData(Map.of("output2", "step2outputvalue2"), "workflowId", "step2")) + ); + Map previousNodeInputs = Map.of("step2", "output2"); + Set requiredKeys = Set.of("param1", "content1"); + Set optionalKeys = Set.of("output1", "output2", "content3", "no-output"); + + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + assertEquals("value1", inputs.get("param1")); + assertEquals(1, inputs.get("content1")); + assertEquals("outputvalue1", inputs.get("output1")); + assertEquals("step2outputvalue2", inputs.get("output2")); + assertEquals("outputvalue1", inputs.get("content3")); + assertNull(inputs.get("no-output")); + + Set missingRequiredKeys = Set.of("not-here"); + FlowFrameworkException e = assertThrows( + FlowFrameworkException.class, + () -> ParseUtils.getInputsFromPreviousSteps(missingRequiredKeys, optionalKeys, currentNodeInputs, outputs, previousNodeInputs) + ); + assertEquals("Missing required inputs [not-here] in workflow [workflowId] node [nodeId]", e.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, e.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index 3c997a02e..a766d51c9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -13,7 +13,6 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -44,7 +43,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id"); + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); } public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { @@ -86,7 +85,7 @@ public void testNoConnectorIdInOutput() throws IOException { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required field connector_id is not provided", ex.getCause().getMessage()); + assertEquals("Missing required inputs [connector_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } public void testDeleteConnectorFailure() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index d78a97e8a..cc5acbc30 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -41,8 +41,8 @@ import static org.mockito.Mockito.verify; public class ModelGroupStepTests extends OpenSearchTestCase { - private WorkflowData inputData = WorkflowData.EMPTY; - private WorkflowData inputDataWithNoName = WorkflowData.EMPTY; + private WorkflowData inputData; + private WorkflowData inputDataWithNoName; @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -65,7 +65,7 @@ public void setUp() throws Exception { "test-id", "test-node-id" ); - + inputDataWithNoName = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); } public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { @@ -146,7 +146,7 @@ public void testRegisterModelGroupWithNoName() throws IOException { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Model group name is not provided", ex.getCause().getMessage()); + assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index c38f8a120..ffa6d82d1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -237,7 +237,7 @@ public void testRegisterLocalModelTaskFailure() { public void testMissingInputs() { CompletableFuture future = registerLocalModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -245,7 +245,19 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { + "model_format", + "name", + "model_type", + "embedding_dimension", + "framework_type", + "model_group_id", + "version", + "url", + "model_content_hash_value" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index a83443f05..865526b79 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -127,7 +127,7 @@ public void testRegisterRemoteModelFailure() { public void testMissingInputs() { CompletableFuture future = this.registerRemoteModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -135,7 +135,11 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { "name", "function_name", "connector_id" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } } From 1d98b079d46af7960877bde5fe2640407b535ded Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 2 Dec 2023 13:45:38 -0800 Subject: [PATCH 06/27] [Feature/agent_framework] Add Undeploy Model Step (#236) Signed-off-by: Daniel Widdis --- .../flowframework/common/CommonValue.java | 2 + .../workflow/UndeployModelStep.java | 114 +++++++++++++++ .../workflow/WorkflowStepFactory.java | 1 + .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/UndeployModelStepTests.java | 131 ++++++++++++++++++ 5 files changed, 256 insertions(+) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 2343cd305..0863565c0 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -73,6 +73,8 @@ private CommonValue() {} /** The provision workflow thread pool name */ public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; + /** Success name field */ + public static final String SUCCESS = "success"; /** Index name field */ public static final String INDEX_NAME = "index_name"; /** Type field */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java new file mode 100644 index 000000000..cfb683648 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; + +/** + * Step to undeploy model + */ +public class UndeployModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(UndeployModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "undeploy_model"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the undeploy + */ + public UndeployModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture undeployModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { + List failures = mlUndeployModelsResponse.getResponse().failures(); + if (failures.isEmpty()) { + undeployModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(SUCCESS, !mlUndeployModelsResponse.getResponse().hasFailures())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } else { + List failedNodes = failures.stream().map(FailedNodeException::nodeId).collect(Collectors.toList()); + String message = "Failed to undeploy model on nodes " + failedNodes; + logger.error(message); + undeployModelFuture.completeExceptionally(new OpenSearchException(message)); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to unldeploy model"); + undeployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String modelId = inputs.get(MODEL_ID).toString(); + + mlClient.undeploy(new String[] { modelId }, null, actionListener); + } catch (FlowFrameworkException e) { + undeployModelFuture.completeExceptionally(e); + } + return undeployModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index ce0b24d24..4b197d99b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -51,6 +51,7 @@ public WorkflowStepFactory( ); stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); + stepMap.put(UndeployModelStep.NAME, new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index b5d09e8cb..d0b05e9fa 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -83,6 +83,14 @@ "deploy_model_status" ] }, + "undeploy_model": { + "inputs":[ + "model_id" + ], + "outputs":[ + "success" + ] + }, "register_model_group": { "inputs":[ "name" diff --git a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java new file mode 100644 index 000000000..1a5fef445 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class UndeployModelStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testUndeployModel() throws IOException, ExecutionException, InterruptedException { + + String modelId = randomAlphaOfLength(5); + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ClusterName clusterName = new ClusterName("clusterName"); + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( + clusterName, + Collections.emptyList(), + Collections.emptyList() + ); + MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), + Map.of("step_1", MODEL_ID) + ); + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + assertTrue(future.isDone()); + assertTrue((boolean) future.get().getContent().get(SUCCESS)); + } + + public void testNoModelIdInOutput() throws IOException { + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testUndeployModelFailure() throws IOException { + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ClusterName clusterName = new ClusterName("clusterName"); + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( + clusterName, + Collections.emptyList(), + List.of(new FailedNodeException("failed-node", "Test message", null)) + ); + MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); + actionListener.onResponse(output); + + actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), + Map.of("step_1", MODEL_ID) + ); + + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof OpenSearchException); + assertEquals("Failed to undeploy model on nodes [failed-node]", ex.getCause().getMessage()); + } +} From ecf4a60437dae1ad18ece8101be28294799fb3a2 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 2 Dec 2023 18:59:57 -0800 Subject: [PATCH 07/27] [Feature/agent_framework] Actually make the WorkflowStepFactory a Factory (#243) Actually make the WorkflowStepFactory a Factory Signed-off-by: Daniel Widdis --- .../workflow/WorkflowStepFactory.java | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 4b197d99b..1b1875177 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -18,13 +18,14 @@ import java.util.HashMap; import java.util.Map; +import java.util.function.Supplier; /** * Generates instances implementing {@link WorkflowStep}. */ public class WorkflowStepFactory { - private final Map stepMap = new HashMap<>(); + private final Map> stepMap = new HashMap<>(); /** * Instantiate this class. @@ -42,21 +43,21 @@ public WorkflowStepFactory( MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - stepMap.put(NoOpStep.NAME, new NoOpStep()); - stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); - stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); + stepMap.put(NoOpStep.NAME, NoOpStep::new); + stepMap.put(CreateIndexStep.NAME, () -> new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); + stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); stepMap.put( RegisterLocalModelStep.NAME, - new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) + () -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) ); - stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); - stepMap.put(UndeployModelStep.NAME, new UndeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); - stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(ToolStep.NAME, new ToolStep()); - stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); + stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); + stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(ToolStep.NAME, ToolStep::new); + stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient)); } /** @@ -66,7 +67,7 @@ public WorkflowStepFactory( */ public WorkflowStep createStep(String type) { if (stepMap.containsKey(type)) { - return stepMap.get(type); + return stepMap.get(type).get(); } throw new FlowFrameworkException("Workflow step type [" + type + "] is not implemented.", RestStatus.NOT_IMPLEMENTED); } @@ -75,7 +76,7 @@ public WorkflowStep createStep(String type) { * Gets the step map * @return a read-only copy of the step map */ - public Map getStepMap() { + public Map> getStepMap() { return Map.copyOf(this.stepMap); } } From 664b41aa0023deb437062a61516c11ed343da6b6 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 4 Dec 2023 09:54:28 -0800 Subject: [PATCH 08/27] [Feature/agent_framework] Add Delete Model Step (#237) Add Delete Model Step Signed-off-by: Daniel Widdis --- .../workflow/DeleteModelStep.java | 101 ++++++++++++++++ .../workflow/WorkflowStepFactory.java | 1 + .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/DeleteConnectorStepTests.java | 3 +- .../workflow/DeleteModelStepTests.java | 113 ++++++++++++++++++ 5 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java new file mode 100644 index 000000000..44fc5c8d7 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; + +/** + * Step to delete a model for a remote model + */ +public class DeleteModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_model"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(MODEL_ID, deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete model"); + deleteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String modelId = inputs.get(MODEL_ID).toString(); + + mlClient.deleteModel(modelId, actionListener); + } catch (FlowFrameworkException e) { + deleteModelFuture.completeExceptionally(e); + } + return deleteModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 1b1875177..3e8ef2981 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -51,6 +51,7 @@ public WorkflowStepFactory( () -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) ); stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index d0b05e9fa..e3263d9a2 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -75,6 +75,14 @@ "register_model_status" ] }, + "delete_model": { + "inputs": [ + "model_id" + ], + "outputs":[ + "model_id" + ] + }, "deploy_model": { "inputs":[ "model_id" diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index a766d51c9..d94f3d793 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -15,7 +15,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -92,7 +91,7 @@ public void testDeleteConnectorFailure() throws IOException { DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR)); return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java new file mode 100644 index 000000000..59d92e94b --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteModelStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testDeleteModel() throws IOException, ExecutionException, InterruptedException { + + String modelId = randomAlphaOfLength(5); + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String modelIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, modelIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("model_id", modelId), "workflowId", "nodeId")), + Map.of("step_1", "model_id") + ); + verify(machineLearningNodeClient).deleteModel(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get("model_id")); + } + + public void testNoModelIdInOutput() throws IOException { + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testDeleteModelFailure() throws IOException { + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("model_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "model_id") + ); + + verify(machineLearningNodeClient).deleteModel(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete model", ex.getCause().getMessage()); + } +} From 5ce79e0666349a78b3107ccf9b8866f32a4bf10f Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 4 Dec 2023 09:56:45 -0800 Subject: [PATCH 09/27] [Feature/agent_framework] Registers root agent with an agentId in ToolSteps (#242) * Add agentId to parameters map for root agent Signed-off-by: Owais Kazi * Modified ToolStep with new Util method Signed-off-by: Owais Kazi * Integrated RegisterAgentStep with new Util method Signed-off-by: Owais Kazi * Spotless fixes Signed-off-by: Owais Kazi * Removed TODO Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- .../workflow/RegisterAgentStep.java | 155 ++++++++---------- .../flowframework/workflow/ToolStep.java | 94 +++++------ 2 files changed, 112 insertions(+), 137 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 022d46c22..e055433b0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -14,6 +14,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -25,13 +26,12 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; +import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; @@ -101,84 +101,55 @@ public void onFailure(Exception e) { } }; - String name = null; - String type = null; - String description = null; - String llmModelId = null; - Map llmParameters = Collections.emptyMap(); - List tools = new ArrayList<>(); - Map parameters = Collections.emptyMap(); - MLMemorySpec memory = null; - Instant createdTime = null; - Instant lastUpdateTime = null; - String appType = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case TYPE: - type = (String) entry.getValue(); - break; - case LLM_MODEL_ID: - llmModelId = (String) entry.getValue(); - break; - case LLM_PARAMETERS: - llmParameters = getStringToStringMap(entry.getValue(), LLM_PARAMETERS); - break; - case TOOLS_FIELD: - tools = addTools(entry.getValue()); - break; - case PARAMETERS_FIELD: - parameters = getStringToStringMap(entry.getValue(), PARAMETERS_FIELD); - break; - case MEMORY_FIELD: - memory = getMLMemorySpec(entry.getValue()); - break; - case CREATED_TIME: - createdTime = Instant.ofEpochMilli((Long) entry.getValue()); - break; - case LAST_UPDATED_TIME_FIELD: - lastUpdateTime = Instant.ofEpochMilli((Long) entry.getValue()); - break; - case APP_TYPE_FIELD: - appType = (String) entry.getValue(); - break; - default: - break; - } - } - } + Set requiredKeys = Set.of(NAME_FIELD, TYPE); + Set optionalKeys = Set.of( + DESCRIPTION_FIELD, + LLM_MODEL_ID, + LLM_PARAMETERS, + TOOLS_FIELD, + PARAMETERS_FIELD, + MEMORY_FIELD, + CREATED_TIME, + LAST_UPDATED_TIME_FIELD, + APP_TYPE_FIELD + ); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - // Case when modelId is present in previous node inputs - if (llmModelId == null) { - llmModelId = getLlmModelId(previousNodeInputs, outputs); - } + String type = (String) inputs.get(TYPE); + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String llmModelId = (String) inputs.get(LLM_MODEL_ID); + Map llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS); + List tools = getTools(previousNodeInputs, outputs); + Map parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD); + MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); + Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME)); + Instant lastUpdateTime = Instant.ofEpochMilli((Long) inputs.get(LAST_UPDATED_TIME_FIELD)); + String appType = (String) inputs.get(APP_TYPE_FIELD); + + // Case when modelId is present in previous node inputs + if (llmModelId == null) { + llmModelId = getLlmModelId(previousNodeInputs, outputs); + } - // Case when modelId is not present at all - if (llmModelId == null) { - registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) - ); - return registerAgentModelFuture; - } + // Case when modelId is not present at all + if (llmModelId == null) { + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + ); + return registerAgentModelFuture; + } - LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); - if (Stream.of(name, type, llmSpec).allMatch(x -> x != null)) { MLAgentBuilder builder = MLAgent.builder().name(name); if (description != null) { @@ -198,12 +169,9 @@ public void onFailure(Exception e) { mlClient.registerAgent(mlAgent, actionListener); - } else { - registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + registerAgentModelFuture.completeExceptionally(e); } - return registerAgentModelFuture; } @@ -212,9 +180,24 @@ public String getName() { return NAME; } - private List addTools(Object tools) { - MLToolSpec mlToolSpec = (MLToolSpec) tools; - mlToolSpecList.add(mlToolSpec); + private List getTools(Map previousNodeInputs, Map outputs) { + List mlToolSpecList = new ArrayList<>(); + List previousNodes = previousNodeInputs.entrySet() + .stream() + .filter(e -> TOOLS_FIELD.equals(e.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + if (previousNodes != null) { + previousNodes.forEach((previousNode) -> { + WorkflowData previousNodeOutput = outputs.get(previousNode); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) { + MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD); + logger.info("Tool added {}", mlToolSpec.getType()); + mlToolSpecList.add(mlToolSpec); + } + }); + } return mlToolSpecList; } @@ -240,7 +223,7 @@ private String getLlmModelId(Map previousNodeInputs, Map llmParameters) { if (llmModelId == null) { - throw new IllegalArgumentException("model id for llm is null"); + throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST); } LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); builder.modelId(llmModelId); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index af8556289..f12d9848e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -10,19 +10,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -47,49 +45,25 @@ public CompletableFuture execute( Map outputs, Map previousNodeInputs ) throws IOException { - String type = null; - String name = null; - String description = null; - Map parameters = Collections.emptyMap(); - Boolean includeOutputInAgentResponse = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case TYPE: - type = (String) entry.getValue(); - break; - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case PARAMETERS_FIELD: - parameters = getToolsParametersMap(entry.getValue(), previousNodeInputs, outputs); - break; - case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: - includeOutputInAgentResponse = (Boolean) entry.getValue(); - break; - default: - break; - } - } + Set requiredKeys = Set.of(TYPE); + Set optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE); - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String type = (String) inputs.get(TYPE); + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + Map parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs); - if (type == null) { - toolFuture.completeExceptionally(new FlowFrameworkException("Tool type is not provided", RestStatus.BAD_REQUEST)); - } else { MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); builder.type(type); @@ -115,9 +89,12 @@ public CompletableFuture execute( currentNodeInputs.getNodeId() ) ); - } - logger.info("Tool registered successfully {}", type); + logger.info("Tool registered successfully {}", type); + + } catch (FlowFrameworkException e) { + toolFuture.completeExceptionally(e); + } return toolFuture; } @@ -132,19 +109,34 @@ private Map getToolsParametersMap( Map outputs ) { Map parametersMap = (Map) parameters; - Optional previousNode = previousNodeInputs.entrySet() + Optional previousNodeModel = previousNodeInputs.entrySet() .stream() .filter(e -> MODEL_ID.equals(e.getValue())) .map(Map.Entry::getKey) .findFirst(); + + Optional previousNodeAgent = previousNodeInputs.entrySet() + .stream() + .filter(e -> AGENT_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + // Case when modelId is passed through previousSteps and not present already in parameters - if (previousNode.isPresent() && !parametersMap.containsKey(MODEL_ID)) { - WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get()); if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString()); - return parametersMap; } } + + // Case when agentId is passed through previousSteps and not present already in parameters + if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) { + WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) { + parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString()); + } + } + // For other cases where modelId is already present in the parameters or not return the parametersMap return parametersMap; } From a047038157f7280e06ee48f366ace72f64d7f5ed Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 4 Dec 2023 15:48:26 -0800 Subject: [PATCH 10/27] Updating state index after register agent (#250) Adding state index update on agent Signed-off-by: Amit Galitzky --- .../common/WorkflowResources.java | 4 +- .../workflow/RegisterAgentStep.java | 39 ++++++++++++++++++- .../workflow/WorkflowStepFactory.java | 2 +- .../workflow/RegisterAgentTests.java | 27 +++++++++++-- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index 04a8650b2..126bfbeb7 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -33,7 +33,9 @@ public enum WorkflowResources { /** official workflow step name for creating an ingest-pipeline and associated created resource */ CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), /** official workflow step name for creating an index and associated created resource */ - CREATE_INDEX("create_index", "index_name"); + CREATE_INDEX("create_index", "index_name"), + /** official workflow step name for register an agent and the associated created resource */ + REGISTER_AGENT("register_agent", "agent_id"); private final String workflowStep; private final String resourceCreated; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index e055433b0..06c97f8d4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.agent.LLMSpec; @@ -54,8 +56,9 @@ public class RegisterAgentStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterAgentStep.class); private MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; - static final String NAME = "register_agent"; + static final String NAME = WorkflowResources.REGISTER_AGENT.getWorkflowStep(); private static final String LLM_MODEL_ID = "llm.model_id"; private static final String LLM_PARAMETERS = "llm.parameters"; @@ -65,10 +68,12 @@ public class RegisterAgentStep implements WorkflowStep { /** * Instantiate this class * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public RegisterAgentStep(MachineLearningNodeClient mlClient) { + public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; this.mlToolSpecList = new ArrayList<>(); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -92,6 +97,36 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { currentNodeInputs.getNodeId() ) ); + + try { + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + logger.info("Created connector successfully"); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + getName(), + mlRegisterAgentResponse.getAgentId(), + ActionListener.wrap(response -> { + logger.info("successfully updated resources created in state index: {}", response.getIndex()); + registerAgentModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + registerAgentModelFuture.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3e8ef2981..b6803b664 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -58,7 +58,7 @@ public WorkflowStepFactory( stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, ToolStep::new); - stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient)); + stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); } /** diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index c393be5e4..115729e9c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -8,9 +8,12 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -29,8 +32,12 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; public class RegisterAgentTests extends OpenSearchTestCase { @@ -39,10 +46,12 @@ public class RegisterAgentTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); - + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); @@ -76,7 +85,7 @@ public void setUp() throws Exception { public void testRegisterAgent() throws IOException, ExecutionException, InterruptedException { String agentId = "agent_id"; - RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -88,6 +97,12 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup return null; }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = registerAgentStep.execute( inputData.getNodeId(), inputData, @@ -103,7 +118,7 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup public void testRegisterAgentFailure() throws IOException { String agentId = "agent_id"; - RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient); + RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -114,6 +129,12 @@ public void testRegisterAgentFailure() throws IOException { return null; }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + CompletableFuture future = registerAgentStep.execute( inputData.getNodeId(), inputData, From e13c8a9de9b97bb1cc2b11eabff9e98bb39b40ee Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 4 Dec 2023 16:28:01 -0800 Subject: [PATCH 11/27] [Feature/agent_framework] Added Retry functionality for Deploy Model (#245) * Added retry functionality for DeployModel Signed-off-by: Owais Kazi * Fixed timeout and exception issues Signed-off-by: Owais Kazi * Addressed PR Comments Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi --- .../common/WorkflowResources.java | 2 + .../flowframework/model/WorkflowNode.java | 2 +- .../AbstractRetryableWorkflowStep.java | 110 ++++++++++++- .../workflow/DeployModelStep.java | 29 ++-- .../workflow/ModelGroupStep.java | 10 +- .../workflow/RegisterLocalModelStep.java | 96 ++---------- .../workflow/WorkflowStepFactory.java | 2 +- .../workflow/DeployModelStepTests.java | 148 ++++++++++++++++-- .../workflow/RegisterLocalModelStepTests.java | 2 +- .../workflow/WorkflowProcessSorterTests.java | 2 +- 10 files changed, 287 insertions(+), 116 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index 126bfbeb7..d43a9e0f9 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -30,6 +30,8 @@ public enum WorkflowResources { REGISTER_LOCAL_MODEL("register_local_model", "model_id"), /** official workflow step name for registering a model group and associated created resource */ REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + /** official workflow step name for deploying a model and associated created resource */ + DEPLOY_MODEL("deploy_model", "model_id"), /** official workflow step name for creating an ingest-pipeline and associated created resource */ CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), /** official workflow step name for creating an index and associated created resource */ diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index b942ccb16..42d59e07f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -47,7 +47,7 @@ public class WorkflowNode implements ToXContentObject { /** The field defining the timeout value for this node */ public static final String NODE_TIMEOUT_FIELD = "node_timeout"; /** The default timeout value if the template doesn't override it */ - public static final String NODE_TIMEOUT_DEFAULT_VALUE = "10s"; + public static final String NODE_TIMEOUT_DEFAULT_VALUE = "15s"; private final String id; // unique id private final String type; // maps to a WorkflowStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index 799edabb9..48b7a0042 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -8,27 +8,133 @@ */ package org.opensearch.flowframework.workflow; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.FutureUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; /** * Abstract retryable workflow step */ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep { - + private static final Logger logger = LogManager.getLogger(AbstractRetryableWorkflowStep.class); /** The maximum number of transport request retries */ protected volatile Integer maxRetry; + private final MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiates a new Retryable workflow step * @param settings Environment settings * @param clusterService the cluster service + * @param mlClient machine learning client + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public AbstractRetryableWorkflowStep(Settings settings, ClusterService clusterService) { + public AbstractRetryableWorkflowStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it); + this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + } + + /** + * Retryable get ml task + * @param workflowId the workflow id + * @param nodeId the workflow node id + * @param future the workflow step future + * @param taskId the ml task id + * @param retries the current number of request retries + * @param workflowStep the workflow step which requires a retry get ml task functionality + */ + void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture future, + String taskId, + int retries, + String workflowStep + ) { + mlClient.getTask(taskId, ActionListener.wrap(response -> { + MLTaskState currentState = response.getState(); + if (currentState != MLTaskState.COMPLETED) { + if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { + // Model registration failed or completed with errors + String errorMessage = workflowStep + " failed with error : " + response.getError(); + logger.error(errorMessage); + future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + } else { + // Task still in progress, attempt retry + throw new IllegalStateException(workflowStep + " is not yet completed"); + } + } else { + try { + logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + response.getTaskId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + future.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + } + }, exception -> { + if (retries < maxRetry) { + // Sleep thread prior to retrying request + try { + Thread.sleep(5000); + } catch (Exception e) { + FutureUtils.cancel(future); + } + retryableGetMlTask(workflowId, nodeId, future, taskId, retries + 1, workflowStep); + } else { + logger.error("Failed to retrieve" + workflowStep + ",maximum retries exceeded"); + future.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + })); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index f878fbdc2..b9307b046 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -11,8 +11,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -27,18 +30,29 @@ /** * Step to deploy a model */ -public class DeployModelStep implements WorkflowStep { +public class DeployModelStep extends AbstractRetryableWorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); private final MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; static final String NAME = "deploy_model"; /** * Instantiate this class + * @param settings The OpenSearch settings + * @param clusterService The cluster service * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public DeployModelStep(MachineLearningNodeClient mlClient) { + public DeployModelStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { + super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -55,13 +69,10 @@ public CompletableFuture execute( @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); - deployModelFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + String taskId = mlDeployModelResponse.getTaskId(); + + // Attempt to retrieve the model ID + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, 0, "Deploy model"); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index e2aea19df..fbf907776 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -72,7 +73,7 @@ public CompletableFuture execute( @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { try { - logger.info("Remote Model registration successful"); + logger.info("Model group registration successful"); String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), @@ -134,7 +135,7 @@ public void onFailure(Exception e) { if (description != null) { builder.description(description); } - if (!backendRoles.isEmpty()) { + if (!CollectionUtils.isEmpty(backendRoles)) { builder.backendRoles(backendRoles); } if (modelAccessMode != null) { @@ -160,6 +161,9 @@ public String getName() { @SuppressWarnings("unchecked") private List getBackendRoles(Map content) { - return (List) content.get(BACKEND_ROLES_FIELD); + if (content.containsKey(BACKEND_ROLES_FIELD)) { + return (List) content.get(BACKEND_ROLES_FIELD); + } + return null; } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index fb1d383b5..4c01e8fb8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -13,15 +13,12 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -34,7 +31,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -45,7 +41,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; @@ -75,7 +70,7 @@ public RegisterLocalModelStep( MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - super(settings, clusterService); + super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -98,7 +93,14 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { String taskId = mlRegisterModelResponse.getTaskId(); // Attempt to retrieve the model ID - retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, registerLocalModelFuture, taskId, 0); + retryableGetMlTask( + currentNodeInputs.getWorkflowId(), + currentNodeId, + registerLocalModelFuture, + taskId, + 0, + "Local model registration" + ); } @Override @@ -178,84 +180,4 @@ public void onFailure(Exception e) { public String getName() { return NAME; } - - /** - * Retryable get ml task - * @param workflowId the workflow id - * @param nodeId the workflow node id - * @param registerLocalModelFuture the workflow step future - * @param taskId the ml task id - * @param retries the current number of request retries - */ - void retryableGetMlTask( - String workflowId, - String nodeId, - CompletableFuture registerLocalModelFuture, - String taskId, - int retries - ) { - mlClient.getTask(taskId, ActionListener.wrap(response -> { - MLTaskState currentState = response.getState(); - if (currentState != MLTaskState.COMPLETED) { - if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { - // Model registration failed or completed with errors - String errorMessage = "Local model registration failed with error : " + response.getError(); - logger.error(errorMessage); - registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } else { - // Task still in progress, attempt retry - throw new IllegalStateException("Local model registration is not yet completed"); - } - } else { - try { - logger.info("Local Model registration successful"); - String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - response.getTaskId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - registerLocalModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); - }, exception -> { - logger.error("Failed to update new created resource", exception); - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - logger.error("Failed to parse and update new created resource", e); - registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - } - }, exception -> { - if (retries < maxRetry) { - // Sleep thread prior to retrying request - try { - Thread.sleep(5000); - } catch (Exception e) { - FutureUtils.cancel(registerLocalModelFuture); - } - final int retryAdd = retries + 1; - retryableGetMlTask(workflowId, nodeId, registerLocalModelFuture, taskId, retryAdd); - } else { - logger.error("Failed to retrieve local model registration task, maximum retries exceeded"); - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - } - })); - } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index b6803b664..3a2ccb4e1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -52,7 +52,7 @@ public WorkflowStepFactory( ); stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); - stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); + stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler)); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 670933373..fa27142f1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,27 +10,49 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.mockito.ArgumentMatchers.eq; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { @@ -40,22 +62,37 @@ public class DeployModelStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + private DeployModelStep deployModel; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); + ClusterService clusterService = mock(ClusterService.class); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); - MockitoAnnotations.openMocks(this); + // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases + Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); + ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + this.deployModel = new DeployModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient, flowFrameworkIndicesHandler); + this.inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); } - public void testDeployModel() throws ExecutionException, InterruptedException { + public void testDeployModel() throws ExecutionException, InterruptedException, IOException { + String modelId = "modelId"; String taskId = "taskId"; - String status = MLTaskState.CREATED.name(); - MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + String status = MLTaskState.COMPLETED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -65,7 +102,36 @@ public void testDeployModel() throws ExecutionException, InterruptedException { MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); CompletableFuture future = deployModel.execute( inputData.getNodeId(), @@ -74,15 +140,19 @@ public void testDeployModel() throws ExecutionException, InterruptedException { Collections.emptyMap() ); - verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertTrue(future.isDone()); - assertEquals(status, future.get().getContent().get("deploy_model_status")); + assertFalse(future.isCompletedExceptionally()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testDeployModelFailure() { - DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + String modelId = "modelId"; + String taskId = "taskId"; @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -106,4 +176,60 @@ public void testDeployModelFailure() { assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to deploy model", ex.getCause().getMessage()); } + + public void testDeployModelTaskFailure() throws IOException { + String modelId = "modelId"; + String taskId = "taskId"; + + String status = MLTaskState.RUNNING.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + String testErrorMessage = "error"; + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.FAILED, + null, + null, + null, + null, + null, + null, + testErrorMessage, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + CompletableFuture future = this.deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Deploy model failed with error : " + testErrorMessage, ex.getCause().getMessage()); + + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index ffa6d82d1..afd90786f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -156,7 +156,7 @@ public void testRegisterLocalModelSuccess() throws Exception { verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertTrue(future.isDone()); - assertTrue(!future.isCompletedExceptionally()); + assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e9e792add..8103f4fbf 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -118,7 +118,7 @@ public void testNodeDetails() throws IOException { ProcessNode node = workflow.get(0); assertEquals("default_timeout", node.id()); assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass()); - assertEquals(10, node.nodeTimeout().seconds()); + assertEquals(15, node.nodeTimeout().seconds()); node = workflow.get(1); assertEquals("custom_timeout", node.id()); assertEquals(CreateIndexStep.class, node.workflowStep().getClass()); From d82da9ae85cb753462cccb78a8224b79e7498470 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 4 Dec 2023 17:14:57 -0800 Subject: [PATCH 12/27] Only update state index on register agent (#253) * fixing status api bug with new deploy task Signed-off-by: Amit Galitzky * changed to getName() Signed-off-by: Amit Galitzky --------- Signed-off-by: Amit Galitzky --- .../ProvisionWorkflowTransportAction.java | 2 +- .../AbstractRetryableWorkflowStep.java | 63 +++++++++++-------- .../workflow/RegisterAgentStep.java | 12 +--- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index da9643cb5..cd4a54a57 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -234,7 +234,7 @@ private void executeWorkflow(List workflowSequence, String workflow Instant.now().toEpochMilli() ), ActionListener.wrap(updateResponse -> { - logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + logger.info("updated workflow {} state to {}", workflowId, State.FAILED); }, exceptionState -> { logger.error("Failed to update workflow state : {}", exceptionState.getMessage(), ex); }) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index 48b7a0042..f807c752a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import java.util.Map; @@ -58,6 +59,24 @@ public AbstractRetryableWorkflowStep( this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } + /** + * Completes the future for either deploy or register local model step + * @param resourceName resource name for the given step + * @param nodeId node ID of the given step + * @param workflowId workflow ID of the given workflow + * @param response Response from ml commons get Task API + * @param future CompletableFuture of the given step + */ + public void completeFuture(String resourceName, String nodeId, String workflowId, MLTask response, CompletableFuture future) { + future.complete( + new WorkflowData( + Map.ofEntries(Map.entry(resourceName, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), + workflowId, + nodeId + ) + ); + } + /** * Retryable get ml task * @param workflowId the workflow id @@ -91,31 +110,25 @@ void retryableGetMlTask( try { logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - response.getTaskId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - future.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); - }, exception -> { - logger.error("Failed to update new created resource", exception); - future.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); - + if (getName().equals(WorkflowResources.DEPLOY_MODEL.getWorkflowStep())) { + completeFuture(resourceName, nodeId, workflowId, response, future); + } else { + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + response.getTaskId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + completeFuture(resourceName, nodeId, workflowId, response, future); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + } } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 06c97f8d4..0e3a1c7c6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -35,7 +35,6 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; -import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; import static org.opensearch.flowframework.common.CommonValue.CREATED_TIME; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -89,18 +88,9 @@ public CompletableFuture execute( ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { - logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); - registerAgentModelFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); - try { String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); - logger.info("Created connector successfully"); + logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), currentNodeId, From ac507afc2cc8093e9b73f0ed4f89922f020dc036 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 5 Dec 2023 21:08:44 -0800 Subject: [PATCH 13/27] [Feature/agent_framework] Add Delete Agent Step (#246) Delete Agent Step Signed-off-by: Daniel Widdis --- .../workflow/DeleteAgentStep.java | 100 ++++++++++++++++ .../workflow/WorkflowStepFactory.java | 1 + .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/DeleteAgentStepTests.java | 113 ++++++++++++++++++ 4 files changed, 222 insertions(+) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java new file mode 100644 index 000000000..d97b4ed28 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; + +/** + * Step to delete a agent for a remote model + */ +public class DeleteAgentStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteAgentStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_agent"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteAgentStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteAgentFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteAgentFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry("agent_id", deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete agent"); + deleteAgentFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(AGENT_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + String agentId = (String) inputs.get(AGENT_ID); + + mlClient.deleteAgent(agentId, actionListener); + } catch (FlowFrameworkException e) { + deleteAgentFuture.completeExceptionally(e); + } + return deleteAgentFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3a2ccb4e1..c2e55b100 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -59,6 +59,7 @@ public WorkflowStepFactory( stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, ToolStep::new); stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index e3263d9a2..149b1cfce 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -123,6 +123,14 @@ "agent_id" ] }, + "delete_agent": { + "inputs": [ + "agent_id" + ], + "outputs":[ + "agent_id" + ] + }, "create_tool": { "inputs": [ "type" diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java new file mode 100644 index 000000000..a893b8928 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteAgentStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testDeleteAgent() throws IOException, ExecutionException, InterruptedException { + + String agentId = randomAlphaOfLength(5); + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String agentIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, agentIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("agent_id", agentId), "workflowId", "nodeId")), + Map.of("step_1", "agent_id") + ); + verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(agentId, future.get().getContent().get("agent_id")); + } + + public void testNoAgentIdInOutput() throws IOException { + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [agent_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testDeleteAgentFailure() throws IOException { + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete agent", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("agent_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "agent_id") + ); + + verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete agent", ex.getCause().getMessage()); + } +} From fe43586b1d3fe4ea5cfba2b37054185f11bd566c Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Thu, 7 Dec 2023 16:17:40 -0800 Subject: [PATCH 14/27] Includ workflow id and current node id in the exception message (#262) Includ workflow id and current node id in the exception message during registe agent step Signed-off-by: Jackie Han --- .../workflow/RegisterAgentStep.java | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 0e3a1c7c6..f9f17b4f0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -83,6 +83,8 @@ public CompletableFuture execute( Map previousNodeInputs ) throws IOException { + String workflowId = currentNodeInputs.getWorkflowId(); + CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @@ -92,7 +94,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); logger.info("Agent registration successful for the agent {}", mlRegisterAgentResponse.getAgentId()); flowFrameworkIndicesHandler.updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), + workflowId, currentNodeId, getName(), mlRegisterAgentResponse.getAgentId(), @@ -101,7 +103,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { registerAgentModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), - currentNodeInputs.getWorkflowId(), + workflowId, currentNodeId ) ); @@ -168,12 +170,15 @@ public void onFailure(Exception e) { // Case when modelId is not present at all if (llmModelId == null) { registerAgentModelFuture.completeExceptionally( - new FlowFrameworkException("llm model id is not provided", RestStatus.BAD_REQUEST) + new FlowFrameworkException( + "llm model id is not provided for workflow: " + workflowId + " on node: " + currentNodeId, + RestStatus.BAD_REQUEST + ) ); return registerAgentModelFuture; } - LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters); + LLMSpec llmSpec = getLLMSpec(llmModelId, llmParameters, workflowId, currentNodeId); MLAgentBuilder builder = MLAgent.builder().name(name); @@ -246,9 +251,12 @@ private String getLlmModelId(Map previousNodeInputs, Map llmParameters) { + private LLMSpec getLLMSpec(String llmModelId, Map llmParameters, String workflowId, String currentNodeId) { if (llmModelId == null) { - throw new FlowFrameworkException("model id for llm is null", RestStatus.BAD_REQUEST); + throw new FlowFrameworkException( + "model id for llm is null for workflow: " + workflowId + " on node: " + currentNodeId, + RestStatus.BAD_REQUEST + ); } LLMSpec.LLMSpecBuilder builder = LLMSpec.builder(); builder.modelId(llmModelId); From 1b37073cfd39b44c36a608c9dad3971ce27822f5 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 11 Dec 2023 16:04:00 -0800 Subject: [PATCH 15/27] Change thread queue to 100 and fix headers parsing bug (#265) change thread queue to 100 and fix headers bug Signed-off-by: Amit Galitzky --- .../flowframework/FlowFrameworkPlugin.java | 2 +- .../flowframework/model/WorkflowNode.java | 8 ++-- .../flowframework/util/ParseUtils.java | 43 +++++++++++++++++++ .../flowframework/util/ParseUtilsTests.java | 10 +++++ 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 14df7e17e..a1c75043d 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -158,7 +158,7 @@ public List> getExecutorBuilders(Settings settings) { settings, PROVISION_THREAD_POOL, OpenSearchExecutors.allocatedProcessors(settings), - 10, + 100, FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL ) ); diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 42d59e07f..706cd2c62 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,7 +24,9 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** @@ -93,7 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } else { for (Map map : (Map[]) e.getValue()) { - buildStringToStringMap(xContentBuilder, map); + buildStringToObjectMap(xContentBuilder, map); } } xContentBuilder.endArray(); @@ -150,9 +152,9 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { } userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); } else { - List> mapList = new ArrayList<>(); + List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - mapList.add(parseStringToStringMap(parser)); + mapList.add(parseStringToObjectMap(parser)); } userInputs.put(inputFieldName, mapList.toArray(new Map[0])); } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 9e3b8d067..6e1a506a1 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -84,6 +84,25 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map map) throws IOException { + xContentBuilder.startObject(); + for (Entry e : map.entrySet()) { + if (e.getValue() instanceof String) { + xContentBuilder.field((String) e.getKey(), (String) e.getValue()); + } else { + xContentBuilder.field((String) e.getKey(), e.getValue()); + } + } + xContentBuilder.endObject(); + } + /** * Builds an XContent object representing a LLMSpec. * @@ -117,6 +136,30 @@ public static Map parseStringToStringMap(XContentParser parser) return map; } + /** + * Parses an XContent object representing a map of String keys to Object values. + * The Object value here can either be a string or a map + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return A map as identified by the key-value pairs in the XContent + * @throws IOException on a parse failure + */ + public static Map parseStringToObjectMap(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Map map = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + // If the current token is a START_OBJECT, parse it as Map + map.put(fieldName, parseStringToStringMap(parser)); + } else { + // Otherwise, parse it as a string + map.put(fieldName, parser.text()); + } + } + return map; + } + /** * Parse content parser to {@link java.time.Instant}. * diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 94fe7b01e..02222b9aa 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -60,6 +60,16 @@ public void testToInstantWithNotValue() throws IOException { assertNull(instant); } + public void testBuildAndParseStringToStringMap() throws IOException { + Map stringMap = Map.ofEntries(Map.entry("one", "two")); + XContentBuilder builder = XContentFactory.jsonBuilder(); + ParseUtils.buildStringToStringMap(builder, stringMap); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + Map parsedMap = ParseUtils.parseStringToStringMap(parser); + assertEquals(stringMap.get("one"), parsedMap.get("one")); + } + public void testGetInputsFromPreviousSteps() { WorkflowData currentNodeInputs = new WorkflowData( Map.ofEntries(Map.entry("content1", 1), Map.entry("param1", 2), Map.entry("content3", "${{step1.output1}}")), From 9c230923c850cb6eb394b22b387f409bf6315dcd Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Mon, 11 Dec 2023 17:46:56 -0800 Subject: [PATCH 16/27] Update resources_created with deploy model: (#275) add deploy model resource Signed-off-by: Amit Galitzky --- .../AbstractRetryableWorkflowStep.java | 62 ++++++++----------- 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index f807c752a..121f477bb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -20,7 +20,6 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import java.util.Map; @@ -59,24 +58,6 @@ public AbstractRetryableWorkflowStep( this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } - /** - * Completes the future for either deploy or register local model step - * @param resourceName resource name for the given step - * @param nodeId node ID of the given step - * @param workflowId workflow ID of the given workflow - * @param response Response from ml commons get Task API - * @param future CompletableFuture of the given step - */ - public void completeFuture(String resourceName, String nodeId, String workflowId, MLTask response, CompletableFuture future) { - future.complete( - new WorkflowData( - Map.ofEntries(Map.entry(resourceName, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), - workflowId, - nodeId - ) - ); - } - /** * Retryable get ml task * @param workflowId the workflow id @@ -110,25 +91,36 @@ void retryableGetMlTask( try { logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + String id; if (getName().equals(WorkflowResources.DEPLOY_MODEL.getWorkflowStep())) { - completeFuture(resourceName, nodeId, workflowId, response, future); + id = response.getModelId(); } else { - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - response.getTaskId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - completeFuture(resourceName, nodeId, workflowId, response, future); - }, exception -> { - logger.error("Failed to update new created resource", exception); - future.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); + id = response.getTaskId(); } + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + id, + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + future.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); From 16b0a59729a1367ce97b8353b4cf71ee7f230451 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 11 Dec 2023 19:04:56 -0800 Subject: [PATCH 17/27] [feature/agent_framework] Changing resources created format (#231) * adding new resources created format and adding enum for resource types Signed-off-by: Amit Galitzky * remove spotless from java 17 Signed-off-by: Amit Galitzky * add action listener to update resource created Signed-off-by: Amit Galitzky * fixing UT Signed-off-by: Amit Galitzky * changed exception type Signed-off-by: Amit Galitzky --------- Signed-off-by: Daniel Widdis --- .../flowframework/FlowFrameworkPlugin.java | 4 ++- .../common/FlowFrameworkSettings.java | 12 ++++++++ .../workflow/WorkflowProcessSorter.java | 29 ++++++++++++++++++- .../FlowFrameworkPluginTests.java | 5 ++-- .../CreateWorkflowTransportActionTests.java | 20 ++++++++++++- .../workflow/WorkflowProcessSorterTests.java | 25 ++++++++++++++-- 6 files changed, 87 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index a1c75043d..513984c68 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -63,6 +63,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -106,7 +107,7 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @@ -144,6 +145,7 @@ public List> getSettings() { List> settings = ImmutableList.of( FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY ); diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 1824197e8..536fa2c73 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {} /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; + /** The upper limit of max workflow steps that can be in a single workflow */ + public static final int MAX_WORKFLOW_STEPS_LIMIT = 500; /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOWS = Setting.intSetting( @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {} Setting.Property.Dynamic ); + /** This setting sets max workflows that can be created */ + public static final Setting MAX_WORKFLOW_STEPS = Setting.intSetting( + "plugins.flow_framework.max_workflow_steps", + 50, + 1, + MAX_WORKFLOW_STEPS_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + /** This setting sets the timeout for the request */ public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( "plugins.flow_framework.request_timeout", diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3e8b77f9d..da362383b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -32,6 +34,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; @@ -45,16 +48,26 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; + private Integer maxWorkflowSteps; /** * Instantiate this class. * * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. + * @param clusterService The OpenSearch cluster service. + * @param settings OpenSerch settings */ - public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { + public WorkflowProcessSorter( + WorkflowStepFactory workflowStepFactory, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings + ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; + this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } /** @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ public List sortProcessNodes(Workflow workflow, String workflowId) { + if (workflow.nodes().size() > this.maxWorkflowSteps) { + throw new FlowFrameworkException( + "Workflow " + + workflowId + + " has " + + workflow.nodes().size() + + " nodes, which exceeds the maximum of " + + this.maxWorkflowSteps + + ". Change the setting [" + + MAX_WORKFLOW_STEPS.getKey() + + "] to increase this.", + RestStatus.BAD_REQUEST + ); + } List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index e3827e0b3..2585ffb09 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,7 +63,7 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -84,7 +85,7 @@ public void testPlugin() throws IOException { assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(4, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); - assertEquals(4, ffp.getSettings().size()); + assertEquals(5, ffp.getSettings().size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index bec06b3a8..6856a2122 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -13,6 +13,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -33,14 +36,20 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -60,6 +69,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; + private ClusterSettings clusterSettings; + private ClusterService clusterService; private Settings settings; @Override @@ -70,8 +81,15 @@ public void setUp() throws Exception { .put("plugins.flow_framework.max_workflows", 2) .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) .build(); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); + clusterSettings = new ClusterSettings(settings, settingsSet); + clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 8103f4fbf..d1590acd8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -32,6 +33,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -41,6 +43,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; @@ -79,11 +82,12 @@ public static void setup() { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.admin()).thenReturn(adminClient); @@ -96,7 +100,7 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); } @AfterClass @@ -245,6 +249,21 @@ public void testExceptions() throws IOException { ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("A")), Collections.emptyList()))); assertEquals("Duplicate node id A.", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); + + ex = assertThrows( + FlowFrameworkException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C"), node("D"), node("E"), node("F")), Collections.emptyList())) + ); + String message = String.format( + Locale.ROOT, + "Workflow %s has %d nodes, which exceeds the maximum of %d. Change the setting [%s] to increase this.", + "123", + 6, + 5, + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + ); + assertEquals(message, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testSuccessfulGraphValidation() throws Exception { From 53070a1e44bfcb8c9eb52ac1f7268bc828a597c3 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 12 Dec 2023 17:17:12 -0800 Subject: [PATCH 18/27] [Feature/agent_framework] Add Get Workflow API to retrieve a stored template by workflow id (#273) * renaming status API implementation Signed-off-by: Joshua Palis * Adding GetWorkflow API Signed-off-by: Joshua Palis * addressing PR comments Signed-off-by: Joshua Palis * Adding todo reminder Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 5 + .../rest/RestGetWorkflowAction.java | 26 ++- .../rest/RestGetWorkflowStateAction.java | 109 ++++++++++++ .../transport/GetWorkflowResponse.java | 45 +++-- .../transport/GetWorkflowStateAction.java | 29 +++ ...uest.java => GetWorkflowStateRequest.java} | 12 +- .../transport/GetWorkflowStateResponse.java | 67 +++++++ .../GetWorkflowStateTransportAction.java | 99 +++++++++++ .../transport/GetWorkflowTransportAction.java | 97 +++++----- .../FlowFrameworkPluginTests.java | 4 +- .../rest/RestGetWorkflowActionTests.java | 38 ++-- .../rest/RestGetWorkflowStateActionTests.java | 104 +++++++++++ .../GetWorkflowStateTransportActionTests.java | 127 ++++++++++++++ .../GetWorkflowTransportActionTests.java | 166 +++++++++++------- 14 files changed, 742 insertions(+), 186 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java rename src/main/java/org/opensearch/flowframework/transport/{GetWorkflowRequest.java => GetWorkflowStateRequest.java} (83%) create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 513984c68..ec9eb40da 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -29,11 +29,14 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowAction; +import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateTransportAction; import org.opensearch.flowframework.transport.GetWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; @@ -126,6 +129,7 @@ public List getRestHandlers( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting) ); } @@ -136,6 +140,7 @@ public List getRestHandlers( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), + new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 6d9d5e3b5..5a92e9c0e 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -20,7 +20,7 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowAction; -import org.opensearch.flowframework.transport.GetWorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -34,7 +34,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** - * Rest Action to facilitate requests to get a workflow status + * Rest Action to facilitate requests to get a stored template */ public class RestGetWorkflowAction extends BaseRestHandler { @@ -55,6 +55,11 @@ public String getName() { return GET_WORKFLOW_ACTION; } + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, WORKFLOW_ID))); + } + @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { @@ -68,7 +73,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); } // Validate params String workflowId = request.param(WORKFLOW_ID); @@ -76,9 +81,8 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - boolean all = request.paramAsBoolean("all", false); - GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest(workflowId, all); - return channel -> client.execute(GetWorkflowAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { @@ -88,7 +92,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); } catch (IOException e) { - logger.error("Failed to send back provision workflow exception", e); + logger.error("Failed to send back get workflow exception", e); channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); } })); @@ -99,12 +103,4 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request ); } } - - @Override - public List routes() { - return ImmutableList.of( - // Provision workflow from indexed use case template - new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) - ); - } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java new file mode 100644 index 000000000..ab7335b2d --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to get a workflow status + */ +public class RestGetWorkflowStateAction extends BaseRestHandler { + + private static final String GET_WORKFLOW_STATE_ACTION = "get_workflow_state"; + private static final Logger logger = LogManager.getLogger(RestGetWorkflowStateAction.class); + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestGetWorkflowStateAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestGetWorkflowStateAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return GET_WORKFLOW_STATE_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + boolean all = request.paramAsBoolean("all", false); + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all); + return channel -> client.execute(GetWorkflowStateAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java index 922a8a3f5..db70d2cb2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -13,55 +13,52 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.model.Template; import java.io.IOException; /** - * Transport Response from getting a workflow status + * Transport Response from getting a template */ public class GetWorkflowResponse extends ActionResponse implements ToXContentObject { - /** The workflow state */ - public WorkflowState workflowState; - /** Flag to indicate if the entire state should be returned */ - public boolean allStatus; + /** The template */ + private Template template; /** * Instantiates a new GetWorkflowResponse from an input stream * @param in the input stream to read from - * @throws IOException if the workflowId cannot be read from the input stream + * @throws IOException if the template json cannot be read from the input stream */ public GetWorkflowResponse(StreamInput in) throws IOException { super(in); - workflowState = new WorkflowState(in); - allStatus = in.readBoolean(); + this.template = Template.parse(in.readString()); } /** - * Instatiates a new GetWorkflowResponse from an input stream - * @param workflowState the workflow state object - * @param allStatus whether to return all fields in state index + * Instantiates a new GetWorkflowResponse + * @param template the template */ - public GetWorkflowResponse(WorkflowState workflowState, boolean allStatus) { - if (allStatus) { - this.workflowState = workflowState; - } else { - this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) - .error(workflowState.getError()) - .state(workflowState.getState()) - .resourcesCreated(workflowState.resourcesCreated()) - .build(); - } + public GetWorkflowResponse(Template template) { + this.template = template; } @Override public void writeTo(StreamOutput out) throws IOException { - workflowState.writeTo(out); + out.writeString(template.toJson()); } @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return workflowState.toXContent(xContentBuilder, params); + return this.template.toXContent(xContentBuilder, params); } + + /** + * Gets the template + * @return the template + */ + public Template getTemplate() { + return this.template; + } + } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java new file mode 100644 index 000000000..b8a713685 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowStateAction + */ +public class GetWorkflowStateAction extends ActionType { + // TODO : If the template body is returned as part of the GetWorkflowStateAction, + // it is necessary to ensure the user has permissions for workflow/get + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow_state/get"; + /** An instance of this action */ + public static final GetWorkflowStateAction INSTANCE = new GetWorkflowStateAction(); + + private GetWorkflowStateAction() { + super(NAME, GetWorkflowStateResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java similarity index 83% rename from src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java rename to src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java index c7594eb77..7fd546c25 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java @@ -17,9 +17,9 @@ import java.io.IOException; /** - * Transport Request to get a workflow or workflow status + * Transport Request to get a workflow status */ -public class GetWorkflowRequest extends ActionRequest { +public class GetWorkflowStateRequest extends ActionRequest { /** * The documentId of the workflow entry within the Global Context index @@ -33,21 +33,21 @@ public class GetWorkflowRequest extends ActionRequest { private boolean all; /** - * Instantiates a new GetWorkflowRequest + * Instantiates a new GetWorkflowStateRequest * @param workflowId the documentId of the workflow * @param all whether the get request is looking for all fields in status */ - public GetWorkflowRequest(@Nullable String workflowId, boolean all) { + public GetWorkflowStateRequest(@Nullable String workflowId, boolean all) { this.workflowId = workflowId; this.all = all; } /** - * Instantiates a new GetWorkflowRequest request + * Instantiates a new GetWorkflowStateRequest request * @param in The input stream to read from * @throws IOException If the stream cannot be read properly */ - public GetWorkflowRequest(StreamInput in) throws IOException { + public GetWorkflowStateRequest(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); this.all = in.readBoolean(); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java new file mode 100644 index 000000000..fe155237e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.IOException; + +/** + * Transport Response from getting a workflow status + */ +public class GetWorkflowStateResponse extends ActionResponse implements ToXContentObject { + + /** The workflow state */ + public WorkflowState workflowState; + /** Flag to indicate if the entire state should be returned */ + public boolean allStatus; + + /** + * Instantiates a new GetWorkflowStateResponse from an input stream + * @param in the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public GetWorkflowStateResponse(StreamInput in) throws IOException { + super(in); + workflowState = new WorkflowState(in); + allStatus = in.readBoolean(); + } + + /** + * Instatiates a new GetWorkflowStateResponse from an input stream + * @param workflowState the workflow state object + * @param allStatus whether to return all fields in state index + */ + public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus) { + if (allStatus) { + this.workflowState = workflowState; + } else { + this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + workflowState.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return workflowState.toXContent(xContentBuilder, params); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java new file mode 100644 index 000000000..57fcc2b89 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +//TODO: Currently we only get the workflow status but we should change to be able to get the +// full template as well +/** + * Transport Action to get a specific workflow. Currently, we only support the action with _status + * in the API path but will add the ability to get the workflow and not just the status in the future + */ +public class GetWorkflowStateTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(GetWorkflowStateTransportAction.class); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Intantiates a new GetWorkflowStateTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param client The client used to make the request to OS + * @param xContentRegistry contentRegister to parse get response + */ + @Inject + public GetWorkflowStateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(GetWorkflowStateAction.NAME, transportService, actionFilters, GetWorkflowStateRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, GetWorkflowStateRequest request, ActionListener listener) { + String workflowId = request.getWorkflowId(); + User user = ParseUtils.getUserContext(client); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowStateResponse(workflowState, request.getAll())); + } catch (Exception e) { + logger.error("Failed to parse workflowState" + r.getId(), e); + listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST)); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } else { + logger.error("Failed to get workflow status of: " + workflowId, e); + listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND)); + } + }), () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to get workflow: " + workflowId, e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index f3bc1dd9e..e2a9b1931 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -17,83 +17,78 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.model.WorkflowState; -import org.opensearch.flowframework.util.ParseUtils; -import org.opensearch.index.IndexNotFoundException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.Template; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -//TODO: Currently we only get the workflow status but we should change to be able to get the -// full template as well /** - * Transport Action to get a specific workflow. Currently, we only support the action with _status - * in the API path but will add the ability to get the workflow and not just the status in the future + * Transport action to retrieve a use case template within the Global Context */ -public class GetWorkflowTransportAction extends HandledTransportAction { +public class GetWorkflowTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class); - + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private final Client client; - private final NamedXContentRegistry xContentRegistry; /** - * Intantiates a new CreateWorkflowTransportAction - * @param transportService The TransportService + * Instantiates a new GetWorkflowTransportAction instance + * @param transportService the transport service * @param actionFilters action filters - * @param client The client used to make the request to OS - * @param xContentRegistry contentRegister to parse get response + * @param flowFrameworkIndicesHandler The Flow Framework indices handler + * @param client the Opensearch Client */ @Inject public GetWorkflowTransportAction( TransportService transportService, ActionFilters actionFilters, - Client client, - NamedXContentRegistry xContentRegistry + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + Client client ) { - super(GetWorkflowAction.NAME, transportService, actionFilters, GetWorkflowRequest::new); + super(GetWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; this.client = client; - this.xContentRegistry = xContentRegistry; } @Override - protected void doExecute(Task task, GetWorkflowRequest request, ActionListener listener) { - String workflowId = request.getWorkflowId(); - User user = ParseUtils.getUserContext(client); - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll())); - } catch (Exception e) { - logger.error("Failed to parse workflowState" + r.getId(), e); - listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST)); + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + // Retrieve workflow by ID + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + listener.onFailure( + new FlowFrameworkException( + "Failed to retrieve template (" + workflowId + ") from global context.", + RestStatus.NOT_FOUND + ) + ); + } else { + listener.onResponse(new GetWorkflowResponse(Template.parse(response.getSourceAsString()))); } - } else { - listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); - } else { - logger.error("Failed to get workflow status of: " + workflowId, e); - listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND)); - } - }), () -> context.restore())); - } catch (Exception e) { - logger.error("Failed to get workflow: " + workflowId, e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + }, exception -> { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + })); + } catch (Exception e) { + logger.error("Failed to retrieve template from global context.", e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + + } else { + listener.onFailure(new FlowFrameworkException("There are no templates in the global_context", RestStatus.NOT_FOUND)); } + } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 2585ffb09..6370d2312 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(4, ffp.getActions().size()); + assertEquals(5, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(5, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java index 0f6ddab59..3a51f1a9e 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -29,25 +29,20 @@ public class RestGetWorkflowActionTests extends OpenSearchTestCase { private RestGetWorkflowAction restGetWorkflowAction; private String getPath; - private NodeClient nodeClient; private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + private NodeClient nodeClient; @Override public void setUp() throws Exception { super.setUp(); - this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + this.getPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); this.nodeClient = mock(NodeClient.class); } - public void testConstructor() { - RestGetWorkflowAction getWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); - assertNotNull(getWorkflowAction); - } - public void testRestGetWorkflowActionName() { String name = restGetWorkflowAction.getName(); assertEquals("get_workflow", name); @@ -60,6 +55,19 @@ public void testRestGetWorkflowActionRoutes() { assertEquals(this.getPath, routes.get(0).getPath()); } + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + }); + assertEquals("request [POST /_plugins/_flow_framework/workflow/{workflow_id}] does not support having a body", ex.getMessage()); + } + public void testNullWorkflowId() throws Exception { // Request with no params @@ -75,22 +83,6 @@ public void testNullWorkflowId() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() { - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath(this.getPath) - .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) - .build(); - - FakeRestChannel channel = new FakeRestChannel(request, false, 1); - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { - restGetWorkflowAction.handleRequest(request, channel, nodeClient); - }); - assertEquals( - "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", - ex.getMessage() - ); - } - public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java new file mode 100644 index 000000000..dc605a5cd --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestGetWorkflowStateActionTests extends OpenSearchTestCase { + private RestGetWorkflowStateAction restGetWorkflowStateAction; + private String getPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restGetWorkflowStateAction = new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testConstructor() { + RestGetWorkflowStateAction getWorkflowAction = new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting); + assertNotNull(getWorkflowAction); + } + + public void testRestGetWorkflowStateActionName() { + String name = restGetWorkflowStateAction.getName(); + assertEquals("get_workflow_state", name); + } + + public void testRestGetWorkflowStateActionRoutes() { + List routes = restGetWorkflowStateAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.GET, routes.get(0).getMethod()); + assertEquals(this.getPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java new file mode 100644 index 000000000..7aa0323b4 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.Assert; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.mockito.Mockito; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GetWorkflowStateTransportActionTests extends OpenSearchTestCase { + + private GetWorkflowStateTransportAction getWorkflowStateTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Client client; + private ThreadPool threadPool; + private ThreadContext threadContext; + private ActionListener response; + private Task task; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.threadPool = mock(ThreadPool.class); + this.getWorkflowStateTransportAction = new GetWorkflowStateTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + xContentRegistry() + ); + task = Mockito.mock(Task.class); + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + + response = new ActionListener() { + @Override + public void onResponse(GetWorkflowStateResponse getResponse) { + assertTrue(true); + } + + @Override + public void onFailure(Exception e) {} + }; + + } + + public void testGetTransportAction() throws IOException { + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false); + getWorkflowStateTransportAction.doExecute(task, getWorkflowRequest, response); + } + + public void testGetAction() { + Assert.assertNotNull(GetWorkflowStateAction.INSTANCE.name()); + Assert.assertEquals(GetWorkflowStateAction.INSTANCE.name(), GetWorkflowStateAction.NAME); + } + + public void testGetWorkflowStateRequest() throws IOException { + GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetWorkflowStateRequest newRequest = new GetWorkflowStateRequest(input); + Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); + Assert.assertEquals(request.getAll(), newRequest.getAll()); + Assert.assertNull(newRequest.validate()); + } + + public void testGetWorkflowStateResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + String workflowId = randomAlphaOfLength(5); + WorkflowState workFlowState = new WorkflowState( + workflowId, + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + GetWorkflowStateResponse response = new GetWorkflowStateResponse(workFlowState, false); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetWorkflowStateResponse newResponse = new GetWorkflowStateResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertEquals(map.get("state"), workFlowState.getState()); + Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java index ab6d0a68f..d7db8a2c9 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -8,115 +8,151 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; -import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; -import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import org.junit.Assert; -import java.io.IOException; -import java.time.Instant; -import java.util.Collections; +import java.util.List; import java.util.Map; -import org.mockito.Mockito; +import org.mockito.ArgumentCaptor; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class GetWorkflowTransportActionTests extends OpenSearchTestCase { - private GetWorkflowTransportAction getWorkflowTransportAction; + private ThreadPool threadPool; private Client client; - private ActionListener response; - private Task task; + private GetWorkflowTransportAction getTemplateTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Template template; @Override public void setUp() throws Exception { super.setUp(); + this.threadPool = mock(ThreadPool.class); this.client = mock(Client.class); - this.getWorkflowTransportAction = new GetWorkflowTransportAction( + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.getTemplateTransportAction = new GetWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - client, - xContentRegistry() + flowFrameworkIndicesHandler, + client ); - task = Mockito.mock(Task.class); + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + this.template = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + Map.of(), + TestHelpers.randomUser() + ); + ThreadPool clientThreadPool = mock(ThreadPool.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(client.threadPool()).thenReturn(clientThreadPool); when(clientThreadPool.getThreadContext()).thenReturn(threadContext); - response = new ActionListener() { - @Override - public void onResponse(GetWorkflowResponse getResponse) { - assertTrue(true); - } + } - @Override - public void onFailure(Exception e) {} - }; + public void testGetWorkflowNoGlobalContext() { - } + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest("1", null); + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - public void testGetTransportAction() throws IOException { - GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest("1234", false); - getWorkflowTransportAction.doExecute(task, getWorkflowRequest, response); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("There are no templates in the global_context")); } - public void testGetAction() { - Assert.assertNotNull(GetWorkflowAction.INSTANCE.name()); - Assert.assertEquals(GetWorkflowAction.INSTANCE.name(), GetWorkflowAction.NAME); - } + public void testGetWorkflowSuccess() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - public void testGetAnomalyDetectorRequest() throws IOException { - GetWorkflowRequest request = new GetWorkflowRequest("1234", false); - BytesStreamOutput out = new BytesStreamOutput(); - request.writeTo(out); - StreamInput input = out.bytes().streamInput(); - GetWorkflowRequest newRequest = new GetWorkflowRequest(input); - Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); - Assert.assertEquals(request.getAll(), newRequest.getAll()); - Assert.assertNull(newRequest.validate()); + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.get to force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + this.template.toXContent(builder, null); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor templateCaptor = ArgumentCaptor.forClass(GetWorkflowResponse.class); + verify(listener, times(1)).onResponse(templateCaptor.capture()); + assertEquals(this.template.name(), templateCaptor.getValue().getTemplate().name()); } - public void testGetAnomalyDetectorResponse() throws IOException { - BytesStreamOutput out = new BytesStreamOutput(); - String workflowId = randomAlphaOfLength(5); - WorkflowState workFlowState = new WorkflowState( - workflowId, - "test", - "PROVISIONING", - "IN_PROGRESS", - Instant.now(), - Instant.now(), - TestHelpers.randomUser(), - Collections.emptyMap(), - Collections.emptyList() - ); + public void testGetWorkflowFailure() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.get to force on failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to retrieve template from global context.")); + return null; + }).when(client).get(any(GetRequest.class), any()); - GetWorkflowResponse response = new GetWorkflowResponse(workFlowState, false); - response.writeTo(out); - NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); - GetWorkflowResponse newResponse = new GetWorkflowResponse(input); - XContentBuilder builder = TestHelpers.builder(); - Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - Map map = TestHelpers.XContentBuilderToMap(builder); - Assert.assertEquals(map.get("state"), workFlowState.getState()); - Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to retrieve template from global context.", exceptionCaptor.getValue().getMessage()); } } From 5045aae99735f1b97e98e0a09f757da4e0338aaa Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 13 Dec 2023 12:01:54 -0800 Subject: [PATCH 19/27] [Feature/agent_framework] Adds a Search Workflow State API (#284) * Modifying workflow state index mapping and resources created Signed-off-by: Joshua Palis * Adding Search workflow state API Signed-off-by: Joshua Palis * Adding rest unit tests Signed-off-by: Joshua Palis * Transport unit tests Signed-off-by: Joshua Palis * Moving resourceType determination outside of the resources created class Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 9 +- .../flowframework/common/CommonValue.java | 4 + .../indices/FlowFrameworkIndicesHandler.java | 8 +- .../flowframework/model/ResourceCreated.java | 54 +++++++----- .../rest/RestSearchWorkflowStateAction.java | 47 ++++++++++ .../transport/GetWorkflowStateResponse.java | 22 ++++- .../transport/SearchWorkflowStateAction.java | 29 +++++++ .../SearchWorkflowStateTransportAction.java | 50 +++++++++++ .../resources/mappings/workflow-state.json | 16 +++- .../FlowFrameworkPluginTests.java | 4 +- .../model/ResourceCreatedTests.java | 37 +++++--- .../RestSearchWorkflowStateActionTests.java | 85 +++++++++++++++++++ ...archWorkflowStateTransportActionTests.java | 79 +++++++++++++++++ 13 files changed, 403 insertions(+), 41 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index ec9eb40da..40ddee2fa 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -32,6 +32,7 @@ import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowAction; +import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; @@ -41,6 +42,8 @@ import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; +import org.opensearch.flowframework.transport.SearchWorkflowStateAction; +import org.opensearch.flowframework.transport.SearchWorkflowStateTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowTransportAction; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; @@ -130,7 +133,8 @@ public List getRestHandlers( new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting), - new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting) + new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestSearchWorkflowStateAction(flowFrameworkFeatureEnabledSetting) ); } @@ -141,7 +145,8 @@ public List getRestHandlers( new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), - new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class) + new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class), + new ActionHandler<>(SearchWorkflowStateAction.INSTANCE, SearchWorkflowStateTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 0863565c0..8b8f9deae 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -168,6 +168,10 @@ private CommonValue() {} public static final String WORKFLOW_STEP_NAME = "workflow_step_name"; /** The field name for the step ID where a resource is created */ public static final String WORKFLOW_STEP_ID = "workflow_step_id"; + /** The field name for the resource type */ + public static final String RESOURCE_TYPE = "resource_type"; + /** The field name for the resource id */ + public static final String RESOURCE_ID = "resource_id"; /** The tools' field for an agent */ public static final String TOOLS_FIELD = "tools"; /** The memory field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 63df7824c..2449297b6 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -34,6 +34,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; @@ -500,7 +501,12 @@ public void updateResourceInStateIndex( String resourceId, ActionListener listener ) throws IOException { - ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceId); + ResourceCreated newResource = new ResourceCreated( + workflowStepName, + nodeId, + WorkflowResources.getResourceByWorkflowStep(workflowStepName), + resourceId + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); newResource.toXContent(builder, ToXContentObject.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java index d039e2f8c..b12f4d044 100644 --- a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -17,12 +17,13 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import java.io.IOException; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_ID; +import static org.opensearch.flowframework.common.CommonValue.RESOURCE_TYPE; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP_NAME; @@ -36,17 +37,20 @@ public class ResourceCreated implements ToXContentObject, Writeable { private final String workflowStepName; private final String workflowStepId; + private final String resourceType; private final String resourceId; /** * Create this resources created object with given workflow step name, ID and resource ID. * @param workflowStepName The workflow step name associating to the step where it was created * @param workflowStepId The workflow step ID associating to the step where it was created + * @param resourceType The resource type * @param resourceId The resources ID for relating to the created resource */ - public ResourceCreated(String workflowStepName, String workflowStepId, String resourceId) { + public ResourceCreated(String workflowStepName, String workflowStepId, String resourceType, String resourceId) { this.workflowStepName = workflowStepName; this.workflowStepId = workflowStepId; + this.resourceType = resourceType; this.resourceId = resourceId; } @@ -58,6 +62,7 @@ public ResourceCreated(String workflowStepName, String workflowStepId, String re public ResourceCreated(StreamInput input) throws IOException { this.workflowStepName = input.readString(); this.workflowStepId = input.readString(); + this.resourceType = input.readString(); this.resourceId = input.readString(); } @@ -66,7 +71,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws XContentBuilder xContentBuilder = builder.startObject() .field(WORKFLOW_STEP_NAME, workflowStepName) .field(WORKFLOW_STEP_ID, workflowStepId) - .field(WorkflowResources.getResourceByWorkflowStep(workflowStepName), resourceId); + .field(RESOURCE_TYPE, resourceType) + .field(RESOURCE_ID, resourceId); return xContentBuilder.endObject(); } @@ -74,6 +80,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowStepName); out.writeString(workflowStepId); + out.writeString(resourceType); out.writeString(resourceId); } @@ -86,6 +93,15 @@ public String resourceId() { return resourceId; } + /** + * Gets the resource type. + * + * @return the resource type. + */ + public String resourceType() { + return resourceType; + } + /** * Gets the workflow step name associated to the created resource * @@ -114,6 +130,7 @@ public String workflowStepId() { public static ResourceCreated parse(XContentParser parser) throws IOException { String workflowStepName = null; String workflowStepId = null; + String resourceType = null; String resourceId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -128,15 +145,14 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { case WORKFLOW_STEP_ID: workflowStepId = parser.text(); break; + case RESOURCE_TYPE: + resourceType = parser.text(); + break; + case RESOURCE_ID: + resourceId = parser.text(); + break; default: - if (!isValidFieldName(fieldName)) { - throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); - } else { - if (fieldName.equals(WorkflowResources.getResourceByWorkflowStep(workflowStepName))) { - resourceId = parser.text(); - } - break; - } + throw new IOException("Unable to parse field [" + fieldName + "] in a resources_created object."); } } if (workflowStepName == null) { @@ -147,17 +163,15 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { logger.error("Resource created object failed parsing: workflowStepId: {}", workflowStepId); throw new FlowFrameworkException("A ResourceCreated object requires workflowStepId", RestStatus.BAD_REQUEST); } + if (resourceType == null) { + logger.error("Resource created object failed parsing: resourceType: {}", resourceType); + throw new FlowFrameworkException("A ResourceCreated object requires resourceType", RestStatus.BAD_REQUEST); + } if (resourceId == null) { logger.error("Resource created object failed parsing: resourceId: {}", resourceId); throw new FlowFrameworkException("A ResourceCreated object requires resourceId", RestStatus.BAD_REQUEST); } - return new ResourceCreated(workflowStepName, workflowStepId, resourceId); - } - - private static boolean isValidFieldName(String fieldName) { - return (WORKFLOW_STEP_NAME.equals(fieldName) - || WORKFLOW_STEP_ID.equals(fieldName) - || WorkflowResources.getAllResourcesCreated().contains(fieldName)); + return new ResourceCreated(workflowStepName, workflowStepId, resourceType, resourceId); } @Override @@ -165,7 +179,9 @@ public String toString() { return "resources_Created [workflow_step_name= " + workflowStepName + ", workflow_step_id= " - + workflowStepName + + workflowStepId + + ", resource_type= " + + resourceType + ", resource_id= " + resourceId + "]"; diff --git a/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateAction.java new file mode 100644 index 000000000..dfbcc0eb2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateAction.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.transport.SearchWorkflowStateAction; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +/** + * Rest Action to facilitate requests to search workflow states + */ +public class RestSearchWorkflowStateAction extends AbstractSearchWorkflowAction { + + private static final String SEARCH_WORKFLOW_STATE_ACTION = "search_workflow_state_action"; + private static final String SEARCH_WORKFLOW_STATE_PATH = WORKFLOW_URI + "/state/_search"; + + /** + * Instantiates a new RestSearchWorkflowStateAction + * + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestSearchWorkflowStateAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + super( + ImmutableList.of(SEARCH_WORKFLOW_STATE_PATH), + WORKFLOW_STATE_INDEX, + WorkflowState.class, + SearchWorkflowStateAction.INSTANCE, + flowFrameworkFeatureEnabledSetting + ); + } + + @Override + public String getName() { + return SEARCH_WORKFLOW_STATE_ACTION; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java index fe155237e..6f8b9e14b 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java @@ -23,9 +23,9 @@ public class GetWorkflowStateResponse extends ActionResponse implements ToXContentObject { /** The workflow state */ - public WorkflowState workflowState; + private final WorkflowState workflowState; /** Flag to indicate if the entire state should be returned */ - public boolean allStatus; + private final boolean allStatus; /** * Instantiates a new GetWorkflowStateResponse from an input stream @@ -44,6 +44,7 @@ public GetWorkflowStateResponse(StreamInput in) throws IOException { * @param allStatus whether to return all fields in state index */ public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus) { + this.allStatus = allStatus; if (allStatus) { this.workflowState = workflowState; } else { @@ -58,10 +59,27 @@ public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus) @Override public void writeTo(StreamOutput out) throws IOException { workflowState.writeTo(out); + out.writeBoolean(allStatus); } @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { return workflowState.toXContent(xContentBuilder, params); } + + /** + * Gets the workflow state. + * @return the workflow state + */ + public WorkflowState getWorkflowState() { + return workflowState; + } + + /** + * Gets whether to return the entire state. + * @return true if the entire state should be returned + */ + public boolean isAllStatus() { + return allStatus; + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateAction.java new file mode 100644 index 000000000..6b331ef02 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestSearchWorkflowStateAction + */ +public class SearchWorkflowStateAction extends ActionType { + + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow_state/search"; + /** An instance of this action */ + public static final SearchWorkflowStateAction INSTANCE = new SearchWorkflowStateAction(); + + private SearchWorkflowStateAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java new file mode 100644 index 000000000..b10bdaeb6 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +/** + * Transport Action to search workflow states + */ +public class SearchWorkflowStateTransportAction extends HandledTransportAction { + + private Client client; + + /** + * Intantiates a new SearchWorkflowStateTransportAction + * @param transportService the TransportService + * @param actionFilters action filters + * @param client The client used to make the request to OS + */ + @Inject + public SearchWorkflowStateTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(SearchWorkflowStateAction.NAME, transportService, actionFilters, SearchRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + // TODO: AccessController should take care of letting the user with right permission to view the workflow + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.search(request, ActionListener.runBefore(actionListener, () -> context.restore())); + } catch (Exception e) { + actionListener.onFailure(e); + } + } +} diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index 86fbeef6e..fedce568c 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -31,7 +31,21 @@ "type": "object" }, "resources_created": { - "type": "object" + "type": "nested", + "properties": { + "workflow_step_name": { + "type": "keyword" + }, + "workflow_step_id": { + "type": "keyword" + }, + "resource_type": { + "type": "keyword" + }, + "resource_id": { + "type": "keyword" + } + } }, "ui_metadata": { "type": "object", diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 6370d2312..cbf988eee 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(5, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(5, ffp.getActions().size()); + assertEquals(6, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(6, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java index 216c18c9e..4f0bf5163 100644 --- a/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java +++ b/src/test/java/org/opensearch/flowframework/model/ResourceCreatedTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.model; import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -22,30 +23,38 @@ public void setUp() throws Exception { public void testParseFeature() throws IOException { String workflowStepName = WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(); - ResourceCreated ResourceCreated = new ResourceCreated(workflowStepName, "workflow_step_1", "L85p1IsBbfF"); - assertEquals(ResourceCreated.workflowStepName(), workflowStepName); - assertEquals(ResourceCreated.workflowStepId(), "workflow_step_1"); - assertEquals(ResourceCreated.resourceId(), "L85p1IsBbfF"); + String resourceType = WorkflowResources.getResourceByWorkflowStep(workflowStepName); + ResourceCreated resourceCreated = new ResourceCreated(workflowStepName, "workflow_step_1", resourceType, "L85p1IsBbfF"); + assertEquals(workflowStepName, resourceCreated.workflowStepName()); + assertEquals("workflow_step_1", resourceCreated.workflowStepId()); + assertEquals("connector_id", resourceCreated.resourceType()); + assertEquals("L85p1IsBbfF", resourceCreated.resourceId()); String expectedJson = - "{\"workflow_step_name\":\"create_connector\",\"workflow_step_id\":\"workflow_step_1\",\"connector_id\":\"L85p1IsBbfF\"}"; - String json = TemplateTestJsonUtil.parseToJson(ResourceCreated); + "{\"workflow_step_name\":\"create_connector\",\"workflow_step_id\":\"workflow_step_1\",\"resource_type\":\"connector_id\",\"resource_id\":\"L85p1IsBbfF\"}"; + String json = TemplateTestJsonUtil.parseToJson(resourceCreated); assertEquals(expectedJson, json); - ResourceCreated ResourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); - assertEquals(workflowStepName, ResourceCreatedTwo.workflowStepName()); - assertEquals("workflow_step_1", ResourceCreatedTwo.workflowStepId()); - assertEquals("L85p1IsBbfF", ResourceCreatedTwo.resourceId()); + ResourceCreated resourceCreatedTwo = ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(json)); + assertEquals(workflowStepName, resourceCreatedTwo.workflowStepName()); + assertEquals("workflow_step_1", resourceCreatedTwo.workflowStepId()); + assertEquals("L85p1IsBbfF", resourceCreatedTwo.resourceId()); } public void testExceptions() throws IOException { String badJson = "{\"wrong\":\"A\",\"resource_id\":\"B\"}"; - IOException e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(badJson))); - assertEquals("Unable to parse field [wrong] in a resources_created object.", e.getMessage()); + IOException badJsonException = assertThrows( + IOException.class, + () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(badJson)) + ); + assertEquals("Unable to parse field [wrong] in a resources_created object.", badJsonException.getMessage()); String missingJson = "{\"resource_id\":\"B\"}"; - e = assertThrows(IOException.class, () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(missingJson))); - assertEquals("Unable to parse field [resource_id] in a resources_created object.", e.getMessage()); + FlowFrameworkException missingJsonException = assertThrows( + FlowFrameworkException.class, + () -> ResourceCreated.parse(TemplateTestJsonUtil.jsonToParser(missingJson)) + ); + assertEquals("A ResourceCreated object requires workflowStepName", missingJsonException.getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateActionTests.java new file mode 100644 index 000000000..028860831 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowStateActionTests.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestSearchWorkflowStateActionTests extends OpenSearchTestCase { + private RestSearchWorkflowStateAction restSearchWorkflowStateAction; + private String searchPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.searchPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "state/_search"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restSearchWorkflowStateAction = new RestSearchWorkflowStateAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestSearchWorkflowStateActionName() { + String name = restSearchWorkflowStateAction.getName(); + assertEquals("search_workflow_state_action", name); + } + + public void testRestSearchWorkflowStateActionRoutes() { + List routes = restSearchWorkflowStateAction.routes(); + assertNotNull(routes); + assertEquals(2, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(RestRequest.Method.GET, routes.get(1).getMethod()); + assertEquals(this.searchPath, routes.get(0).getPath()); + assertEquals(this.searchPath, routes.get(1).getPath()); + } + + public void testInvalidSearchRequest() { + final String requestContent = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"template\":\"1.0.0\"}}]}}}"; + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.searchPath) + .withContent(new BytesArray(requestContent), MediaTypeRegistry.JSON) + .build(); + + XContentParseException ex = expectThrows(XContentParseException.class, () -> { + restSearchWorkflowStateAction.prepareRequest(request, nodeClient); + }); + assertEquals("unknown named object category [org.opensearch.index.query.QueryBuilder]", ex.getMessage()); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.searchPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restSearchWorkflowStateAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java new file mode 100644 index 000000000..d5dcddb8e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SearchWorkflowStateTransportActionTests extends OpenSearchTestCase { + + private SearchWorkflowStateTransportAction searchWorkflowStateTransportAction; + private Client client; + private ThreadPool threadPool; + private ThreadContext threadContext; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.threadPool = mock(ThreadPool.class); + this.threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + this.searchWorkflowStateTransportAction = new SearchWorkflowStateTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client + ); + + } + + public void testFailedSearchWorkflow() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + SearchRequest searchRequest = new SearchRequest(); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Search failed")); + return null; + }).when(client).search(any(), any()); + + searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testSearchWorkflow() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + SearchRequest searchRequest = new SearchRequest(); + + searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); + verify(client, times(1)).search(any(SearchRequest.class), any()); + } + +} From 13ce2f76f4aa948875d680f3393ed8a0f3e4e711 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 13 Dec 2023 14:18:48 -0800 Subject: [PATCH 20/27] Permit ordering of tools in register agent step (#283) Signed-off-by: Daniel Widdis --- .../flowframework/model/WorkflowNode.java | 16 ++++++++++ .../workflow/RegisterAgentStep.java | 32 +++++++++++-------- .../model/WorkflowNodeTests.java | 5 ++- .../workflow/RegisterAgentTests.java | 8 +---- 4 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 706cd2c62..8680d3f43 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,6 +24,7 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap; @@ -93,6 +94,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (PipelineProcessor p : (PipelineProcessor[]) e.getValue()) { xContentBuilder.value(p); } + } else if (TOOLS_FIELD.equals(e.getKey())) { + for (String t : (String[]) e.getValue()) { + xContentBuilder.value(t); + } } else { for (Map map : (Map[]) e.getValue()) { buildStringToObjectMap(xContentBuilder, map); @@ -151,6 +156,12 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { processorList.add(PipelineProcessor.parse(parser)); } userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); + } else if (TOOLS_FIELD.equals(inputFieldName)) { + List toolsList = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + toolsList.add(parser.text()); + } + userInputs.put(inputFieldName, toolsList.toArray(new String[0])); } else { List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { @@ -173,6 +184,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { case DOUBLE: userInputs.put(inputFieldName, parser.doubleValue()); break; + case BIG_INTEGER: + userInputs.put(inputFieldName, parser.bigIntegerValue()); + break; + default: + throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object."); } break; default: diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index f9f17b4f0..cacbb119b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; @@ -28,6 +29,7 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -155,7 +157,8 @@ public void onFailure(Exception e) { String description = (String) inputs.get(DESCRIPTION_FIELD); String llmModelId = (String) inputs.get(LLM_MODEL_ID); Map llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS); - List tools = getTools(previousNodeInputs, outputs); + String[] tools = (String[]) inputs.get(TOOLS_FIELD); + List toolsList = getTools(tools, previousNodeInputs, outputs); Map parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD); MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME)); @@ -188,7 +191,7 @@ public void onFailure(Exception e) { builder.type(type) .llm(llmSpec) - .tools(tools) + .tools(toolsList) .parameters(parameters) .memory(memory) .createdTime(createdTime) @@ -210,24 +213,25 @@ public String getName() { return NAME; } - private List getTools(Map previousNodeInputs, Map outputs) { + private List getTools(@Nullable String[] tools, Map previousNodeInputs, Map outputs) { List mlToolSpecList = new ArrayList<>(); List previousNodes = previousNodeInputs.entrySet() .stream() .filter(e -> TOOLS_FIELD.equals(e.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toList()); - - if (previousNodes != null) { - previousNodes.forEach((previousNode) -> { - WorkflowData previousNodeOutput = outputs.get(previousNode); - if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) { - MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD); - logger.info("Tool added {}", mlToolSpec.getType()); - mlToolSpecList.add(mlToolSpec); - } - }); - } + // Anything in tools is sorted first, followed by anything else in previous node inputs + List sortedNodes = tools == null ? new ArrayList<>() : Arrays.asList(tools); + previousNodes.removeAll(sortedNodes); + sortedNodes.addAll(previousNodes); + sortedNodes.forEach((node) -> { + WorkflowData previousNodeOutput = outputs.get(node); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) { + MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD); + logger.info("Tool added {}", mlToolSpec.getType()); + mlToolSpecList.add(mlToolSpec); + } + }); return mlToolSpecList; } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index b9620c214..2bc7cb2ca 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -31,7 +31,8 @@ public void testNode() throws IOException { Map.entry("bar", Map.of("key", "value")), Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), - Map.entry("created_time", 1689793598499L) + Map.entry("created_time", 1689793598499L), + Map.entry("tools", new String[] { "foo", "bar" }) ) ); assertEquals("A", nodeA.id()); @@ -46,6 +47,7 @@ public void testNode() throws IOException { assertEquals("test-type", pp[0].type()); assertEquals(Map.of("key2", "value2"), pp[0].params()); assertEquals(1689793598499L, map.get("created_time")); + assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools")); // node equality is based only on ID WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); @@ -63,6 +65,7 @@ public void testNode() throws IOException { assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); assertTrue(json.contains("\"created_time\":1689793598499")); + assertTrue(json.contains("\"tools\":[\"foo\",\"bar\"]")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 115729e9c..885d22352 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -15,10 +15,8 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; -import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.test.OpenSearchTestCase; @@ -54,10 +52,6 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); - - LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); - Map mlMemorySpec = Map.ofEntries( Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"), Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"), @@ -71,7 +65,7 @@ public void setUp() throws Exception { Map.entry("type", "type"), Map.entry("llm.model_id", "xyz"), Map.entry("llm.parameters", Collections.emptyMap()), - Map.entry("tools", tools), + Map.entry("tools", new String[] { "abc", "xyz" }), Map.entry("parameters", Collections.emptyMap()), Map.entry("memory", mlMemorySpec), Map.entry("created_time", 1689793598499L), From 2f760f4efaa8eb3bb2fc55e019ee867df50727ee Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 13 Dec 2023 18:11:16 -0800 Subject: [PATCH 21/27] Fix tools ordering class casting bug (#289) Signed-off-by: Daniel Widdis --- .../opensearch/flowframework/common/CommonValue.java | 4 +++- .../opensearch/flowframework/model/WorkflowNode.java | 6 +++--- .../flowframework/workflow/RegisterAgentStep.java | 10 ++++------ .../flowframework/model/WorkflowNodeTests.java | 6 +++--- .../flowframework/workflow/RegisterAgentTests.java | 9 ++++++++- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 8b8f9deae..bb9eaf108 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -172,8 +172,10 @@ private CommonValue() {} public static final String RESOURCE_TYPE = "resource_type"; /** The field name for the resource id */ public static final String RESOURCE_ID = "resource_id"; - /** The tools' field for an agent */ + /** The tools field for an agent */ public static final String TOOLS_FIELD = "tools"; + /** The tools order field for an agent */ + public static final String TOOLS_ORDER_FIELD = "tools_order"; /** The memory field for an agent */ public static final String MEMORY_FIELD = "memory"; /** The app type field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 8680d3f43..6d166c68b 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,7 +24,7 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap; @@ -94,7 +94,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (PipelineProcessor p : (PipelineProcessor[]) e.getValue()) { xContentBuilder.value(p); } - } else if (TOOLS_FIELD.equals(e.getKey())) { + } else if (TOOLS_ORDER_FIELD.equals(e.getKey())) { for (String t : (String[]) e.getValue()) { xContentBuilder.value(t); } @@ -156,7 +156,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { processorList.add(PipelineProcessor.parse(parser)); } userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0])); - } else if (TOOLS_FIELD.equals(inputFieldName)) { + } else if (TOOLS_ORDER_FIELD.equals(inputFieldName)) { List toolsList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { toolsList.add(parser.text()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index cacbb119b..f72d78d11 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -46,6 +46,7 @@ import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; @@ -64,8 +65,6 @@ public class RegisterAgentStep implements WorkflowStep { private static final String LLM_MODEL_ID = "llm.model_id"; private static final String LLM_PARAMETERS = "llm.parameters"; - private List mlToolSpecList; - /** * Instantiate this class * @param mlClient client to instantiate MLClient @@ -73,7 +72,6 @@ public class RegisterAgentStep implements WorkflowStep { */ public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { this.mlClient = mlClient; - this.mlToolSpecList = new ArrayList<>(); this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -136,6 +134,7 @@ public void onFailure(Exception e) { LLM_MODEL_ID, LLM_PARAMETERS, TOOLS_FIELD, + TOOLS_ORDER_FIELD, PARAMETERS_FIELD, MEMORY_FIELD, CREATED_TIME, @@ -157,8 +156,8 @@ public void onFailure(Exception e) { String description = (String) inputs.get(DESCRIPTION_FIELD); String llmModelId = (String) inputs.get(LLM_MODEL_ID); Map llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS); - String[] tools = (String[]) inputs.get(TOOLS_FIELD); - List toolsList = getTools(tools, previousNodeInputs, outputs); + String[] toolsOrder = (String[]) inputs.get(TOOLS_ORDER_FIELD); + List toolsList = getTools(toolsOrder, previousNodeInputs, outputs); Map parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD); MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME)); @@ -285,7 +284,6 @@ private MLMemorySpec getMLMemorySpec(Object mlMemory) { sessionId = (String) map.get(MLMemorySpec.SESSION_ID_FIELD); windowSize = (Integer) map.get(MLMemorySpec.WINDOW_SIZE_FIELD); - @SuppressWarnings("unchecked") MLMemorySpec.MLMemorySpecBuilder builder = MLMemorySpec.builder(); builder.type(type); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java index 2bc7cb2ca..08f820a36 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowNodeTests.java @@ -32,7 +32,7 @@ public void testNode() throws IOException { Map.entry("baz", new Map[] { Map.of("A", "a"), Map.of("B", "b") }), Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }), Map.entry("created_time", 1689793598499L), - Map.entry("tools", new String[] { "foo", "bar" }) + Map.entry("tools_order", new String[] { "foo", "bar" }) ) ); assertEquals("A", nodeA.id()); @@ -47,7 +47,7 @@ public void testNode() throws IOException { assertEquals("test-type", pp[0].type()); assertEquals(Map.of("key2", "value2"), pp[0].params()); assertEquals(1689793598499L, map.get("created_time")); - assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools")); + assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools_order")); // node equality is based only on ID WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz")); @@ -65,7 +65,7 @@ public void testNode() throws IOException { assertTrue(json.contains("\"bar\":{\"key\":\"value\"}")); assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]")); assertTrue(json.contains("\"created_time\":1689793598499")); - assertTrue(json.contains("\"tools\":[\"foo\",\"bar\"]")); + assertTrue(json.contains("\"tools_order\":[\"foo\",\"bar\"]")); WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json)); assertEquals("A", nodeX.id()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 885d22352..883aef99a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -15,8 +15,10 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.test.OpenSearchTestCase; @@ -52,6 +54,10 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); + MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); + + LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); + Map mlMemorySpec = Map.ofEntries( Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"), Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"), @@ -65,7 +71,8 @@ public void setUp() throws Exception { Map.entry("type", "type"), Map.entry("llm.model_id", "xyz"), Map.entry("llm.parameters", Collections.emptyMap()), - Map.entry("tools", new String[] { "abc", "xyz" }), + Map.entry("tools", tools), + Map.entry("tools_order", new String[] { "abc", "xyz" }), Map.entry("parameters", Collections.emptyMap()), Map.entry("memory", mlMemorySpec), Map.entry("created_time", 1689793598499L), From 4852878b03042fea8679c3d970726cc4e7bfc257 Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Thu, 14 Dec 2023 14:50:45 -0800 Subject: [PATCH 22/27] Combine create api with provision api by adding a provision param (#282) * replace dryrun parameter with provision in create workflow Signed-off-by: Jackie Han * test Signed-off-by: Jackie Han * test Signed-off-by: Jackie Han * test Signed-off-by: Jackie Han * Combine create api with provision api by adding a provision param Signed-off-by: Jackie Han * cleanup Signed-off-by: Jackie Han * keep dryrun option in create workflow Signed-off-by: Jackie Han * cleanup Signed-off-by: Jackie Han * keep both dryRun and provision parameter Signed-off-by: Jackie Han --------- Signed-off-by: Jackie Han --- .../rest/RestCreateWorkflowAction.java | 4 +- .../CreateWorkflowTransportAction.java | 24 ++- .../transport/WorkflowRequest.java | 22 +- .../CreateWorkflowTransportActionTests.java | 194 +++++++++++++++++- 4 files changed, 235 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 5e8509373..deeabdd76 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -31,6 +31,7 @@ import java.util.Locale; import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -91,8 +92,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); Template template = Template.parse(request.content().utf8ToString()); boolean dryRun = request.paramAsBoolean(DRY_RUN, false); + boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, requestTimeout, maxWorkflows); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun, provision, requestTimeout, maxWorkflows); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 6ca1c4661..765c9cae5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -135,7 +135,28 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + if (request.isProvision()) { + logger.info("provision parameter"); + WorkflowRequest workflowRequest = new WorkflowRequest(globalContextResponse.getId(), null); + client.execute( + ProvisionWorkflowAction.INSTANCE, + workflowRequest, + ActionListener.wrap(provisionResponse -> { + listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + logger.error("Failed to send back provision workflow exception", exception); + }) + ); + } else { + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + } }, exception -> { logger.error("Failed to save workflow state : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { @@ -246,5 +267,4 @@ private void validateWorkflows(Template template) throws Exception { workflowProcessSorter.validateGraph(sortedNodes); } } - } diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index d049be8f6..057f13d01 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -38,6 +38,11 @@ public class WorkflowRequest extends ActionRequest { */ private boolean dryRun; + /** + * Provision flag + */ + private boolean provision; + /** * Timeout for request */ @@ -54,7 +59,7 @@ public class WorkflowRequest extends ActionRequest { * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, false, null, null); + this(workflowId, template, false, false, null, null); } /** @@ -70,7 +75,7 @@ public WorkflowRequest( @Nullable TimeValue requestTimeout, @Nullable Integer maxWorkflows ) { - this(workflowId, template, false, requestTimeout, maxWorkflows); + this(workflowId, template, false, false, requestTimeout, maxWorkflows); } /** @@ -78,6 +83,7 @@ public WorkflowRequest( * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow * @param dryRun flag to indicate if validation is necessary + * @param provision flag to indicate if provision is necessary * @param requestTimeout timeout of the request * @param maxWorkflows max number of workflows */ @@ -85,12 +91,14 @@ public WorkflowRequest( @Nullable String workflowId, @Nullable Template template, boolean dryRun, + boolean provision, @Nullable TimeValue requestTimeout, @Nullable Integer maxWorkflows ) { this.workflowId = workflowId; this.template = template; this.dryRun = dryRun; + this.provision = provision; this.requestTimeout = requestTimeout; this.maxWorkflows = maxWorkflows; } @@ -106,6 +114,7 @@ public WorkflowRequest(StreamInput in) throws IOException { String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); this.dryRun = in.readBoolean(); + this.provision = in.readBoolean(); this.requestTimeout = in.readOptionalTimeValue(); this.maxWorkflows = in.readOptionalInt(); } @@ -136,6 +145,14 @@ public boolean isDryRun() { return this.dryRun; } + /** + * Gets the provision flag + * @return the provision boolean + */ + public boolean isProvision() { + return this.provision; + } + /** * Gets the timeout of the request * @return the requestTimeout @@ -158,6 +175,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); out.writeBoolean(dryRun); + out.writeBoolean(provision); out.writeOptionalTimeValue(requestTimeout); out.writeOptionalInt(maxWorkflows); } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 6856a2122..2e67b59d8 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -29,6 +30,7 @@ import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -52,8 +54,9 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -89,7 +92,16 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); + + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); + WorkflowStepFactory factory = new WorkflowStepFactory( + Settings.EMPTY, + clusterService, + client, + mlClient, + flowFrameworkIndicesHandler + ); + this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), @@ -129,7 +141,16 @@ public void setUp() throws Exception { ); } - public void testFailedDryRunValidation() { + public void testDryRunValidation_withoutProvision_Success() { + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, true, false, null, null); + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + } + + public void testDryRunValidation_Failed() { WorkflowNode createConnector = new WorkflowNode( "workflow_step_1", @@ -183,7 +204,7 @@ public void testFailedDryRunValidation() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, null, null); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, false, null, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -198,6 +219,7 @@ public void testMaxWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -234,6 +256,7 @@ public void testFailedToCreateNewWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -271,6 +294,7 @@ public void testCreateNewWorkflow() { null, template, false, + false, WORKFLOW_REQUEST_TIMEOUT.get(settings), MAX_WORKFLOWS.get(settings) ); @@ -352,4 +376,166 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } + + public void testCreateWorkflow_withDryRun_withProvision_Success() { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + true, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + } + + public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() { + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + true, + true, + WORKFLOW_REQUEST_TIMEOUT.get(settings), + MAX_WORKFLOWS.get(settings) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + responseListener.onFailure(new Exception("failed")); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("failed", exceptionCaptor.getValue().getMessage()); + } + + private Template generateValidTemplate() { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + WorkflowResources.CREATE_CONNECTOR.getWorkflowStep(), + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + WorkflowResources.REGISTER_REMOTE_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + WorkflowResources.DEPLOY_MODEL.getWorkflowStep(), + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + + Template validTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + return validTemplate; + } } From 3c45cb54630faf41b68355a0987a806b4e60aba6 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 14 Dec 2023 17:21:03 -0800 Subject: [PATCH 23/27] [Feature/agent_framework] Deprovision API (#271) * Deprovision REST and Transport Actions Signed-off-by: Daniel Widdis * Fix errors you find actually running the code Signed-off-by: Daniel Widdis * Add test for Rest deprovision action Signed-off-by: Daniel Widdis * Initial copypaste of Deprovision Transport Action Test Signed-off-by: Daniel Widdis * Add some delays to let deletions propagate, reset workflow state Signed-off-by: Daniel Widdis * Improved deprovisioning results and status updates Signed-off-by: Daniel Widdis * Fix bug in resource created parsing Signed-off-by: Daniel Widdis * Completed test implementations Signed-off-by: Daniel Widdis * Fixes after rebase Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../flowframework/FlowFrameworkPlugin.java | 5 + .../common/WorkflowResources.java | 50 ++- .../flowframework/model/ResourceCreated.java | 9 +- .../rest/RestDeprovisionWorkflowAction.java | 108 ++++++ .../transport/DeprovisionWorkflowAction.java | 27 ++ .../DeprovisionWorkflowTransportAction.java | 337 ++++++++++++++++++ .../transport/WorkflowRequest.java | 2 +- .../FlowFrameworkPluginTests.java | 4 +- .../RestDeprovisionWorkflowActionTests.java | 100 ++++++ ...provisionWorkflowTransportActionTests.java | 234 ++++++++++++ 10 files changed, 857 insertions(+), 19 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 40ddee2fa..544f6f3e1 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -28,6 +28,7 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestDeprovisionWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; @@ -35,6 +36,8 @@ import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; +import org.opensearch.flowframework.transport.DeprovisionWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateTransportAction; @@ -131,6 +134,7 @@ public List getRestHandlers( return ImmutableList.of( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestDeprovisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting), @@ -143,6 +147,7 @@ public List getRestHandlers( return ImmutableList.of( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), + new ActionHandler<>(DeprovisionWorkflowAction.INSTANCE, DeprovisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class), diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index d43a9e0f9..1246574d7 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -23,32 +23,34 @@ public enum WorkflowResources { /** official workflow step name for creating a connector and associated created resource */ - CREATE_CONNECTOR("create_connector", "connector_id"), + CREATE_CONNECTOR("create_connector", "connector_id", "delete_connector"), /** official workflow step name for registering a remote model and associated created resource */ - REGISTER_REMOTE_MODEL("register_remote_model", "model_id"), + REGISTER_REMOTE_MODEL("register_remote_model", "model_id", "delete_model"), /** official workflow step name for registering a local model and associated created resource */ - REGISTER_LOCAL_MODEL("register_local_model", "model_id"), + REGISTER_LOCAL_MODEL("register_local_model", "model_id", "delete_model"), /** official workflow step name for registering a model group and associated created resource */ - REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + REGISTER_MODEL_GROUP("register_model_group", "model_group_id", null), // TODO /** official workflow step name for deploying a model and associated created resource */ - DEPLOY_MODEL("deploy_model", "model_id"), + DEPLOY_MODEL("deploy_model", "model_id", "undeploy_model"), /** official workflow step name for creating an ingest-pipeline and associated created resource */ - CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), + CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id", null), // TODO /** official workflow step name for creating an index and associated created resource */ - CREATE_INDEX("create_index", "index_name"), + CREATE_INDEX("create_index", "index_name", null), // TODO /** official workflow step name for register an agent and the associated created resource */ - REGISTER_AGENT("register_agent", "agent_id"); + REGISTER_AGENT("register_agent", "agent_id", "delete_agent"); private final String workflowStep; private final String resourceCreated; + private final String deprovisionStep; private static final Logger logger = LogManager.getLogger(WorkflowResources.class); private static final Set allResources = Stream.of(values()) .map(WorkflowResources::getResourceCreated) .collect(Collectors.toSet()); - WorkflowResources(String workflowStep, String resourceCreated) { + WorkflowResources(String workflowStep, String resourceCreated, String deprovisionStep) { this.workflowStep = workflowStep; this.resourceCreated = resourceCreated; + this.deprovisionStep = deprovisionStep; } /** @@ -68,7 +70,15 @@ public String getResourceCreated() { } /** - * gets the resources created type based on the workflowStep + * Returns the deprovisionStep for the given enum Constant + * @return the deprovisionStep of this data. + */ + public String getDeprovisionStep() { + return deprovisionStep; + } + + /** + * Gets the resources created type based on the workflowStep. * @param workflowStep workflow step name * @return the resource that will be created * @throws FlowFrameworkException if workflow step doesn't exist in enum @@ -76,7 +86,7 @@ public String getResourceCreated() { public static String getResourceByWorkflowStep(String workflowStep) throws FlowFrameworkException { if (workflowStep != null && !workflowStep.isEmpty()) { for (WorkflowResources mapping : values()) { - if (mapping.getWorkflowStep().equals(workflowStep)) { + if (workflowStep.equals(mapping.getWorkflowStep()) || workflowStep.equals(mapping.getDeprovisionStep())) { return mapping.getResourceCreated(); } } @@ -85,6 +95,24 @@ public static String getResourceByWorkflowStep(String workflowStep) throws FlowF throw new FlowFrameworkException("Unable to find resource type for step: " + workflowStep, RestStatus.BAD_REQUEST); } + /** + * Gets the deprovision step type based on the workflowStep. + * @param workflowStep workflow step name + * @return the corresponding step to deprovision + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static String getDeprovisionStepByWorkflowStep(String workflowStep) throws FlowFrameworkException { + if (workflowStep != null && !workflowStep.isEmpty()) { + for (WorkflowResources mapping : values()) { + if (mapping.getWorkflowStep().equals(workflowStep)) { + return mapping.getDeprovisionStep(); + } + } + } + logger.error("Unable to find deprovision step for step: " + workflowStep); + throw new FlowFrameworkException("Unable to find deprovision step for step: " + workflowStep, RestStatus.BAD_REQUEST); + } + /** * Returns all the possible resource created types in enum * @return a set of all the resource created types diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java index b12f4d044..9cc096ef6 100644 --- a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -176,15 +176,14 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { @Override public String toString() { - return "resources_Created [workflow_step_name= " + return "resources_Created [workflow_step_name=" + workflowStepName - + ", workflow_step_id= " + + ", workflow_step_id=" + workflowStepId - + ", resource_type= " + + ", resource_type=" + resourceType - + ", resource_id= " + + ", resource_id=" + resourceId + "]"; } - } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java new file mode 100644 index 000000000..467a683ce --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to de-provision a workflow + */ +public class RestDeprovisionWorkflowAction extends BaseRestHandler { + + private static final String DEPROVISION_WORKFLOW_ACTION = "deprovision_workflow"; + private static final Logger logger = LogManager.getLogger(RestDeprovisionWorkflowAction.class); + private final FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestDeprovisionWorkflowAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestDeprovisionWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return DEPROVISION_WORKFLOW_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("No request body is required", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + return channel -> client.execute(DeprovisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_deprovision")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java new file mode 100644 index 000000000..8efcfbbc3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestDeprovisionWorkflowAction + */ +public class DeprovisionWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/deprovision"; + /** An instance of this action */ + public static final DeprovisionWorkflowAction INSTANCE = new DeprovisionWorkflowAction(); + + private DeprovisionWorkflowAction() { + super(NAME, WorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java new file mode 100644 index 000000000..784b67374 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -0,0 +1,337 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.time.Instant; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; +import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; + +/** + * Transport Action to deprovision a workflow from a stored use case template + */ +public class DeprovisionWorkflowTransportAction extends HandledTransportAction { + + private static final String DEPROVISION_SUFFIX = "_deprovision"; + + private final Logger logger = LogManager.getLogger(DeprovisionWorkflowTransportAction.class); + + private final ThreadPool threadPool; + private final Client client; + private final WorkflowProcessSorter workflowProcessSorter; + private final WorkflowStepFactory workflowStepFactory; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private final EncryptorUtils encryptorUtils; + + /** + * Instantiates a new ProvisionWorkflowTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param threadPool The OpenSearch thread pool + * @param client The node client to retrieve a stored use case template + * @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes + * @param workflowStepFactory The factory instantiating workflow steps + * @param flowFrameworkIndicesHandler Class to handle all internal system indices actions + * @param encryptorUtils Utility class to handle encryption/decryption + */ + @Inject + public DeprovisionWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + WorkflowProcessSorter workflowProcessSorter, + WorkflowStepFactory workflowStepFactory, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + EncryptorUtils encryptorUtils + ) { + super(DeprovisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.threadPool = threadPool; + this.client = client; + this.workflowProcessSorter = workflowProcessSorter; + this.workflowStepFactory = workflowStepFactory; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.encryptorUtils = encryptorUtils; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + // Retrieve use case template from global context + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + // Stash thread context to interact with system index + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + listener.onFailure( + new FlowFrameworkException( + "Failed to retrieve template (" + workflowId + ") from global context.", + RestStatus.NOT_FOUND + ) + ); + return; + } + + // Parse template from document source + Template template = Template.parse(response.getSourceAsString()); + + // Decrypt template + template = encryptorUtils.decryptTemplateCredentials(template); + + // Sort and validate graph + Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); + workflowProcessSorter.validateGraph(provisionProcessSequence); + + // We have a valid template and sorted nodes, get the created resources + getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + logger.error("Workflow validation failed for workflow : " + workflowId); + listener.onFailure(exception); + } else { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + })); + } catch (Exception e) { + String message = "Failed to retrieve template from global context."; + logger.error(message, e); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(e))); + } + } + + private void getResourcesAndExecute( + String workflowId, + List provisionProcessSequence, + ActionListener listener + ) { + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { + // Get a map of step id to created resources + final Map resourceMap = response.getWorkflowState() + .resourcesCreated() + .stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepId, Function.identity())); + + // Now finally do the deprovision + executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener); + }, exception -> { + String message = "Failed to get workflow state for workflow " + workflowId; + logger.error(message, exception); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception))); + })); + } + + private void executeDeprovisionSequence( + String workflowId, + Map resourceMap, + List provisionProcessSequence, + ActionListener listener + ) { + // Create a list of ProcessNodes with the corresponding deprovision workflow steps + List deprovisionProcessSequence = provisionProcessSequence.stream() + // Only include nodes that created a resource + .filter(pn -> resourceMap.containsKey(pn.id())) + // Create a new ProcessNode with a deprovision step + .map(pn -> { + String stepName = pn.workflowStep().getName(); + String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName); + // Unimplemented steps presently return null, so skip + if (deprovisionStep == null) { + return null; + } + // New ID is old ID with deprovision added + String deprovisionStepId = pn.id() + DEPROVISION_SUFFIX; + return new ProcessNode( + deprovisionStepId, + workflowStepFactory.createStep(deprovisionStep), + Collections.emptyMap(), + new WorkflowData( + Map.of(getResourceByWorkflowStep(stepName), resourceMap.get(pn.id()).resourceId()), + workflowId, + deprovisionStepId + ), + Collections.emptyList(), + this.threadPool, + pn.nodeTimeout() + ); + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + // Deprovision in reverse order of provisioning to minimize risk of dependencies + Collections.reverse(deprovisionProcessSequence); + logger.info("Deprovisioning steps: {}", deprovisionProcessSequence.stream().map(ProcessNode::id).collect(Collectors.joining(", "))); + + // Repeat attempting to delete resources as long as at least one is successful + int resourceCount = deprovisionProcessSequence.size(); + while (resourceCount > 0) { + Iterator iter = deprovisionProcessSequence.iterator(); + while (iter.hasNext()) { + ProcessNode deprovisionNode = iter.next(); + ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourceMap); + String resourceNameAndId = getResourceNameAndId(resource); + CompletableFuture deprovisionFuture = deprovisionNode.execute(); + try { + deprovisionFuture.join(); + logger.info("Successful {} for {}", deprovisionNode.id(), resourceNameAndId); + // Remove from list so we don't try again + iter.remove(); + // Pause briefly before next step + Thread.sleep(100); + } catch (Throwable t) { + logger.info( + "Failed {} for {}: {}", + deprovisionNode.id(), + resourceNameAndId, + t.getCause() == null ? t.getMessage() : t.getCause().getMessage() + ); + } + } + if (deprovisionProcessSequence.size() < resourceCount) { + // If we've deleted something, decrement and try again if not zero + resourceCount = deprovisionProcessSequence.size(); + deprovisionProcessSequence = deprovisionProcessSequence.stream().map(pn -> { + return new ProcessNode( + pn.id(), + workflowStepFactory.createStep(pn.workflowStep().getName()), + pn.previousNodeInputs(), + pn.input(), + pn.predecessors(), + this.threadPool, + pn.nodeTimeout() + ); + }).collect(Collectors.toList()); + // Pause briefly before next loop + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + break; + } + } else { + // If nothing was deleted, exit loop + break; + } + } + // Get corresponding resources + List remainingResources = deprovisionProcessSequence.stream() + .map(pn -> getResourceFromDeprovisionNode(pn, resourceMap)) + .collect(Collectors.toList()); + logger.info("Resources remaining: {}", remainingResources); + updateWorkflowState(workflowId, remainingResources, listener); + } + + private void updateWorkflowState( + String workflowId, + List remainingResources, + ActionListener listener + ) { + if (remainingResources.isEmpty()) { + // Successful deprovision + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.NOT_STARTED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, Collections.emptyList()) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to NOT_STARTED", workflowId); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); + // return workflow ID + listener.onResponse(new WorkflowResponse(workflowId)); + } else { + // Failed deprovision + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.COMPLETED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, remainingResources) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to COMPLETED", workflowId); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); + // give user list of remaining resources + listener.onFailure( + new FlowFrameworkException( + "Failed to deprovision some resources: [" + + remainingResources.stream() + .map(DeprovisionWorkflowTransportAction::getResourceNameAndId) + .filter(Objects::nonNull) + .distinct() + .collect(Collectors.joining(", ")) + + "].", + RestStatus.ACCEPTED + ) + ); + } + } + + private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, Map resourceMap) { + String deprovisionId = deprovisionNode.id(); + int pos = deprovisionId.indexOf(DEPROVISION_SUFFIX); + return pos > 0 ? resourceMap.get(deprovisionId.substring(0, pos)) : null; + } + + private static String getResourceNameAndId(ResourceCreated resource) { + if (resource == null) { + return null; + } + return getResourceByWorkflowStep(resource.workflowStepName()) + " " + resource.resourceId(); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 057f13d01..a030dccfa 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -19,7 +19,7 @@ import java.io.IOException; /** - * Transport Request to create and provision a workflow + * Transport Request to create, provision, and deprovision a workflow */ public class WorkflowRequest extends ActionRequest { diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index cbf988eee..9f9529ca5 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(6, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(6, ffp.getActions().size()); + assertEquals(7, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(7, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java new file mode 100644 index 000000000..a9170e35d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestDeprovisionWorkflowActionTests extends OpenSearchTestCase { + + private RestDeprovisionWorkflowAction deprovisionWorkflowRestAction; + private String deprovisionWorkflowPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + + this.deprovisionWorkflowRestAction = new RestDeprovisionWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.deprovisionWorkflowPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_deprovision"); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestDeprovisionWorkflowActionName() { + String name = deprovisionWorkflowRestAction.getName(); + assertEquals("deprovision_workflow", name); + } + + public void testRestDeprovisiionWorkflowActionRoutes() { + List routes = deprovisionWorkflowRestAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(this.deprovisionWorkflowPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_deprovision] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java new file mode 100644 index 000000000..5d21c63d8 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; +import org.opensearch.flowframework.workflow.DeleteConnectorStep; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.AfterClass; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase { + + private static ThreadPool threadPool = new TestThreadPool(DeprovisionWorkflowTransportActionTests.class.getName()); + private Client client; + private WorkflowProcessSorter workflowProcessSorter; + private WorkflowStepFactory workflowStepFactory; + private DeleteConnectorStep deleteConnectorStep; + private DeprovisionWorkflowTransportAction deprovisionWorkflowTransportAction; + private Template template; + private GetResult getResult; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private EncryptorUtils encryptorUtils; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + this.workflowStepFactory = mock(WorkflowStepFactory.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); + + this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + threadPool, + client, + workflowProcessSorter, + workflowStepFactory, + flowFrameworkIndicesHandler, + encryptorUtils + ); + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode node = new WorkflowNode("step_1", "create_connector", Collections.emptyMap(), Collections.emptyMap()); + List nodes = List.of(node); + List edges = Collections.emptyList(); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + this.template = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of(PROVISION_WORKFLOW, workflow), + Map.of(), + TestHelpers.randomUser() + ); + this.getResult = mock(GetResult.class); + + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + ProcessNode processNode = mock(ProcessNode.class); + when(processNode.id()).thenReturn("step_1"); + when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap()); + when(processNode.input()).thenReturn(WorkflowData.EMPTY); + when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5)); + when(this.workflowProcessSorter.sortProcessNodes(any(Workflow.class), any(String.class))).thenReturn(List.of(processNode)); + this.deleteConnectorStep = mock(DeleteConnectorStep.class); + when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); + + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); + } + + public void testDeprovisionWorkflow() throws IOException { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(true); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("create_connector", "step_1", "connector_id", "connectorId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn( + CompletableFuture.completedFuture(WorkflowData.EMPTY) + ); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + } + + public void testFailedToRetrieveTemplateFromGlobalContext() { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(false); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to retrieve template (1) from global context.", exceptionCaptor.getValue().getMessage()); + } + + public void testFailToDeprovision() throws IOException { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(true); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("deploy_model", "step_1", "model_id", "modelId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException("rte")); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage()); + } +} From 4cbf9dc3f5d7bb7ebbf6bafaa0019bed5757e268 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 14 Dec 2023 17:47:01 -0800 Subject: [PATCH 24/27] [Feature/agent_framework] Adding installed plugins validation (#290) * Adding installed plugins validation Signed-off-by: Joshua Palis * Adding failure success unit tests Signed-off-by: Joshua Palis * Combining graph and installed plugin validation Signed-off-by: Joshua Palis * Removing stray comment Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 8 +- .../model/WorkflowStepValidator.java | 24 ++- .../CreateWorkflowTransportAction.java | 2 +- .../ProvisionWorkflowTransportAction.java | 2 +- .../workflow/WorkflowProcessSorter.java | 81 +++++++- .../resources/mappings/workflow-steps.json | 42 +++- .../CreateWorkflowTransportActionTests.java | 37 ++-- .../workflow/WorkflowProcessSorterTests.java | 192 +++++++++++++++++- 8 files changed, 353 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 544f6f3e1..d69c0b588 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -116,7 +116,13 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter( + workflowStepFactory, + threadPool, + clusterService, + client, + settings + ); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index eb1779e93..c9689b975 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -25,18 +25,23 @@ public class WorkflowStepValidator { private static final String INPUTS_FIELD = "inputs"; /** Outputs field name */ private static final String OUTPUTS_FIELD = "outputs"; + /** Required Plugins field name */ + private static final String REQUIRED_PLUGINS = "required_plugins"; private List inputs; private List outputs; + private List requiredPlugins; /** * Intantiate the object representing a Workflow Step validator * @param inputs the workflow step inputs * @param outputs the workflow step outputs + * @param requiredPlugins the required plugins for this workflow step */ - public WorkflowStepValidator(List inputs, List outputs) { + public WorkflowStepValidator(List inputs, List outputs, List requiredPlugins) { this.inputs = inputs; this.outputs = outputs; + this.requiredPlugins = requiredPlugins; } /** @@ -48,6 +53,7 @@ public WorkflowStepValidator(List inputs, List outputs) { public static WorkflowStepValidator parse(XContentParser parser) throws IOException { List parsedInputs = new ArrayList<>(); List parsedOutputs = new ArrayList<>(); + List requiredPlugins = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -66,11 +72,17 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept parsedOutputs.add(parser.text()); } break; + case REQUIRED_PLUGINS: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + requiredPlugins.add(parser.text()); + } + break; default: throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object."); } } - return new WorkflowStepValidator(parsedInputs, parsedOutputs); + return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins); } /** @@ -88,4 +100,12 @@ public List getInputs() { public List getOutputs() { return List.copyOf(outputs); } + + /** + * Get the required plugins + * @return the outputs + */ + public List getRequiredPlugins() { + return List.copyOf(requiredPlugins); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 765c9cae5..92f89c082 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -264,7 +264,7 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); - workflowProcessSorter.validateGraph(sortedNodes); + workflowProcessSorter.validate(sortedNodes); } } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index cd4a54a57..ff36cfd1f 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -122,7 +122,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); - workflowProcessSorter.validateGraph(provisionProcessSequence); + workflowProcessSorter.validate(provisionProcessSequence); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index da362383b..e564ad456 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,15 +10,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.plugins.PluginInfo; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; @@ -31,6 +37,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -49,6 +56,8 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; private Integer maxWorkflowSteps; + private ClusterService clusterService; + private Client client; /** * Instantiate this class. @@ -56,17 +65,21 @@ public class WorkflowProcessSorter { * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. * @param clusterService The OpenSearch cluster service. + * @param client The OpenSearch Client * @param settings OpenSerch settings */ public WorkflowProcessSorter( WorkflowStepFactory workflowStepFactory, ThreadPool threadPool, ClusterService clusterService, + Client client, Settings settings ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + this.clusterService = clusterService; + this.client = client; clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } @@ -123,13 +136,75 @@ public List sortProcessNodes(Workflow workflow, String workflowId) } /** - * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * Validates inputs and ensures the required plugins are installed for each step in a topologically sorted graph + * @param processNodes the topologically sorted list of process nodes + * @throws Exception if validation fails + */ + public void validate(List processNodes) throws Exception { + WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + validatePluginsInstalled(processNodes, validator); + validateGraph(processNodes, validator); + } + + /** + * Validates a sorted workflow, determines if each process node's required plugins are currently installed * @param processNodes A list of process nodes + * @param validator The validation definitions for the workflow steps * @throws Exception on validation failure */ - public void validateGraph(List processNodes) throws Exception { + public void validatePluginsInstalled(List processNodes, WorkflowValidator validator) throws Exception { - WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + // Retrieve node information to ascertain installed plugins + NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); + nodesInfoRequest.clear().addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); + CompletableFuture> installedPluginsFuture = new CompletableFuture<>(); + client.admin().cluster().nodesInfo(nodesInfoRequest, ActionListener.wrap(response -> { + List installedPlugins = new ArrayList<>(); + + // Retrieve installed plugin names from the local node + String localNodeId = clusterService.state().getNodes().getLocalNodeId(); + NodeInfo info = response.getNodesMap().get(localNodeId); + PluginsAndModules plugins = info.getInfo(PluginsAndModules.class); + for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + installedPlugins.add(pluginInfo.getName()); + } + + installedPluginsFuture.complete(installedPlugins); + + }, exception -> { + logger.error("Failed to retrieve installed plugins"); + installedPluginsFuture.completeExceptionally(exception); + })); + + // Block execution until installed plugin list is returned + List installedPlugins = installedPluginsFuture.get(); + + // Iterate through process nodes in graph + for (ProcessNode processNode : processNodes) { + + // Retrieve required plugins of this node based on type + String nodeType = processNode.workflowStep().getName(); + List requiredPlugins = new ArrayList<>(validator.getWorkflowStepValidators().get(nodeType).getRequiredPlugins()); + if (!installedPlugins.containsAll(requiredPlugins)) { + requiredPlugins.removeAll(installedPlugins); + throw new FlowFrameworkException( + "The workflowStep " + + processNode.workflowStep().getName() + + " requires the following plugins to be installed : " + + requiredPlugins.toString(), + RestStatus.BAD_REQUEST + ); + } + } + } + + /** + * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * @param processNodes A list of process nodes + * @param validator The validation definitions for the workflow steps + * @throws Exception on validation failure + */ + public void validateGraph(List processNodes, WorkflowValidator validator) throws Exception { // Iterate through process nodes in graph for (ProcessNode processNode : processNodes) { diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 149b1cfce..1c6e73a4c 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -1,7 +1,8 @@ { "noop": { "inputs":[], - "outputs":[] + "outputs":[], + "required_plugins":[] }, "create_index": { "inputs":[ @@ -10,7 +11,8 @@ ], "outputs":[ "index_name" - ] + ], + "required_plugins":[] }, "create_ingest_pipeline": { "inputs":[ @@ -23,7 +25,8 @@ ], "outputs":[ "pipeline_id" - ] + ], + "required_plugins":[] }, "create_connector": { "inputs":[ @@ -37,6 +40,9 @@ ], "outputs":[ "connector_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_connector": { @@ -45,6 +51,9 @@ ], "outputs":[ "connector_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_local_model": { @@ -62,6 +71,9 @@ "outputs":[ "model_id", "register_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_remote_model": { @@ -73,6 +85,9 @@ "outputs": [ "model_id", "register_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_model": { @@ -81,6 +96,9 @@ ], "outputs":[ "model_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "deploy_model": { @@ -89,6 +107,9 @@ ], "outputs":[ "deploy_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "undeploy_model": { @@ -97,6 +118,9 @@ ], "outputs":[ "success" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_model_group": { @@ -106,6 +130,9 @@ "outputs":[ "model_group_id", "model_group_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_agent": { @@ -121,6 +148,9 @@ ], "outputs":[ "agent_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_agent": { @@ -129,6 +159,9 @@ ], "outputs":[ "agent_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "create_tool": { @@ -137,6 +170,9 @@ ], "outputs": [ "tools" + ], + "required_plugins":[ + "opensearch-ml" ] } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 2e67b59d8..0addb7f1f 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -29,8 +29,6 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; -import org.opensearch.flowframework.workflow.WorkflowStepFactory; -import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -58,6 +56,8 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -70,7 +70,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private WorkflowProcessSorter workflowProcessSorter; private Template template; - private Client client = mock(Client.class); + private Client client; private ThreadPool threadPool; private ClusterSettings clusterSettings; private ClusterService clusterService; @@ -79,6 +79,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); + client = mock(Client.class); + threadPool = mock(ThreadPool.class); settings = Settings.builder() .put("plugins.flow_framework.max_workflows", 2) @@ -93,15 +95,10 @@ public void setUp() throws Exception { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); - WorkflowStepFactory factory = new WorkflowStepFactory( - Settings.EMPTY, - clusterService, - client, - mlClient, - flowFrameworkIndicesHandler - ); - this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings); + // Validation functionality should not be invoked in these unit tests, mocking instead + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + + // Spy this action to stub check max workflows this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), @@ -150,7 +147,7 @@ public void testDryRunValidation_withoutProvision_Success() { createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } - public void testDryRunValidation_Failed() { + public void testDryRunValidation_Failed() throws Exception { WorkflowNode createConnector = new WorkflowNode( "workflow_step_1", @@ -204,12 +201,12 @@ public void testDryRunValidation_Failed() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + // Stub validation failure + doThrow(Exception.class).when(workflowProcessSorter).validate(any()); WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, false, null, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); + verify(listener, times(1)).onFailure(any()); } public void testMaxWorkflow() { @@ -377,12 +374,14 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } - public void testCreateWorkflow_withDryRun_withProvision_Success() { + public void testCreateWorkflow_withDryRun_withProvision_Success() throws Exception { Template validTemplate = generateValidTemplate(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any()); WorkflowRequest workflowRequest = new WorkflowRequest( null, validTemplate, @@ -436,11 +435,13 @@ public void testCreateWorkflow_withDryRun_withProvision_Success() { assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } - public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() { + public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() throws Exception { + Template validTemplate = generateValidTemplate(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + doNothing().when(workflowProcessSorter).validate(any()); WorkflowRequest workflowRequest = new WorkflowRequest( null, validTemplate, diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index d1590acd8..2974470aa 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -8,12 +8,20 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -23,7 +31,9 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.plugins.PluginInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -50,6 +60,8 @@ import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -73,12 +85,13 @@ private static List parse(String json) throws IOException { private static TestThreadPool testThreadPool; private static WorkflowProcessSorter workflowProcessSorter; + private static Client client = mock(Client.class); + private static ClusterService clusterService = mock(ClusterService.class); + private static WorkflowValidator validator; @BeforeClass - public static void setup() { + public static void setup() throws IOException { AdminClient adminClient = mock(AdminClient.class); - ClusterService clusterService = mock(ClusterService.class); - Client client = mock(Client.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); @@ -100,7 +113,8 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, settings); + validator = WorkflowValidator.parse("mappings/workflow-steps.json"); } @AfterClass @@ -300,7 +314,7 @@ public void testSuccessfulGraphValidation() throws Exception { Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); - workflowProcessSorter.validateGraph(sortedProcessNodes); + workflowProcessSorter.validateGraph(sortedProcessNodes, validator); } public void testFailedGraphValidation() { @@ -324,9 +338,175 @@ public void testFailedGraphValidation() { List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); FlowFrameworkException ex = expectThrows( FlowFrameworkException.class, - () -> workflowProcessSorter.validateGraph(sortedProcessNodes) + () -> workflowProcessSorter.validateGraph(sortedProcessNodes, validator) ); assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } + + public void testSuccessfulInstalledPluginValidation() throws Exception { + + // Mock and stub the cluster admin client to invoke the NodesInfoRequest + AdminClient adminClient = mock(AdminClient.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + // Mock and stub the clusterservice to get the local node + ClusterState clusterState = mock(ClusterState.class); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getLocalNodeId()).thenReturn("123"); + + // Stub cluster admin client's node info request + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + + // Mock and stub Plugin info + PluginInfo mockedFlowPluginInfo = mock(PluginInfo.class); + PluginInfo mockedMlPluginInfo = mock(PluginInfo.class); + when(mockedFlowPluginInfo.getName()).thenReturn("opensearch-flow-framework"); + when(mockedMlPluginInfo.getName()).thenReturn("opensearch-ml"); + + // Mock and stub PluginsAndModules + PluginsAndModules mockedPluginsAndModules = mock(PluginsAndModules.class); + when(mockedPluginsAndModules.getPluginInfos()).thenReturn(List.of(mockedFlowPluginInfo, mockedMlPluginInfo)); + + // Mock and stub NodesInfoResponse to NodeInfo + NodeInfo nodeInfo = mock(NodeInfo.class); + @SuppressWarnings("unchecked") + Map mockedMap = mock(Map.class); + NodesInfoResponse response = mock(NodesInfoResponse.class); + when(response.getNodesMap()).thenReturn(mockedMap); + when(mockedMap.get(any())).thenReturn(nodeInfo); + when(nodeInfo.getInfo(any())).thenReturn(mockedPluginsAndModules); + + // stub on response to pass the mocked NodesInfoRepsonse + listener.onResponse(response); + return null; + + }).when(clusterAdminClient).nodesInfo(any(NodesInfoRequest.class), any()); + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterRemoteModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + + workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, validator); + } + + public void testFailedInstalledPluginValidation() throws Exception { + + // Mock and stub the cluster admin client to invoke the NodesInfoRequest + AdminClient adminClient = mock(AdminClient.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + // Mock and stub the clusterservice to get the local node + ClusterState clusterState = mock(ClusterState.class); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getLocalNodeId()).thenReturn("123"); + + // Stub cluster admin client's node info request + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + + // Mock and stub Plugin info, We ommit the opensearch-ml info here to trigger validation failure + PluginInfo mockedFlowPluginInfo = mock(PluginInfo.class); + when(mockedFlowPluginInfo.getName()).thenReturn("opensearch-flow-framework"); + + // Mock and stub PluginsAndModules + PluginsAndModules mockedPluginsAndModules = mock(PluginsAndModules.class); + when(mockedPluginsAndModules.getPluginInfos()).thenReturn(List.of(mockedFlowPluginInfo)); + + // Mock and stub NodesInfoResponse to NodeInfo + NodeInfo nodeInfo = mock(NodeInfo.class); + @SuppressWarnings("unchecked") + Map mockedMap = mock(Map.class); + NodesInfoResponse response = mock(NodesInfoResponse.class); + when(response.getNodesMap()).thenReturn(mockedMap); + when(mockedMap.get(any())).thenReturn(nodeInfo); + when(nodeInfo.getInfo(any())).thenReturn(mockedPluginsAndModules); + + // stub on response to pass the mocked NodesInfoRepsonse + listener.onResponse(response); + return null; + + }).when(clusterAdminClient).nodesInfo(any(NodesInfoRequest.class), any()); + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterRemoteModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + + FlowFrameworkException exception = expectThrows( + FlowFrameworkException.class, + () -> workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, validator) + ); + + assertEquals( + "The workflowStep create_connector requires the following plugins to be installed : [opensearch-ml]", + exception.getMessage() + ); + } } From 856631de8ac218a3b02f1eff7ff51dc43867141f Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 14 Dec 2023 18:10:12 -0800 Subject: [PATCH 25/27] Add Delete Workflow API (#294) Signed-off-by: Daniel Widdis --- .../flowframework/FlowFrameworkPlugin.java | 5 + .../rest/RestDeleteWorkflowAction.java | 105 +++++++++++++++ .../transport/DeleteWorkflowAction.java | 28 ++++ .../DeleteWorkflowTransportAction.java | 66 ++++++++++ .../DeprovisionWorkflowTransportAction.java | 2 +- .../FlowFrameworkPluginTests.java | 4 +- .../rest/RestDeleteWorkflowActionTests.java | 96 ++++++++++++++ .../rest/RestGetWorkflowActionTests.java | 8 +- .../DeleteWorkflowTransportActionTests.java | 123 ++++++++++++++++++ 9 files changed, 430 insertions(+), 7 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestDeleteWorkflowActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index d69c0b588..4ffd69342 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -28,6 +28,7 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestDeleteWorkflowAction; import org.opensearch.flowframework.rest.RestDeprovisionWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; @@ -36,6 +37,8 @@ import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.DeleteWorkflowAction; +import org.opensearch.flowframework.transport.DeleteWorkflowTransportAction; import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; import org.opensearch.flowframework.transport.DeprovisionWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; @@ -139,6 +142,7 @@ public List getRestHandlers( ) { return ImmutableList.of( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), + new RestDeleteWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestDeprovisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), @@ -152,6 +156,7 @@ public List getRestHandlers( public List> getActions() { return ImmutableList.of( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), + new ActionHandler<>(DeleteWorkflowAction.INSTANCE, DeleteWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), new ActionHandler<>(DeprovisionWorkflowAction.INSTANCE, DeprovisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java new file mode 100644 index 000000000..e017ee581 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.DeleteWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to delete a stored template + */ +public class RestDeleteWorkflowAction extends BaseRestHandler { + + private static final String DELETE_WORKFLOW_ACTION = "delete_workflow"; + private static final Logger logger = LogManager.getLogger(RestDeleteWorkflowAction.class); + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestDeleteWorkflowAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestDeleteWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return DELETE_WORKFLOW_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, WORKFLOW_ID))); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String workflowId = request.param(WORKFLOW_ID); + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + // Validate params + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(DeleteWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + + } catch (IOException e) { + logger.error("Failed to send back delete workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowAction.java new file mode 100644 index 000000000..006f1f205 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowAction.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowAction + */ +public class DeleteWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/delete"; + /** An instance of this action */ + public static final DeleteWorkflowAction INSTANCE = new DeleteWorkflowAction(); + + private DeleteWorkflowAction() { + super(NAME, DeleteResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java new file mode 100644 index 000000000..4bc45da22 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; + +/** + * Transport action to retrieve a use case template within the Global Context + */ +public class DeleteWorkflowTransportAction extends HandledTransportAction { + + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private final Client client; + + /** + * Instantiates a new DeleteWorkflowTransportAction instance + * @param transportService the transport service + * @param actionFilters action filters + * @param flowFrameworkIndicesHandler The Flow Framework indices handler + * @param client the OpenSearch Client + */ + @Inject + public DeleteWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + Client client + ) { + super(DeleteWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.client = client; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String workflowId = request.getWorkflowId(); + DeleteRequest deleteRequest = new DeleteRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); + client.delete(deleteRequest, ActionListener.runBefore(listener, () -> context.restore())); + } else { + listener.onFailure(new FlowFrameworkException("There are no templates in the global context.", RestStatus.NOT_FOUND)); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 784b67374..1777a2457 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -131,7 +131,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); - workflowProcessSorter.validateGraph(provisionProcessSequence); + workflowProcessSorter.validate(provisionProcessSequence); // We have a valid template and sorted nodes, get the created resources getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 9f9529ca5..b08c27cfb 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(7, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(7, ffp.getActions().size()); + assertEquals(8, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(8, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestDeleteWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestDeleteWorkflowActionTests.java new file mode 100644 index 000000000..93fc27623 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestDeleteWorkflowActionTests.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestDeleteWorkflowActionTests extends OpenSearchTestCase { + private RestDeleteWorkflowAction restDeleteWorkflowAction; + private String getPath; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + private NodeClient nodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.getPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restDeleteWorkflowAction = new RestDeleteWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestDeleteWorkflowActionName() { + String name = restDeleteWorkflowAction.getName(); + assertEquals("delete_workflow", name); + } + + public void testRestDeleteWorkflowActionRoutes() { + List routes = restDeleteWorkflowAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.DELETE, routes.get(0).getMethod()); + assertEquals(this.getPath, routes.get(0).getPath()); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.DELETE) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restDeleteWorkflowAction.handleRequest(request, channel, nodeClient); + }); + assertEquals("request [DELETE /_plugins/_flow_framework/workflow/{workflow_id}] does not support having a body", ex.getMessage()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.DELETE) + .withPath(this.getPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + restDeleteWorkflowAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.DELETE) + .withPath(this.getPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restDeleteWorkflowAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java index 3a51f1a9e..31cfc701f 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -56,7 +56,7 @@ public void testRestGetWorkflowActionRoutes() { } public void testInvalidRequestWithContent() { - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); @@ -65,13 +65,13 @@ public void testInvalidRequestWithContent() { IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { restGetWorkflowAction.handleRequest(request, channel, nodeClient); }); - assertEquals("request [POST /_plugins/_flow_framework/workflow/{workflow_id}] does not support having a body", ex.getMessage()); + assertEquals("request [GET /_plugins/_flow_framework/workflow/{workflow_id}] does not support having a body", ex.getMessage()); } public void testNullWorkflowId() throws Exception { // Request with no params - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .build(); @@ -85,7 +85,7 @@ public void testNullWorkflowId() throws Exception { public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java new file mode 100644 index 000000000..b0ae61e6d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import org.mockito.ArgumentCaptor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DeleteWorkflowTransportActionTests extends OpenSearchTestCase { + + private Client client; + private DeleteWorkflowTransportAction deleteWorkflowTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.deleteWorkflowTransportAction = new DeleteWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + flowFrameworkIndicesHandler, + client + ); + + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + + } + + public void testDeleteWorkflowNoGlobalContext() { + + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest("1", null); + deleteWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("There are no templates in the global context.")); + } + + public void testDeleteWorkflowSuccess() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.delete to force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + responseListener.onResponse(new DeleteResponse(shardId, workflowId, 1, 1, 1, true)); + return null; + }).when(client).delete(any(DeleteRequest.class), any()); + + deleteWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(Result.DELETED, responseCaptor.getValue().getResult()); + } + + public void testDeleteWorkflowNotFound() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.delete to force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + responseListener.onResponse(new DeleteResponse(shardId, workflowId, 1, 1, 1, false)); + return null; + }).when(client).delete(any(DeleteRequest.class), any()); + + deleteWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(Result.NOT_FOUND, responseCaptor.getValue().getResult()); + } +} From 280ebc9e04c440ed5fead4a6fc2e269a572d829d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 15 Dec 2023 09:49:02 -0800 Subject: [PATCH 26/27] Consume REST params and consistently handle error messages (#295) * Always consume the workflow_id param Signed-off-by: Daniel Widdis * Delegate no-content error message to BaseRestHandler Signed-off-by: Daniel Widdis * Don't lose FlowFrameworkException status Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../flowframework/rest/RestCreateWorkflowAction.java | 3 +-- .../flowframework/rest/RestDeleteWorkflowAction.java | 3 ++- .../rest/RestDeprovisionWorkflowAction.java | 6 +++--- .../flowframework/rest/RestGetWorkflowAction.java | 11 ++++++----- .../rest/RestGetWorkflowStateAction.java | 10 ++++++---- .../rest/RestProvisionWorkflowAction.java | 3 ++- .../rest/RestGetWorkflowStateActionTests.java | 8 ++++---- 7 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index deeabdd76..5d8aed031 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -78,6 +78,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String workflowId = request.param(WORKFLOW_ID); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", @@ -88,8 +89,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } try { - - String workflowId = request.param(WORKFLOW_ID); Template template = Template.parse(request.content().utf8ToString()); boolean dryRun = request.paramAsBoolean(DRY_RUN, false); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java index e017ee581..cd0672e62 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java @@ -72,7 +72,8 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request } // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + // BaseRestHandler will give appropriate error message + return channel -> channel.sendResponse(null); } // Validate params if (workflowId == null) { diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java index 467a683ce..bfd5c70d4 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java @@ -57,7 +57,7 @@ public String getName() { @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - + String workflowId = request.param(WORKFLOW_ID); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -67,10 +67,10 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request } // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("No request body is required", RestStatus.BAD_REQUEST); + // BaseRestHandler will give appropriate error message + return channel -> channel.sendResponse(null); } // Validate params - String workflowId = request.param(WORKFLOW_ID); if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 5a92e9c0e..93ea6d134 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -62,7 +62,7 @@ public List routes() { @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - + String workflowId = request.param(WORKFLOW_ID); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -70,13 +70,12 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request RestStatus.FORBIDDEN ); } - // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + // BaseRestHandler will give appropriate error message + return channel -> channel.sendResponse(null); } // Validate params - String workflowId = request.param(WORKFLOW_ID); if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } @@ -87,7 +86,9 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { try { - FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java index ab7335b2d..20f8d69b7 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java @@ -57,7 +57,7 @@ public String getName() { @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - + String workflowId = request.param(WORKFLOW_ID); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -68,10 +68,10 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + // BaseRestHandler will give appropriate error message + return channel -> channel.sendResponse(null); } // Validate params - String workflowId = request.param(WORKFLOW_ID); if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } @@ -83,7 +83,9 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { try { - FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 9e6eb4d01..7e4d68183 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -77,7 +77,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + // BaseRestHandler will give appropriate error message + return channel -> channel.sendResponse(null); } // Validate params if (workflowId == null) { diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java index dc605a5cd..06c4a7053 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java @@ -63,7 +63,7 @@ public void testRestGetWorkflowStateActionRoutes() { public void testNullWorkflowId() throws Exception { // Request with no params - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .build(); @@ -76,7 +76,7 @@ public void testNullWorkflowId() throws Exception { } public void testInvalidRequestWithContent() { - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); @@ -86,14 +86,14 @@ public void testInvalidRequestWithContent() { restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); }); assertEquals( - "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", + "request [GET /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", ex.getMessage() ); } public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) .withPath(this.getPath) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); From ad8e14224fca2ba0c59a0c4c75c91b02b07c4c04 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 15 Dec 2023 10:00:21 -0800 Subject: [PATCH 27/27] Allow YAML Templates (#296) Signed-off-by: Daniel Widdis --- .../java/org/opensearch/flowframework/model/Template.java | 4 ++-- .../flowframework/rest/RestCreateWorkflowAction.java | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index f4a8b1958..3f42c225e 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -260,9 +260,9 @@ public void writeTo(StreamOutput output) throws IOException { } /** - * Parse raw json content into a Template instance. + * Parse raw xContent into a Template instance. * - * @param parser json based content parser + * @param parser xContent based content parser * @return an instance of the template * @throws IOException if content can't be parsed correctly */ diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 5d8aed031..e254b66f7 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -18,6 +18,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; @@ -30,6 +31,7 @@ import java.util.List; import java.util.Locale; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; @@ -89,7 +91,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } try { - Template template = Template.parse(request.content().utf8ToString()); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Template template = Template.parse(parser); boolean dryRun = request.paramAsBoolean(DRY_RUN, false); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);