Skip to content

Commit

Permalink
create agent
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Nov 7, 2023
1 parent e8b063a commit 90d9b31
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 48 deletions.
12 changes: 12 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,18 @@ public class CommonValue {
+ MLAgent.AGENT_NAME_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.AGENT_TYPE_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.DESCRIPTION_FIELD
+ "\" : {\"type\": \"text\"},\n"
+ " \""
+ MLAgent.PROMPT_FIELD
+ "\" : {\"type\": \"text\"},\n"
+ " \""
+ MLAgent.MODEL_ID_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.LLM_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
Expand All @@ -353,6 +362,9 @@ public class CommonValue {
+ MLAgent.MEMORY_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ MLAgent.MEMORY_ID_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
Expand Down
58 changes: 44 additions & 14 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
Expand All @@ -32,20 +29,26 @@ public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
public static final String AGENT_TYPE_FIELD = "type";
public static final String DESCRIPTION_FIELD = "description";
public static final String PROMPT_FIELD = "prompt";
public static final String MODEL_ID_FIELD = "model_id";
public static final String LLM_FIELD = "llm";
public static final String TOOLS_FIELD = "tools";
public static final String PARAMETERS_FIELD = "parameters";
public static final String MEMORY_FIELD = "memory";
public static final String MEMORY_ID_FIELD = "memory_id";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";

private String name;
private String type;
private String description;
private String prompt;
private String modelId;
private LLMSpec llm;
private List<MLToolSpec> tools;
private Map<String, String> parameters;
private MLMemorySpec memory;
private String memoryId;

private Instant createdTime;
private Instant lastUpdateTime;
Expand All @@ -54,10 +57,13 @@ public class MLAgent implements ToXContentObject, Writeable {
public MLAgent(String name,
String type,
String description,
String prompt,
String modelId,
LLMSpec llm,
List<MLToolSpec> tools,
Map<String, String> parameters,
MLMemorySpec memory,
String memoryId,
Instant createdTime,
Instant lastUpdateTime) {
if (name == null) {
Expand All @@ -66,27 +72,23 @@ public MLAgent(String name,
this.name = name;
this.type = type;
this.description = description;
this.prompt = prompt;
this.modelId = modelId;
this.llm = llm;
this.tools = tools;
this.parameters = parameters;
this.memory = memory;
this.memoryId = memoryId;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getAlias()).orElse(toolSpec.getName());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
}

public MLAgent(StreamInput input) throws IOException{
name = input.readString();
type = input.readOptionalString();
type = input.readString();
description = input.readOptionalString();
prompt = input.readOptionalString();
modelId = input.readString();
if (input.readBoolean()) {
llm = new LLMSpec(input);
}
Expand All @@ -103,14 +105,17 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
memoryId = input.readOptionalString();
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeOptionalString(type);
out.writeString(type);
out.writeOptionalString(description);
out.writeOptionalString(prompt);
out.writeString(modelId);
if (llm != null) {
out.writeBoolean(true);
llm.writeTo(out);
Expand Down Expand Up @@ -138,6 +143,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalString(memoryId);
out.writeInstant(createdTime);
out.writeInstant(lastUpdateTime);
}
Expand All @@ -154,6 +160,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (prompt != null) {
builder.field(PROMPT_FIELD, prompt);
}
if (modelId != null) {
builder.field(MODEL_ID_FIELD, modelId);
}
if (llm != null) {
builder.field(LLM_FIELD, llm);
}
Expand All @@ -166,6 +178,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (memory != null) {
builder.field(MEMORY_FIELD, memory);
}
if (memoryId != null) {
builder.field(MEMORY_ID_FIELD, memoryId);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand All @@ -180,10 +195,13 @@ public static MLAgent parse(XContentParser parser) throws IOException {
String name = null;
String type = null;
String description = null;;
String prompt = null;
String modelId = null;
LLMSpec llm = null;
List<MLToolSpec> tools = null;
Map<String, String> parameters = null;
MLMemorySpec memory = null;
String memoryId = null;
Instant createdTime = null;
Instant lastUpdateTime = null;

Expand All @@ -202,6 +220,12 @@ public static MLAgent parse(XContentParser parser) throws IOException {
case DESCRIPTION_FIELD:
description = parser.text();
break;
case PROMPT_FIELD:
prompt = parser.text();
break;
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case LLM_FIELD:
llm = LLMSpec.parse(parser);
break;
Expand All @@ -218,6 +242,9 @@ public static MLAgent parse(XContentParser parser) throws IOException {
case MEMORY_FIELD:
memory = MLMemorySpec.parse(parser);
break;
case MEMORY_ID_FIELD:
memoryId = parser.text();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand All @@ -233,10 +260,13 @@ public static MLAgent parse(XContentParser parser) throws IOException {
.name(name)
.type(type)
.description(description)
.prompt(prompt)
.modelId(modelId)
.llm(llm)
.tools(tools)
.parameters(parameters)
.memory(memory)
.memoryId(memoryId)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.build();
Expand Down
38 changes: 19 additions & 19 deletions common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,43 @@

@Getter
public class MLToolSpec implements ToXContentObject {
public static final String TOOL_TYPE_FIELD = "type";
public static final String TOOL_NAME_FIELD = "name";
public static final String ALIAS_FIELD = "alias";
public static final String DESCRIPTION_FIELD = "description";
public static final String PARAMETERS_FIELD = "parameters";

private String type;
private String name;
private String alias;
private String description;
private Map<String, String> parameters;


@Builder(toBuilder = true)
public MLToolSpec(String name,
String alias,
public MLToolSpec(String type,
String name,
String description,
Map<String, String> parameters) {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
if (type == null) {
throw new IllegalArgumentException("tool type is null");
}
this.type = type;
this.name = name;
this.alias = alias;
this.description = description;
this.parameters = parameters;
}

public MLToolSpec(StreamInput input) throws IOException{
name = input.readString();
alias = input.readOptionalString();
type = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
if (input.readBoolean()) {
parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeOptionalString(alias);
out.writeString(type);
out.writeOptionalString(name);
out.writeOptionalString(description);
if (parameters != null && parameters.size() > 0) {
out.writeBoolean(true);
Expand All @@ -71,12 +71,12 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (type != null) {
builder.field(TOOL_TYPE_FIELD, type);
}
if (name != null) {
builder.field(TOOL_NAME_FIELD, name);
}
if (alias != null) {
builder.field(ALIAS_FIELD, alias);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand All @@ -88,8 +88,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static MLToolSpec parse(XContentParser parser) throws IOException {
String type = null;
String name = null;
String alias = null;
String description = null;
Map<String, String> parameters = null;

Expand All @@ -99,12 +99,12 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
parser.nextToken();

switch (fieldName) {
case TOOL_TYPE_FIELD:
type = parser.text();
break;
case TOOL_NAME_FIELD:
name = parser.text();
break;
case ALIAS_FIELD:
alias = parser.text();
break;
case DESCRIPTION_FIELD:
description = parser.text();
break;
Expand All @@ -117,8 +117,8 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
}
}
return MLToolSpec.builder()
.type(type)
.name(name)
.alias(alias)
.description(description)
.parameters(parameters)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
int finalI = i;
lastStepListener.whenComplete(output -> {
String outputKey = lastToolSpec.getName() + ".output";
if (lastToolSpec.getAlias() != null) {
outputKey = lastToolSpec.getAlias() + ".output";
if (lastToolSpec.getName() != null) {
outputKey = lastToolSpec.getName() + ".output";
}
if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) {
ModelTensors tensors = (ModelTensors) ((List) output).get(0);
Expand Down Expand Up @@ -137,7 +137,7 @@ private Tool createTool(MLToolSpec toolSpec) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getName());
}
Tool tool = toolFactories.get(toolSpec.getName()).create(toolParams);
tool.setAlias(toolSpec.getAlias());
tool.setName(toolSpec.getName());

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
Expand All @@ -155,8 +155,8 @@ private Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String
if (key.startsWith(toolSpec.getName() + ".")) {
toBeReplaced = toolSpec.getName()+".";
}
if (toolSpec.getAlias() != null && key.startsWith(toolSpec.getAlias() + ".")) {
toBeReplaced = toolSpec.getAlias()+".";
if (toolSpec.getName() != null && key.startsWith(toolSpec.getName() + ".")) {
toBeReplaced = toolSpec.getName()+".";
}
if (toBeReplaced != null) {
executeParams.put(key.replace(toBeReplaced, ""), params.get(key));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ private void runAgent(MLAgent mlAgent, Map<String, String> params, ActionListene
}
}
Tool tool = toolFactories.get(toolSpec.getName()).create(toolParams);
tool.setAlias(toolSpec.getAlias());
tool.setName(toolSpec.getName());

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
String toolName = Optional.ofNullable(toolSpec.getAlias()).orElse(toolSpec.getName());
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getName());
tools.put(toolName, tool);
toolSpecMap.put(toolName, toolSpec);
}
Expand Down Expand Up @@ -207,7 +207,7 @@ private void runReAct(LLMSpec llm, Map<String, Tool> tools, Map<String, MLToolSp
inputTools.addAll(gson.fromJson(parameters.get(TOOLS), List.class));
} else {
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
String toolName = Optional.ofNullable(entry.getValue().getAlias()).orElse(entry.getValue().getName());
String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getName());
inputTools.add(toolName);
}
}
Expand Down
Loading

0 comments on commit 90d9b31

Please sign in to comment.