Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add optional config field to tool step (#899) #920

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.17...2.x)
### Features
- Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651))
- Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899))

### Enhancements
- Incrementally remove resources from workflow state during deprovisioning ([#898](https://github.com/opensearch-project/flow-framework/pull/898))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ private CommonValue() {}
public static final String TOOLS_FIELD = "tools";
/** The tools order field for an agent */
public static final String TOOLS_ORDER_FIELD = "tools_order";
/** The tools config field */
public static final String CONFIG_FIELD = "config";
/** The memory field for an agent */
public static final String MEMORY_FIELD = "memory";
/** The app type field for an agent */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.agent.MLToolSpec;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static org.opensearch.flowframework.common.CommonValue.CONFIG_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
Expand All @@ -38,7 +41,21 @@ public class ToolStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(ToolStep.class);
PlainActionFuture<WorkflowData> toolFuture = PlainActionFuture.newFuture();
static final String NAME = "create_tool";

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
public static final String NAME = "create_tool";
/** Required input keys */
public static final Set<String> REQUIRED_INPUTS = Set.of(TYPE);
/** Optional input keys */
public static final Set<String> OPTIONAL_INPUTS = Set.of(
NAME_FIELD,
DESCRIPTION_FIELD,
PARAMETERS_FIELD,
CONFIG_FIELD,
INCLUDE_OUTPUT_IN_AGENT_RESPONSE
);
/** Provided output keys */
public static final Set<String> PROVIDED_OUTPUTS = Set.of(TOOLS_FIELD);

@Override
public PlainActionFuture<WorkflowData> execute(
Expand All @@ -48,13 +65,10 @@ public PlainActionFuture<WorkflowData> execute(
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
Set<String> requiredKeys = Set.of(TYPE);
Set<String> optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
REQUIRED_INPUTS,
OPTIONAL_INPUTS,
currentNodeInputs,
outputs,
previousNodeInputs,
Expand All @@ -69,11 +83,13 @@ public PlainActionFuture<WorkflowData> execute(
// parse connector_id, model_id and agent_id from previous node inputs
Set<String> toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID);
Map<String, String> parameters = getToolsParametersMap(
inputs.get(PARAMETERS_FIELD),
inputs.getOrDefault(PARAMETERS_FIELD, new HashMap<>()),
previousNodeInputs,
outputs,
toolParameterKeys
);
@SuppressWarnings("unchecked")
Map<String, String> config = (Map<String, String>) inputs.getOrDefault(CONFIG_FIELD, Collections.emptyMap());

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();

Expand All @@ -90,6 +106,7 @@ public PlainActionFuture<WorkflowData> execute(
if (includeOutputInAgentResponse != null) {
builder.includeOutputInAgentResponse(includeOutputInAgentResponse);
}
builder.configMap(config);

MLToolSpec mlToolSpec = builder.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
Expand Down Expand Up @@ -231,7 +230,7 @@ public enum WorkflowSteps {
DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),

/** Create Tool Step */
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null),
CREATE_TOOL(ToolStep.NAME, ToolStep.REQUIRED_INPUTS, ToolStep.PROVIDED_OUTPUTS, List.of(OPENSEARCH_ML), null),

/** Create Ingest Pipeline Step */
CREATE_INGEST_PIPELINE(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ public void setUp() throws Exception {
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false);
MLToolSpec tools = MLToolSpec.builder()
.type("tool1")
.name("CatIndexTool")
.description("desc")
.parameters(Collections.emptyMap())
.includeOutputInAgentResponse(false)
.build();

LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public void setUp() throws Exception {
Map.entry("name", "name"),
Map.entry("description", "description"),
Map.entry("parameters", Collections.emptyMap()),
Map.entry("config", Map.of("foo", "bar")),
Map.entry("include_output_in_agent_response", "false")
),
"test-id",
Expand Down Expand Up @@ -102,6 +103,7 @@ public void testTool() throws ExecutionException, InterruptedException {
);
assertTrue(future.isDone());
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
assertEquals(Map.of("foo", "bar"), ((MLToolSpec) future.get().getContent().get("tools")).getConfigMap());
}

public void testBoolParseFail() {
Expand Down
Loading