From 2eb24e2f70522959c4ed7b8b59858977eaa0b51a Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Fri, 23 Aug 2024 08:39:02 +0800 Subject: [PATCH] Fix by skipping put "input" parameter if it already exists. Signed-off-by: Heng Qian --- .../engine/algorithms/agent/AgentUtils.java | 33 +++++++++++++++++++ .../MLConversationalFlowAgentRunner.java | 33 +------------------ .../algorithms/agent/MLFlowAgentRunner.java | 33 +------------------ .../agent/MLFlowAgentRunnerTest.java | 31 +++++++++++++++-- 4 files changed, 64 insertions(+), 66 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index f424b3f624..78030c4554 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -480,4 +480,37 @@ public static Map constructToolParams( } return toolParams; } + + public static Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { + Map executeParams = new HashMap<>(); + // tooSpec parameter may override the parameters in params. + if (toolSpec.getParameters() != null) { + executeParams.putAll(toolSpec.getParameters()); + } + for (String key : params.keySet()) { + // To avoid overriding the default "input" parameter, skip put if it already exists. + if (key.equals("input") && executeParams.containsKey("input")) + continue; + String toBeReplaced = null; + if (key.startsWith(toolSpec.getType() + ".")) { + toBeReplaced = toolSpec.getType() + "."; + } + if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) { + toBeReplaced = toolSpec.getName() + "."; + } + if (toBeReplaced != null) { + executeParams.put(key.replace(toBeReplaced, ""), params.get(key)); + } else { + executeParams.put(key, params.get(key)); + } + } + + if (executeParams.containsKey("input")) { + String input = executeParams.get("input"); + StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}"); + input = substitutor.replace(input); + executeParams.put("input", input); + } + return executeParams; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 672890c030..1d5d6687cb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolExecuteParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; @@ -22,13 +23,11 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.StepListener; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -406,34 +405,4 @@ String parseResponse(Object output) throws IOException { } } } - - @VisibleForTesting - Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { - Map executeParams = new HashMap<>(); - if (toolSpec.getParameters() != null) { - executeParams.putAll(toolSpec.getParameters()); - } - for (String key : params.keySet()) { - String toBeReplaced = null; - if (key.startsWith(toolSpec.getType() + ".")) { - toBeReplaced = toolSpec.getType() + "."; - } - if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) { - toBeReplaced = toolSpec.getName() + "."; - } - if (toBeReplaced != null) { - executeParams.put(key.replace(toBeReplaced, ""), params.get(key)); - } else { - executeParams.put(key, params.get(key)); - } - } - - if (executeParams.containsKey("input")) { - String input = executeParams.get("input"); - StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}"); - input = substitutor.replace(input); - executeParams.put("input", input); - } - return executeParams; - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 865594b743..0b672285be 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -7,6 +7,7 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolExecuteParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; import java.io.IOException; @@ -18,7 +19,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.StepListener; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -266,35 +266,4 @@ Tool createTool(MLToolSpec toolSpec) { } return tool; } - - @VisibleForTesting - Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { - Map executeParams = new HashMap<>(); - for (String key : params.keySet()) { - String toBeReplaced = null; - if (key.startsWith(toolSpec.getType() + ".")) { - toBeReplaced = toolSpec.getType() + "."; - } - if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) { - toBeReplaced = toolSpec.getName() + "."; - } - if (toBeReplaced != null) { - executeParams.put(key.replace(toBeReplaced, ""), params.get(key)); - } else { - executeParams.put(key, params.get(key)); - } - } - // tooSpec parameter may override the parameters in params. - if (toolSpec.getParameters() != null) { - executeParams.putAll(toolSpec.getParameters()); - } - - if (executeParams.containsKey("input")) { - String input = executeParams.get("input"); - StringSubstitutor substitutor = new StringSubstitutor(executeParams, "${parameters.", "}"); - input = substitutor.replace(input); - executeParams.put("input", input); - } - return executeParams; - } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index 609609438a..6efe0119b0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -292,7 +292,7 @@ public void testGetToolExecuteParams() { Map params = Map.of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4"); - Map result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params); + Map result = AgentUtils.getToolExecuteParams(toolSpec, params); assertEquals("value1", result.get("param1")); assertEquals("value3", result.get("param3")); @@ -322,7 +322,7 @@ public void testGetToolExecuteParamsWithInputSubstitution() { ); // Execute the method - Map result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params); + Map result = AgentUtils.getToolExecuteParams(toolSpec, params); // Assertions assertEquals("value1", result.get("param1")); @@ -335,6 +335,33 @@ public void testGetToolExecuteParamsWithInputSubstitution() { assertEquals(expectedInput, result.get("input")); } + @Test + public void testGetToolExecuteParamsWithInputConflict() { + // Setup ToolSpec with parameters + MLToolSpec toolSpec = mock(MLToolSpec.class); + when(toolSpec.getParameters()) + .thenReturn(Map.of("param1", "value1", "input", "Input contains ${parameters.param1}, ${parameters.param4}")); + when(toolSpec.getType()).thenReturn("toolType"); + when(toolSpec.getName()).thenReturn("toolName"); + + // Setup params with a special 'input' key for substitution + Map params = Map + .of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4", "input", "Input In Params"); + + // Execute the method + Map result = AgentUtils.getToolExecuteParams(toolSpec, params); + + // Assertions + assertEquals("value1", result.get("param1")); + assertEquals("value3", result.get("param3")); + assertEquals("value4", result.get("param4")); + assertFalse(result.containsKey("toolType.param2")); + + // The input parameter from params won't override the input parameter from ToolSpec + String expectedInput = "Input contains value1, value4"; + assertEquals(expectedInput, result.get("input")); + } + @Test public void testCreateTool() { MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).description("description").type(FIRST_TOOL).build();