From 90d9b310811c682f55ee5415d53ad0be8f671156 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 7 Nov 2023 11:55:53 -0800 Subject: [PATCH] create agent Signed-off-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 12 ++++ .../opensearch/ml/common/agent/MLAgent.java | 58 ++++++++++++++----- .../ml/common/agent/MLToolSpec.java | 38 ++++++------ .../algorithms/agent/MLFlowAgentRunner.java | 10 ++-- .../algorithms/agent/MLReActAgentRunner.java | 6 +- .../opensearch/ml/engine/tools/AgentTool.java | 15 +++++ .../ml/engine/tools/CatIndexTool.java | 15 +++++ .../ml/engine/tools/MLModelTool.java | 15 +++++ .../opensearch/ml/engine/tools/MathTool.java | 15 +++++ .../ml/engine/tools/PainlessScriptTool.java | 15 +++++ .../ml/engine/tools/VectorDBTool.java | 15 +++++ .../opensearch/ml/common/spi/tools/Tool.java | 28 ++++++--- 12 files changed, 194 insertions(+), 48 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index eee4222d77..1cd4e4131d 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -338,9 +338,18 @@ public class CommonValue { + MLAgent.AGENT_NAME_FIELD + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + " \"" + + MLAgent.AGENT_TYPE_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + MLAgent.DESCRIPTION_FIELD + "\" : {\"type\": \"text\"},\n" + " \"" + + MLAgent.PROMPT_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLAgent.MODEL_ID_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + MLAgent.LLM_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" @@ -353,6 +362,9 @@ public class CommonValue { + MLAgent.MEMORY_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" + + MLAgent.MEMORY_ID_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + MLAgent.CREATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 65a08924cc..7e85ec411b 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -17,11 +17,8 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -32,20 +29,26 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String AGENT_NAME_FIELD = "name"; public static final String AGENT_TYPE_FIELD = "type"; public static final String DESCRIPTION_FIELD = "description"; + public static final String PROMPT_FIELD = "prompt"; + public static final String MODEL_ID_FIELD = "model_id"; public static final String LLM_FIELD = "llm"; public static final String TOOLS_FIELD = "tools"; public static final String PARAMETERS_FIELD = "parameters"; public static final String MEMORY_FIELD = "memory"; + public static final String MEMORY_ID_FIELD = "memory_id"; public static final String CREATED_TIME_FIELD = "created_time"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; private String name; private String type; private String description; + private String prompt; + private String modelId; private LLMSpec llm; private List tools; private Map parameters; private MLMemorySpec memory; + private String memoryId; private Instant createdTime; private Instant lastUpdateTime; @@ -54,10 +57,13 @@ public class MLAgent implements ToXContentObject, Writeable { public MLAgent(String name, String type, String description, + String prompt, + String modelId, LLMSpec llm, List tools, Map parameters, MLMemorySpec memory, + String memoryId, Instant createdTime, Instant lastUpdateTime) { if (name == null) { @@ -66,27 +72,23 @@ public MLAgent(String name, this.name = name; this.type = type; this.description = description; + this.prompt = prompt; + this.modelId = modelId; this.llm = llm; this.tools = tools; this.parameters = parameters; this.memory = memory; + this.memoryId = memoryId; this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; - if (!"flow".equals(type)) { - Set toolNames = new HashSet<>(); - for (MLToolSpec toolSpec : tools) { - String toolName = Optional.ofNullable(toolSpec.getAlias()).orElse(toolSpec.getName()); - if (toolNames.contains(toolName)) { - throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); - } - } - } } public MLAgent(StreamInput input) throws IOException{ name = input.readString(); - type = input.readOptionalString(); + type = input.readString(); description = input.readOptionalString(); + prompt = input.readOptionalString(); + modelId = input.readString(); if (input.readBoolean()) { llm = new LLMSpec(input); } @@ -103,14 +105,17 @@ public MLAgent(StreamInput input) throws IOException{ if (input.readBoolean()) { memory = new MLMemorySpec(input); } + memoryId = input.readOptionalString(); createdTime = input.readInstant(); lastUpdateTime = input.readInstant(); } public void writeTo(StreamOutput out) throws IOException { out.writeString(name); - out.writeOptionalString(type); + out.writeString(type); out.writeOptionalString(description); + out.writeOptionalString(prompt); + out.writeString(modelId); if (llm != null) { out.writeBoolean(true); llm.writeTo(out); @@ -138,6 +143,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalString(memoryId); out.writeInstant(createdTime); out.writeInstant(lastUpdateTime); } @@ -154,6 +160,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } + if (prompt != null) { + builder.field(PROMPT_FIELD, prompt); + } + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } if (llm != null) { builder.field(LLM_FIELD, llm); } @@ -166,6 +178,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (memory != null) { builder.field(MEMORY_FIELD, memory); } + if (memoryId != null) { + builder.field(MEMORY_ID_FIELD, memoryId); + } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } @@ -180,10 +195,13 @@ public static MLAgent parse(XContentParser parser) throws IOException { String name = null; String type = null; String description = null;; + String prompt = null; + String modelId = null; LLMSpec llm = null; List tools = null; Map parameters = null; MLMemorySpec memory = null; + String memoryId = null; Instant createdTime = null; Instant lastUpdateTime = null; @@ -202,6 +220,12 @@ public static MLAgent parse(XContentParser parser) throws IOException { case DESCRIPTION_FIELD: description = parser.text(); break; + case PROMPT_FIELD: + prompt = parser.text(); + break; + case MODEL_ID_FIELD: + modelId = parser.text(); + break; case LLM_FIELD: llm = LLMSpec.parse(parser); break; @@ -218,6 +242,9 @@ public static MLAgent parse(XContentParser parser) throws IOException { case MEMORY_FIELD: memory = MLMemorySpec.parse(parser); break; + case MEMORY_ID_FIELD: + memoryId = parser.text(); + break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -233,10 +260,13 @@ public static MLAgent parse(XContentParser parser) throws IOException { .name(name) .type(type) .description(description) + .prompt(prompt) + .modelId(modelId) .llm(llm) .tools(tools) .parameters(parameters) .memory(memory) + .memoryId(memoryId) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .build(); diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 8cd606cbf5..27e38f16be 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -22,34 +22,34 @@ @Getter public class MLToolSpec implements ToXContentObject { + public static final String TOOL_TYPE_FIELD = "type"; public static final String TOOL_NAME_FIELD = "name"; - public static final String ALIAS_FIELD = "alias"; public static final String DESCRIPTION_FIELD = "description"; public static final String PARAMETERS_FIELD = "parameters"; + private String type; private String name; - private String alias; private String description; private Map parameters; @Builder(toBuilder = true) - public MLToolSpec(String name, - String alias, + public MLToolSpec(String type, + String name, String description, Map parameters) { - if (name == null) { - throw new IllegalArgumentException("agent name is null"); + if (type == null) { + throw new IllegalArgumentException("tool type is null"); } + this.type = type; this.name = name; - this.alias = alias; this.description = description; this.parameters = parameters; } public MLToolSpec(StreamInput input) throws IOException{ - name = input.readString(); - alias = input.readOptionalString(); + type = input.readString(); + name = input.readOptionalString(); description = input.readOptionalString(); if (input.readBoolean()) { parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); @@ -57,8 +57,8 @@ public MLToolSpec(StreamInput input) throws IOException{ } public void writeTo(StreamOutput out) throws IOException { - out.writeString(name); - out.writeOptionalString(alias); + out.writeString(type); + out.writeOptionalString(name); out.writeOptionalString(description); if (parameters != null && parameters.size() > 0) { out.writeBoolean(true); @@ -71,12 +71,12 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + if (type != null) { + builder.field(TOOL_TYPE_FIELD, type); + } if (name != null) { builder.field(TOOL_NAME_FIELD, name); } - if (alias != null) { - builder.field(ALIAS_FIELD, alias); - } if (description != null) { builder.field(DESCRIPTION_FIELD, description); } @@ -88,8 +88,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public static MLToolSpec parse(XContentParser parser) throws IOException { + String type = null; String name = null; - String alias = null; String description = null; Map parameters = null; @@ -99,12 +99,12 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { parser.nextToken(); switch (fieldName) { + case TOOL_TYPE_FIELD: + type = parser.text(); + break; case TOOL_NAME_FIELD: name = parser.text(); break; - case ALIAS_FIELD: - alias = parser.text(); - break; case DESCRIPTION_FIELD: description = parser.text(); break; @@ -117,8 +117,8 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { } } return MLToolSpec.builder() + .type(type) .name(name) - .alias(alias) .description(description) .parameters(parameters) .build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 6be256197a..b1c5521ede 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -77,8 +77,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener { String outputKey = lastToolSpec.getName() + ".output"; - if (lastToolSpec.getAlias() != null) { - outputKey = lastToolSpec.getAlias() + ".output"; + if (lastToolSpec.getName() != null) { + outputKey = lastToolSpec.getName() + ".output"; } if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) { ModelTensors tensors = (ModelTensors) ((List) output).get(0); @@ -137,7 +137,7 @@ private Tool createTool(MLToolSpec toolSpec) { throw new IllegalArgumentException("Tool not found: " + toolSpec.getName()); } Tool tool = toolFactories.get(toolSpec.getName()).create(toolParams); - tool.setAlias(toolSpec.getAlias()); + tool.setName(toolSpec.getName()); if (toolSpec.getDescription() != null) { tool.setDescription(toolSpec.getDescription()); @@ -155,8 +155,8 @@ private Map getToolExecuteParams(MLToolSpec toolSpec, Map params, ActionListene } } Tool tool = toolFactories.get(toolSpec.getName()).create(toolParams); - tool.setAlias(toolSpec.getAlias()); + tool.setName(toolSpec.getName()); if (toolSpec.getDescription() != null) { tool.setDescription(toolSpec.getDescription()); } - String toolName = Optional.ofNullable(toolSpec.getAlias()).orElse(toolSpec.getName()); + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getName()); tools.put(toolName, tool); toolSpecMap.put(toolName, toolSpec); } @@ -207,7 +207,7 @@ private void runReAct(LLMSpec llm, Map tools, Map entry : tools.entrySet()) { - String toolName = Optional.ofNullable(entry.getValue().getAlias()).orElse(entry.getValue().getName()); + String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getName()); inputTools.add(toolName); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index e64349a187..2485881dca 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -58,11 +58,26 @@ public void run(Map parameters, ActionListener listener) } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Override public String getName() { return AgentTool.NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { return true; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index 3c16ef18ae..f7539edde2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -149,6 +149,16 @@ public void run(Map parameters, ActionListener listener) })); } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Data public static class IndexState { private String health; @@ -198,6 +208,11 @@ public String getName() { return CatIndexTool.NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 06d3d27b1f..bf2a79b38f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -78,12 +78,27 @@ public void run(Map parameters, ActionListener listener) })); } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Override public String getName() { return MLModelTool.NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java index 65c79d3e0a..be14d9647c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java @@ -56,11 +56,26 @@ public void run(Map parameters, ActionListener listener) listener.onResponse((T)result); } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Override public String getName() { return MathTool.NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java index 3598e2e66a..e5898a00d2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java @@ -64,12 +64,27 @@ public void run(Map parameters, ActionListener listener) listener.onResponse((T)s); } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Override public String getName() { return PainlessScriptTool.NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index e09fb38d4c..68c5e7a320 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -111,11 +111,26 @@ public void run(Map parameters, ActionListener listener) } } + @Override + public String getType() { + return null; + } + + @Override + public String getVersion() { + return null; + } + @Override public String getName() { return NAME; } + @Override + public void setName(String s) { + + } + @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index 05fb90b804..2e2c7745ad 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -38,22 +38,28 @@ default T run(Map parameters) { default void setOutputParser(Parser parser) {}; /** - * Get tool name. + * Get tool type mapping to the run function. * @return */ - String getName(); + String getType(); + + /** + * Get tool version. + * @return + */ + String getVersion(); /** - * Get tool alias. + * Get tool name which is displayed in prompt. * @return */ - String getAlias(); + String getName(); /** - * Set tool alias. - * @param alias + * Set tool name which is displayed in prompt. + * @param name */ - void setAlias(String alias); + void setName(String name); /** * Get tool description. @@ -86,6 +92,14 @@ default boolean end(String input, Map toolParameters) { return false; } + /** + * The tool runs against the original human input. + * @return + */ + default boolean useOriginalInput() { + return false; + } + /** * Tool factory which can create instance of {@link Tool}. * @param The subclass this factory produces