Skip to content

Commit

Permalink
add description for VectorDBTool and remove json parsig for RAGTool (#…
Browse files Browse the repository at this point in the history
…121)

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl authored Jan 12, 2024
1 parent c202f62 commit c81746e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
4 changes: 1 addition & 3 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ public RAGTool(
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
this.k = k == null ? DEFAULT_K : k;
this.inferenceModelId = inferenceModelId;

outputParser = new Parser() {
@Override
public Object parse(Object o) {
Expand All @@ -115,8 +114,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
}

try {
String question = parameters.get(INPUT_FIELD);
input = gson.fromJson(question, String.class);
input = parameters.get(INPUT_FIELD);
} catch (Exception e) {
log.error("Failed to read question from " + INPUT_FIELD, e);
listener.onFailure(new IllegalArgumentException("Failed to read question from " + INPUT_FIELD));
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
@ToolAnnotation(VectorDBTool.TYPE)
public class VectorDBTool extends AbstractRetrieverTool {
public static final String TYPE = "VectorDBTool";

public static String DEFAULT_DESCRIPTION =
"Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query.";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String K_FIELD = "k";
Expand Down
41 changes: 30 additions & 11 deletions src/test/java/org/opensearch/agent/tools/RAGToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,36 @@ public void testRunWithEmptySearchResponse() throws IOException {
verify(client).execute(any(), any(), any());
}

@Test
public void testRunWithQuestionJson() throws IOException {
NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry();
ragTool.setXContentRegistry(mockNamedXContentRegistry);

ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput();
SearchResponse mockedEmptySearchResponse = SearchResponse
.fromXContent(
JsonXContent.jsonXContent
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString)
);

doAnswer(invocation -> {
SearchRequest searchRequest = invocation.getArgument(0);
assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size());
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(mockedEmptySearchResponse);
return null;
}).when(client).search(any(), any());

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
ragTool.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), listener);
verify(client).search(any(), any());
verify(client).execute(any(), any(), any());
}

@Test
@SneakyThrows
public void testRunWithRuntimeExceptionDuringSearch() {
Expand Down Expand Up @@ -261,17 +291,6 @@ public void testRunWithEmptyInput() {
ragTool.run(Map.of(INPUT_FIELD, ""), listener);
}

@Test
public void testRunWithMalformedInput() throws IOException {
ActionListener listener = mock(ActionListener.class);
ragTool.run(Map.of(INPUT_FIELD, "{hello?"), listener);
verify(listener).onFailure(any(RuntimeException.class));
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assertEquals("Failed to read question from " + INPUT_FIELD, argumentCaptor.getValue().getMessage());

}

@Test
public void testFactory() {
RAGTool.Factory factoryMock = new RAGTool.Factory();
Expand Down

0 comments on commit c81746e

Please sign in to comment.