Skip to content

Commit

Permalink
Add support for including intermediary tool response to agent response (
Browse files Browse the repository at this point in the history
#1611)

* Add support for including intermediay tool response to agent response

Signed-off-by: Arjun kumar Giri <[email protected]>

* Add include intermediary response support to react agent runner

Signed-off-by: Arjun kumar Giri <[email protected]>

* Rebase from upstream

Updated version of awssdk (#1605) (#1606)

Signed-off-by: Owais Kazi <[email protected]>
(cherry picked from commit 7e44eb2)

Co-authored-by: Owais Kazi <[email protected]>

create agent

Signed-off-by: Jing Zhang <[email protected]>

CatIndexTool implementation (#1582)

* CatIndexTool implementation

Signed-off-by: Daniel Widdis <[email protected]>

* Match implementation to REST API to respect index permissions

Signed-off-by: Daniel Widdis <[email protected]>

* Javadoc fixes

Signed-off-by: Daniel Widdis <[email protected]>

* Reduce fields to exactly match _cat/indices

Signed-off-by: Daniel Widdis <[email protected]>

* Add TODO, clean up unused code/imports

Signed-off-by: Daniel Widdis <[email protected]>

* Rebase with upstream

Signed-off-by: Daniel Widdis <[email protected]>

* Remove alias getters/setters accidentally kept when rebasing

Signed-off-by: Daniel Widdis <[email protected]>

* Update test json format to list

Signed-off-by: Daniel Widdis <[email protected]>

* Remove unused modelId

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
Signed-off-by: Arjun kumar Giri <[email protected]>

* Revert "Rebase from upstream"

This reverts commit 49cff61.

Signed-off-by: Arjun kumar Giri <[email protected]>

* Add unit tests

Signed-off-by: Arjun kumar Giri <[email protected]>

* Rebase upstream changes

Signed-off-by: Arjun kumar Giri <[email protected]>

* Fixed PR feedback

Signed-off-by: Arjun kumar Giri <[email protected]>

---------

Signed-off-by: Arjun kumar Giri <[email protected]>
Signed-off-by: Daniel Widdis <[email protected]>
Signed-off-by: arjunkumargiri <[email protected]>
Co-authored-by: Arjun kumar Giri <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2023
1 parent 568cd6b commit ea3d62b
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,29 @@ public class MLToolSpec implements ToXContentObject {
public static final String TOOL_NAME_FIELD = "name";
public static final String DESCRIPTION_FIELD = "description";
public static final String PARAMETERS_FIELD = "parameters";
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";

private String type;
private String name;
private String description;
private Map<String, String> parameters;
private boolean includeOutputInAgentResponse;


@Builder(toBuilder = true)
public MLToolSpec(String type,
String name,
String description,
Map<String, String> parameters) {
Map<String, String> parameters,
boolean includeOutputInAgentResponse) {
if (type == null) {
throw new IllegalArgumentException("tool type is null");
}
this.type = type;
this.name = name;
this.description = description;
this.parameters = parameters;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
}

public MLToolSpec(StreamInput input) throws IOException{
Expand All @@ -54,6 +58,7 @@ public MLToolSpec(StreamInput input) throws IOException{
if (input.readBoolean()) {
parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
includeOutputInAgentResponse = input.readBoolean();
}

public void writeTo(StreamOutput out) throws IOException {
Expand All @@ -66,6 +71,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeBoolean(includeOutputInAgentResponse);
}

@Override
Expand All @@ -83,6 +89,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (parameters != null && parameters.size() > 0) {
builder.field(PARAMETERS_FIELD, parameters);
}
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
builder.endObject();
return builder;
}
Expand All @@ -92,6 +99,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
String name = null;
String description = null;
Map<String, String> parameters = null;
boolean includeOutputInAgentResponse = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -111,6 +119,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
case PARAMETERS_FIELD:
parameters = getParameterMap(parser.map());
break;
case INCLUDE_OUTPUT_IN_AGENT_RESPONSE:
includeOutputInAgentResponse = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
Expand All @@ -121,6 +132,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
.name(name)
.description(description)
.parameters(parameters)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.StepListener;
import org.opensearch.client.Client;
Expand All @@ -18,13 +19,15 @@
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -57,29 +60,39 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
List<MLToolSpec> toolSpecs = mlAgent.getTools();
StepListener<Object> firstStepListener = null;
Tool firstTool = null;
List<ModelTensor> flowAgentOutput = new ArrayList<>();
Map<String, String> firstToolExecuteParams = null;
StepListener<Object> lastStepListener = null;
StepListener<Object> previousStepListener = null;
if (toolSpecs.size() == 0) {
listener.onFailure(new IllegalArgumentException("no tool configured"));
return;
}
for (int i = 0 ;i<toolSpecs.size(); i++) {
MLToolSpec toolSpec = toolSpecs.get(i);
Tool tool = createTool(toolSpec);

for (int i = 0 ; i <= toolSpecs.size() ; i++) {
if (i == 0) {
MLToolSpec toolSpec = toolSpecs.get(i);
Tool tool = createTool(toolSpec);
firstStepListener = new StepListener();
lastStepListener = firstStepListener;
previousStepListener = firstStepListener;
firstTool = tool;
firstToolExecuteParams = getToolExecuteParams(toolSpec, params);
} else {
MLToolSpec lastToolSpec = toolSpecs.get(i - 1);
MLToolSpec previousToolSpec = toolSpecs.get(i - 1);
StepListener<Object> nextStepListener = new StepListener<>();
int finalI = i;
lastStepListener.whenComplete(output -> {
String outputKey = lastToolSpec.getType() + ".output";
if (lastToolSpec.getName() != null) {
outputKey = lastToolSpec.getName() + ".output";
previousStepListener.whenComplete(output -> {
String key = previousToolSpec.getName();
String outputKey = previousToolSpec.getName() != null ? previousToolSpec.getName() + ".output"
: previousToolSpec.getType() + ".output";

if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
String result = output instanceof String ? (String) output :
AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));

ModelTensor stepOutput = ModelTensor.builder().name(key).result(result).build();
flowAgentOutput.add(stepOutput);
}

if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) {
ModelTensors tensors = (ModelTensors) ((List) output).get(0);
Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response");
Expand All @@ -93,18 +106,23 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
params.put(outputKey, escapeJson(toJson(output.toString())));
}
}
if (finalI < toolSpecs.size() - 1) {

if (finalI == toolSpecs.size()) {
listener.onResponse(flowAgentOutput);
return;
}

MLToolSpec toolSpec = toolSpecs.get(finalI);
Tool tool = createTool(toolSpec);
if (finalI < toolSpecs.size()) {
tool.run(getToolExecuteParams(toolSpec, params), nextStepListener);
} else {
tool.run(getToolExecuteParams(toolSpec, params), listener);
}

}, e -> {
log.error("Failed to run flow agent", e);
listener.onFailure(e);
});
if (i < toolSpecs.size() - 1) {
lastStepListener = nextStepListener;
}
previousStepListener = nextStepListener;
}
}
if (toolSpecs.size() == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
Expand Down Expand Up @@ -38,6 +39,8 @@
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.search.SearchHit;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -250,9 +253,11 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
AtomicBoolean getFinalAnswer = new AtomicBoolean(false);
AtomicReference<String> lastTool = new AtomicReference<>();
AtomicReference<String> lastThought = new AtomicReference<>();
AtomicReference<String> currentAction = new AtomicReference<>();
AtomicReference<String> lastAction = new AtomicReference<>();
AtomicReference<String> lastActionInput = new AtomicReference<>();
AtomicReference<String> lastActionResult = new AtomicReference<>();
List<ModelTensor> outputModelTensors = new ArrayList<>();

StepListener<?> lastStepListener = null;
int maxIterations = Integer.parseInt(maxIteration) * 2;
Expand All @@ -268,7 +273,7 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp

lastStepListener.whenComplete(output -> {
StringBuilder sessionMsgAnswerBuilder = new StringBuilder("");
if (finalI % 2 == 0) {
if (finalI % 2 == 0) { // LLM response handler to identify next action
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
Expand Down Expand Up @@ -312,23 +317,19 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
}
cotModelTensors.add(ModelTensors.builder().mlModelTensors(Arrays.asList(ModelTensor.builder().name("response").result(finalAnswer).build())).build());

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors.add(ModelTensors.builder().mlModelTensors(Arrays.asList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer)).build())).build());
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
publishResponse(listener, outputModelTensors, finalAnswer);
}
return;
}
if (finalI == maxIterations - 1) {
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors.add(ModelTensors.builder().mlModelTensors(Arrays.asList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", thought)).build())).build());
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
publishResponse(listener, outputModelTensors, finalAnswer);
}
}

Expand All @@ -344,6 +345,7 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
}
}
action = toolName;
currentAction.set(action);

if (action != null && tools.containsKey(action) && inputTools.contains(action)) {
Map<String, String> toolParams = new HashMap<>();
Expand Down Expand Up @@ -397,10 +399,18 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());
}
} else {
} else { // Handle tool output
Object result = output;
modelTensors.add(ModelTensors.builder().mlModelTensors(Arrays.asList(ModelTensor.builder().dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + result)).build())).build());

