Skip to content

Commit

Permalink
Fix by skipping put "input" parameter if it already exists.
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws committed Aug 23, 2024
1 parent df16131 commit 2eb24e2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,37 @@ public static Map<String, String> constructToolParams(
}
return toolParams;
}

public static Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params) {
Map<String, String> executeParams = new HashMap<>();
// tooSpec parameter may override the parameters in params.
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
// To avoid overriding the default "input" parameter, skip put if it already exists.
if (key.equals("input") && executeParams.containsKey("input"))
continue;
String toBeReplaced = null;
if (key.startsWith(toolSpec.getType() + ".")) {
toBeReplaced = toolSpec.getType() + ".";
}
if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) {
toBeReplaced = toolSpec.getName() + ".";
}
if (toBeReplaced != null) {
executeParams.put(key.replace(toBeReplaced, ""), params.get(key));
} else {
executeParams.put(key, params.get(key));
}
}

if (executeParams.containsKey("input")) {
String input = executeParams.get("input");
StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}");
input = substitutor.replace(input);
executeParams.put("input", input);
}
return executeParams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolExecuteParams;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;

Expand All @@ -22,13 +23,11 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -406,34 +405,4 @@ String parseResponse(Object output) throws IOException {
}
}
}

@VisibleForTesting
Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params) {
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
for (String key : params.keySet()) {
String toBeReplaced = null;
if (key.startsWith(toolSpec.getType() + ".")) {
toBeReplaced = toolSpec.getType() + ".";
}
if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) {
toBeReplaced = toolSpec.getName() + ".";
}
if (toBeReplaced != null) {
executeParams.put(key.replace(toBeReplaced, ""), params.get(key));
} else {
executeParams.put(key, params.get(key));
}
}

if (executeParams.containsKey("input")) {
String input = executeParams.get("input");
StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}");
input = substitutor.replace(input);
executeParams.put("input", input);
}
return executeParams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolExecuteParams;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;

import java.io.IOException;
Expand All @@ -18,7 +19,6 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -266,35 +266,4 @@ Tool createTool(MLToolSpec toolSpec) {
}
return tool;
}

@VisibleForTesting
Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params) {
Map<String, String> executeParams = new HashMap<>();
for (String key : params.keySet()) {
String toBeReplaced = null;
if (key.startsWith(toolSpec.getType() + ".")) {
toBeReplaced = toolSpec.getType() + ".";
}
if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) {
toBeReplaced = toolSpec.getName() + ".";
}
if (toBeReplaced != null) {
executeParams.put(key.replace(toBeReplaced, ""), params.get(key));
} else {
executeParams.put(key, params.get(key));
}
}
// tooSpec parameter may override the parameters in params.
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}

if (executeParams.containsKey("input")) {
String input = executeParams.get("input");
StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}");
input = substitutor.replace(input);
executeParams.put("input", input);
}
return executeParams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ public void testGetToolExecuteParams() {

Map<String, String> params = Map.of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4");

Map<String, String> result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params);
Map<String, String> result = AgentUtils.getToolExecuteParams(toolSpec, params);

assertEquals("value1", result.get("param1"));
assertEquals("value3", result.get("param3"));
Expand Down Expand Up @@ -322,7 +322,7 @@ public void testGetToolExecuteParamsWithInputSubstitution() {
);

// Execute the method
Map<String, String> result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params);
Map<String, String> result = AgentUtils.getToolExecuteParams(toolSpec, params);

// Assertions
assertEquals("value1", result.get("param1"));
Expand All @@ -335,6 +335,33 @@ public void testGetToolExecuteParamsWithInputSubstitution() {
assertEquals(expectedInput, result.get("input"));
}

@Test
public void testGetToolExecuteParamsWithInputConflict() {
// Setup ToolSpec with parameters
MLToolSpec toolSpec = mock(MLToolSpec.class);
when(toolSpec.getParameters())
.thenReturn(Map.of("param1", "value1", "input", "Input contains ${parameters.param1}, ${parameters.param4}"));
when(toolSpec.getType()).thenReturn("toolType");
when(toolSpec.getName()).thenReturn("toolName");

// Setup params with a special 'input' key for substitution
Map<String, String> params = Map
.of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4", "input", "Input In Params");

// Execute the method
Map<String, String> result = AgentUtils.getToolExecuteParams(toolSpec, params);

// Assertions
assertEquals("value1", result.get("param1"));
assertEquals("value3", result.get("param3"));
assertEquals("value4", result.get("param4"));
assertFalse(result.containsKey("toolType.param2"));

// The input parameter from params won't override the input parameter from ToolSpec
String expectedInput = "Input contains value1, value4";
assertEquals(expectedInput, result.get("input"));
}

@Test
public void testCreateTool() {
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).description("description").type(FIRST_TOOL).build();
Expand Down

0 comments on commit 2eb24e2

Please sign in to comment.