Skip to content

Commit

Permalink
add memory factory; fix tool interface
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Nov 12, 2023
1 parent 7d1741c commit 6f4394b
Show file tree
Hide file tree
Showing 47 changed files with 296 additions and 231 deletions.
44 changes: 32 additions & 12 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public class CommonValue {
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String ML_AGENT_INDEX = ".plugins-ml-agent";
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1;

public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down Expand Up @@ -328,7 +332,6 @@ public class CommonValue {
+ " }\n"
+ "}";


public static final String ML_AGENT_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_AGENT_INDEX_SCHEMA_VERSION
Expand All @@ -339,17 +342,11 @@ public class CommonValue {
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.AGENT_TYPE_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ "\" : {\"type\":\"keyword\"},\n"
+ " \""
+ MLAgent.DESCRIPTION_FIELD
+ "\" : {\"type\": \"text\"},\n"
+ " \""
+ MLAgent.PROMPT_FIELD
+ "\" : {\"type\": \"text\"},\n"
+ " \""
+ MLAgent.MODEL_ID_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.LLM_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
Expand All @@ -362,14 +359,37 @@ public class CommonValue {
+ MLAgent.MEMORY_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ MLAgent.MEMORY_ID_FIELD
+ "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \""
+ MLAgent.CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ MLAgent.LAST_UPDATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";

public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MEMORY_META_INDEX_SCHEMA_VERSION
+ " },\n"
+ " \"properties\": {\n"
+ " \"name\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \"application_type\" : {\"type\":\"keyword\"},\n"
+ " \"created_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \"last_updated_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";

public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION
+ " },\n"
+ " \"properties\": {\n"
+ " \"question\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n"
+ " \"response\" : {\"type\":\"text\"},\n"
+ " \"session_id\" : {\"type\":\"keyword\"},\n"
+ " \"final_answer\" : {\"type\":\"boolean\"},\n"
+ " \"created_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \"last_updated_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";

}
49 changes: 12 additions & 37 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
Expand All @@ -29,8 +32,6 @@ public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
public static final String AGENT_TYPE_FIELD = "type";
public static final String DESCRIPTION_FIELD = "description";
public static final String PROMPT_FIELD = "prompt";
public static final String MODEL_ID_FIELD = "model_id";
public static final String LLM_FIELD = "llm";
public static final String TOOLS_FIELD = "tools";
public static final String PARAMETERS_FIELD = "parameters";
Expand All @@ -57,13 +58,10 @@ public class MLAgent implements ToXContentObject, Writeable {
public MLAgent(String name,
String type,
String description,
String prompt,
String modelId,
LLMSpec llm,
List<MLToolSpec> tools,
Map<String, String> parameters,
MLMemorySpec memory,
String memoryId,
Instant createdTime,
Instant lastUpdateTime) {
if (name == null) {
Expand All @@ -72,13 +70,10 @@ public MLAgent(String name,
this.name = name;
this.type = type;
this.description = description;
this.prompt = prompt;
this.modelId = modelId;
this.llm = llm;
this.tools = tools;
this.parameters = parameters;
this.memory = memory;
this.memoryId = memoryId;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
}
Expand All @@ -87,8 +82,6 @@ public MLAgent(StreamInput input) throws IOException{
name = input.readString();
type = input.readString();
description = input.readOptionalString();
prompt = input.readOptionalString();
modelId = input.readString();
if (input.readBoolean()) {
llm = new LLMSpec(input);
}
Expand All @@ -105,17 +98,23 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
memoryId = input.readOptionalString();
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeString(type);
out.writeOptionalString(description);
out.writeOptionalString(prompt);
out.writeString(modelId);
if (llm != null) {
out.writeBoolean(true);
llm.writeTo(out);
Expand Down Expand Up @@ -160,12 +159,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (prompt != null) {
builder.field(PROMPT_FIELD, prompt);
}
if (modelId != null) {
builder.field(MODEL_ID_FIELD, modelId);
}
if (llm != null) {
builder.field(LLM_FIELD, llm);
}
Expand All @@ -178,9 +171,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (memory != null) {
builder.field(MEMORY_FIELD, memory);
}
if (memoryId != null) {
builder.field(MEMORY_ID_FIELD, memoryId);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand All @@ -195,13 +185,10 @@ public static MLAgent parse(XContentParser parser) throws IOException {
String name = null;
String type = null;
String description = null;;
String prompt = null;
String modelId = null;
LLMSpec llm = null;
List<MLToolSpec> tools = null;
Map<String, String> parameters = null;
MLMemorySpec memory = null;
String memoryId = null;
Instant createdTime = null;
Instant lastUpdateTime = null;

Expand All @@ -220,12 +207,6 @@ public static MLAgent parse(XContentParser parser) throws IOException {
case DESCRIPTION_FIELD:
description = parser.text();
break;
case PROMPT_FIELD:
prompt = parser.text();
break;
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case LLM_FIELD:
llm = LLMSpec.parse(parser);
break;
Expand All @@ -242,9 +223,6 @@ public static MLAgent parse(XContentParser parser) throws IOException {
case MEMORY_FIELD:
memory = MLMemorySpec.parse(parser);
break;
case MEMORY_ID_FIELD:
memoryId = parser.text();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand All @@ -260,13 +238,10 @@ public static MLAgent parse(XContentParser parser) throws IOException {
.name(name)
.type(type)
.description(description)
.prompt(prompt)
.modelId(modelId)
.llm(llm)
.tools(tools)
.parameters(parameters)
.memory(memory)
.memoryId(memoryId)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ public class MLAgentExecutor implements Executable {
private ClusterService clusterService;
private NamedXContentRegistry xContentRegistry;
private Map<String, Tool.Factory> toolFactories;
private Map<String, Memory> memoryMap;
private Map<String, Memory.Factory> memoryFactoryMap;

public MLAgentExecutor(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory> memoryMap) {
public MLAgentExecutor(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap) {
this.client = client;
this.settings = settings;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;
this.toolFactories = toolFactories;
this.memoryMap = memoryMap;
this.memoryFactoryMap = memoryFactoryMap;
}

@Override
Expand Down Expand Up @@ -130,10 +130,10 @@ public void execute(Input input, ActionListener<Output> listener) {
listener.onFailure(ex);
});
if ("flow".equals(mlAgent.getType())) {
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap);
flowAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
} else if ("cot".equals(mlAgent.getType())) {
MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap);
reactAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ public class MLFlowAgentRunner {
private ClusterService clusterService;
private NamedXContentRegistry xContentRegistry;
private Map<String, Tool.Factory> toolFactories;
private Map<String, Memory> memoryMap;
private Map<String, Memory.Factory> memoryFactoryMap;

public MLFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory> memoryMap) {
public MLFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap) {
this.client = client;
this.settings = settings;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;
this.toolFactories = toolFactories;
this.memoryMap = memoryMap;
this.memoryFactoryMap = memoryFactoryMap;
}

public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
Expand Down
Loading

0 comments on commit 6f4394b

Please sign in to comment.