MLToolSpec toolSpec = toolSpecMap.get(currentAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = output instanceof String ? (String) output :
AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
ModelTensor modelTensor = ModelTensor.builder().name(toolSpec.getName()).result(outputString).build();
outputModelTensors.add(modelTensor);
}

String toolResponse = tmpParameters.get("prompt.tool_response");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(ImmutableMap.of("observation", result), "${parameters.", "}");
toolResponse = toolResponseSubstitutor.replace(toolResponse);
Expand All @@ -418,13 +428,12 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
ActionRequest request = new MLPredictionTaskRequest(llm.getModelId(), RemoteInferenceMLInput.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()).build());

if (finalI == maxIterations - 1) {
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors.add(ModelTensors.builder().mlModelTensors(Arrays.asList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", lastThought.get())).build())).build());
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
publishResponse(listener, outputModelTensors, lastThought.get());
}
} else {
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
Expand All @@ -445,6 +454,13 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private void publishResponse(ActionListener<Object> listener, List<ModelTensor> outputModelTensors, String finalAnswer) {
List<ModelTensors> finalModelTensors = new ArrayList<>();
outputModelTensors.add(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer)).build());
finalModelTensors.add(ModelTensors.builder().mlModelTensors(outputModelTensors).build());
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}

private String addPrefixSuffixToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> prefixMap = new HashMap<>();
String prefix = parameters.containsKey(PROMPT_PREFIX) ? parameters.get(PROMPT_PREFIX) : "";
Expand Down
Loading

0 comments on commit ea3d62b

Please sign in to comment.