Skip to content

Commit

Permalink
Add abstract tool support (opensearch-project#1758)
Browse files Browse the repository at this point in the history
* Add abstract tool support

Signed-off-by: Arjun kumar Giri <[email protected]>

* Removed unnecessary gradle dependencies

Signed-off-by: Arjun kumar Giri <[email protected]>

* Fix windows build failure

Signed-off-by: Arjun kumar Giri <[email protected]>

---------

Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri authored Dec 21, 2023
1 parent e662249 commit 48c56b8
Show file tree
Hide file tree
Showing 18 changed files with 272 additions and 344 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -36,7 +37,7 @@
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public abstract class AbstractRetrieverTool extends AbstractTool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
Expand All @@ -52,12 +53,15 @@ public abstract class AbstractRetrieverTool implements Tool {
protected Integer docSize;

protected AbstractRetrieverTool(
String type,
String description,
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
super(type, description);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,31 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports running any Agent.
*/
@Log4j2
@ToolAnnotation(AgentTool.TYPE)
public class AgentTool implements Tool {
public class AgentTool extends AbstractTool {
public static final String TYPE = "AgentTool";
private final Client client;

private String agentId;
@Setter
@Getter
private String name = TYPE;

@VisibleForTesting
static String DEFAULT_DESCRIPTION = "Use this tool to run any agent.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

public AgentTool(Client client, String agentId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.agentId = agentId;
}
Expand All @@ -68,26 +62,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,47 +43,30 @@
import org.opensearch.core.action.ActionResponse;
import org.opensearch.index.IndexSettings;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(CatIndexTool.TYPE)
public class CatIndexTool implements Tool {
public class CatIndexTool extends AbstractTool {
public static final String TYPE = "CatIndexTool";
private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information.";

@Setter
@Getter
private String name = CatIndexTool.TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String version;

private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;
@SuppressWarnings("unused")
private ClusterService clusterService;

public CatIndexTool(Client client, ClusterService clusterService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.clusterService = clusterService;

outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser((Parser<Object, Object>) parser -> {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) parser;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand Down Expand Up @@ -295,16 +278,6 @@ public void onFailure(final Exception e) {
}, size);
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,27 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(IndexMappingTool.NAME)
public class IndexMappingTool implements Tool {
public static final String NAME = "IndexMappingTool";
@ToolAnnotation(IndexMappingTool.TYPE)
public class IndexMappingTool extends AbstractTool {
public static final String TYPE = "IndexMappingTool";

private static final String DEFAULT_DESCRIPTION = "Use this tool to get index mapping information.";
@Setter
@Getter
private String name = IndexMappingTool.NAME;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
@Getter
private String type;
@Getter
private String version;
private Client client;
@Setter
private Parser<?, ?> inputParser;
@Setter
private Parser<?, ?> outputParser;

public IndexMappingTool(Client client) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;

outputParser = new Parser<>() {
@Override
public Object parse(Object o) {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
this.setOutputParser((Parser<Object, Object>) o -> {
@SuppressWarnings("unchecked")
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,36 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports running any ml-commons model.
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool extends AbstractTool {
public static final String TYPE = "MLModelTool";

@Setter
@Getter
private String name = TYPE;
private static String DEFAULT_DESCRIPTION = "Use this tool to run any model.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;

public MLModelTool(Client client, String modelId) {
super(TYPE, DEFAULT_DESCRIPTION);
this.client = client;
this.modelId = modelId;

outputParser = o -> {
this.setOutputParser(o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
};
});
}

@Override
Expand All @@ -69,33 +58,13 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.<MLTaskResponse>wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
modelTensorOutput.getMlModelOutputs();
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
listener.onResponse((T) this.getOutputParser().parse(modelTensorOutput.getMlModelOutputs()));
}, e -> {
log.error("Failed to run model " + modelId, e);
listener.onFailure(e);
}));
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,24 @@
import java.util.regex.Pattern;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.spi.tools.AbstractTool;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
import org.opensearch.script.ScriptService;

import lombok.Getter;
import lombok.Setter;

@ToolAnnotation(MathTool.TYPE)
public class MathTool implements Tool {
public class MathTool extends AbstractTool {
public static final String TYPE = "MathTool";

@Setter
@Getter
private String name = TYPE;

@Setter
private ScriptService scriptService;

private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem.";
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;

public MathTool(ScriptService scriptService) {
super(TYPE, DEFAULT_DESCRIPTION);
this.scriptService = scriptService;
}

Expand All @@ -59,26 +52,6 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T) result);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
try {
Expand Down
Loading

0 comments on commit 48c56b8

Please sign in to comment.