Skip to content

Commit

Permalink
[Backport 2.x] Add optional config field to tool step (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#899) (opensearch-project#920)

Add optional config field to tool step (opensearch-project#899)

* Add optional config field to tool step



* Complete TODOs now that upstream is merged



---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Oct 21, 2024
1 parent 0bb3ea0 commit 7baa162
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.17...2.x)
### Features
- Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651))
- Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899))

### Enhancements
- Incrementally remove resources from workflow state during deprovisioning ([#898](https://github.com/opensearch-project/flow-framework/pull/898))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ private CommonValue() {}
public static final String TOOLS_FIELD = "tools";
/** The tools order field for an agent */
public static final String TOOLS_ORDER_FIELD = "tools_order";
/** The tools config field */
public static final String CONFIG_FIELD = "config";
/** The memory field for an agent */
public static final String MEMORY_FIELD = "memory";
/** The app type field for an agent */
Expand Down
31 changes: 24 additions & 7 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.agent.MLToolSpec;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static org.opensearch.flowframework.common.CommonValue.CONFIG_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
Expand All @@ -38,7 +41,21 @@ public class ToolStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(ToolStep.class);
PlainActionFuture<WorkflowData> toolFuture = PlainActionFuture.newFuture();
static final String NAME = "create_tool";

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
public static final String NAME = "create_tool";
/** Required input keys */
public static final Set<String> REQUIRED_INPUTS = Set.of(TYPE);
/** Optional input keys */
public static final Set<String> OPTIONAL_INPUTS = Set.of(
NAME_FIELD,
DESCRIPTION_FIELD,
PARAMETERS_FIELD,
CONFIG_FIELD,
INCLUDE_OUTPUT_IN_AGENT_RESPONSE
);
/** Provided output keys */
public static final Set<String> PROVIDED_OUTPUTS = Set.of(TOOLS_FIELD);

@Override
public PlainActionFuture<WorkflowData> execute(
Expand All @@ -48,13 +65,10 @@ public PlainActionFuture<WorkflowData> execute(
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
Set<String> requiredKeys = Set.of(TYPE);
Set<String> optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
REQUIRED_INPUTS,
OPTIONAL_INPUTS,
currentNodeInputs,
outputs,
previousNodeInputs,
Expand All @@ -69,11 +83,13 @@ public PlainActionFuture<WorkflowData> execute(
// parse connector_id, model_id and agent_id from previous node inputs
Set<String> toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID);
Map<String, String> parameters = getToolsParametersMap(
inputs.get(PARAMETERS_FIELD),
inputs.getOrDefault(PARAMETERS_FIELD, new HashMap<>()),
previousNodeInputs,
outputs,
toolParameterKeys
);
@SuppressWarnings("unchecked")
Map<String, String> config = (Map<String, String>) inputs.getOrDefault(CONFIG_FIELD, Collections.emptyMap());

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

Expand All @@ -90,6 +106,7 @@ public PlainActionFuture<WorkflowData> execute(
if (includeOutputInAgentResponse != null) {
builder.includeOutputInAgentResponse(includeOutputInAgentResponse);
}
builder.configMap(config);

MLToolSpec mlToolSpec = builder.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
Expand Down Expand Up @@ -231,7 +230,7 @@ public enum WorkflowSteps {
DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Create Tool Step */
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null),
CREATE_TOOL(ToolStep.NAME, ToolStep.REQUIRED_INPUTS, ToolStep.PROVIDED_OUTPUTS, List.of(OPENSEARCH_ML), null),

/** Create Ingest Pipeline Step */
CREATE_INGEST_PIPELINE(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ public void setUp() throws Exception {
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false);
MLToolSpec tools = MLToolSpec.builder()
.type("tool1")
.name("CatIndexTool")
.description("desc")
.parameters(Collections.emptyMap())
.includeOutputInAgentResponse(false)
.build();

LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public void setUp() throws Exception {
Map.entry("name", "name"),
Map.entry("description", "description"),
Map.entry("parameters", Collections.emptyMap()),
Map.entry("config", Map.of("foo", "bar")),
Map.entry("include_output_in_agent_response", "false")
),
"test-id",
Expand Down Expand Up @@ -102,6 +103,7 @@ public void testTool() throws ExecutionException, InterruptedException {
);
assertTrue(future.isDone());
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
assertEquals(Map.of("foo", "bar"), ((MLToolSpec) future.get().getContent().get("tools")).getConfigMap());
}

public void testBoolParseFail() {
Expand Down

0 comments on commit 7baa162

Please sign in to comment.