diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b1bc8270..adf7b3209 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.17...2.x) ### Features - Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651)) +- Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899)) ### Enhancements - Incrementally remove resources from workflow state during deprovisioning ([#898](https://github.com/opensearch-project/flow-framework/pull/898)) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 898675d94..9c88788b3 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -164,6 +164,8 @@ private CommonValue() {} public static final String TOOLS_FIELD = "tools"; /** The tools order field for an agent */ public static final String TOOLS_ORDER_FIELD = "tools_order"; + /** The tools config field */ + public static final String CONFIG_FIELD = "config"; /** The memory field for an agent */ public static final String MEMORY_FIELD = "memory"; /** The app type field for an agent */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 45e2ee240..9d13c6953 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -17,10 +17,13 @@ import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.Set; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; @@ -38,7 +41,21 @@ public class ToolStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ToolStep.class); PlainActionFuture toolFuture = PlainActionFuture.newFuture(); - static final String NAME = "create_tool"; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + public static final String NAME = "create_tool"; + /** Required input keys */ + public static final Set REQUIRED_INPUTS = Set.of(TYPE); + /** Optional input keys */ + public static final Set OPTIONAL_INPUTS = Set.of( + NAME_FIELD, + DESCRIPTION_FIELD, + PARAMETERS_FIELD, + CONFIG_FIELD, + INCLUDE_OUTPUT_IN_AGENT_RESPONSE + ); + /** Provided output keys */ + public static final Set PROVIDED_OUTPUTS = Set.of(TOOLS_FIELD); @Override public PlainActionFuture execute( @@ -48,13 +65,10 @@ public PlainActionFuture execute( Map previousNodeInputs, Map params ) { - Set requiredKeys = Set.of(TYPE); - Set optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE); - try { Map inputs = ParseUtils.getInputsFromPreviousSteps( - requiredKeys, - optionalKeys, + REQUIRED_INPUTS, + OPTIONAL_INPUTS, currentNodeInputs, outputs, previousNodeInputs, @@ -69,11 +83,13 @@ public PlainActionFuture execute( // parse connector_id, model_id and agent_id from previous node inputs Set toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID); Map parameters = getToolsParametersMap( - inputs.get(PARAMETERS_FIELD), + inputs.getOrDefault(PARAMETERS_FIELD, new HashMap<>()), previousNodeInputs, outputs, toolParameterKeys ); + @SuppressWarnings("unchecked") + Map config = (Map) inputs.getOrDefault(CONFIG_FIELD, Collections.emptyMap()); MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); @@ -90,6 +106,7 @@ public PlainActionFuture execute( if (includeOutputInAgentResponse != null) { builder.includeOutputInAgentResponse(includeOutputInAgentResponse); } + builder.configMap(config); MLToolSpec mlToolSpec = builder.build(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 9fc8baada..65e8dea78 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -52,7 +52,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; -import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; @@ -231,7 +230,7 @@ public enum WorkflowSteps { DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Create Tool Step */ - CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null), + CREATE_TOOL(ToolStep.NAME, ToolStep.REQUIRED_INPUTS, ToolStep.PROVIDED_OUTPUTS, List.of(OPENSEARCH_ML), null), /** Create Ingest Pipeline Step */ CREATE_INGEST_PIPELINE( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index c2b3dcca1..8def95f58 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -56,7 +56,13 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); + MLToolSpec tools = MLToolSpec.builder() + .type("tool1") + .name("CatIndexTool") + .description("desc") + .parameters(Collections.emptyMap()) + .includeOutputInAgentResponse(false) + .build(); LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 029b5c835..2b5e5b7fa 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -61,6 +61,7 @@ public void setUp() throws Exception { Map.entry("name", "name"), Map.entry("description", "description"), Map.entry("parameters", Collections.emptyMap()), + Map.entry("config", Map.of("foo", "bar")), Map.entry("include_output_in_agent_response", "false") ), "test-id", @@ -102,6 +103,7 @@ public void testTool() throws ExecutionException, InterruptedException { ); assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + assertEquals(Map.of("foo", "bar"), ((MLToolSpec) future.get().getContent().get("tools")).getConfigMap()); } public void testBoolParseFail() {