diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index a88930fd..e3670bd0 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -90,7 +90,6 @@ public RAGTool( this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; this.k = k == null ? DEFAULT_K : k; this.inferenceModelId = inferenceModelId; - outputParser = new Parser() { @Override public Object parse(Object o) { @@ -115,8 +114,7 @@ public void run(Map parameters, ActionListener listener) } try { - String question = parameters.get(INPUT_FIELD); - input = gson.fromJson(question, String.class); + input = parameters.get(INPUT_FIELD); } catch (Exception e) { log.error("Failed to read question from " + INPUT_FIELD, e); listener.onFailure(new IllegalArgumentException("Failed to read question from " + INPUT_FIELD)); diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index aea3a01c..dfbbed26 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -28,6 +28,9 @@ @ToolAnnotation(VectorDBTool.TYPE) public class VectorDBTool extends AbstractRetrieverTool { public static final String TYPE = "VectorDBTool"; + + public static String DEFAULT_DESCRIPTION = + "Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query."; public static final String MODEL_ID_FIELD = "model_id"; public static final String EMBEDDING_FIELD = "embedding_field"; public static final String K_FIELD = "k"; diff --git a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java index 79bfcebf..8ef43468 100644 --- a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java @@ -203,6 +203,36 @@ public void testRunWithEmptySearchResponse() throws IOException { verify(client).execute(any(), any(), any()); } + @Test + public void testRunWithQuestionJson() throws IOException { + NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + ragTool.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + @Test @SneakyThrows public void testRunWithRuntimeExceptionDuringSearch() { @@ -261,17 +291,6 @@ public void testRunWithEmptyInput() { ragTool.run(Map.of(INPUT_FIELD, ""), listener); } - @Test - public void testRunWithMalformedInput() throws IOException { - ActionListener listener = mock(ActionListener.class); - ragTool.run(Map.of(INPUT_FIELD, "{hello?"), listener); - verify(listener).onFailure(any(RuntimeException.class)); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(listener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to read question from " + INPUT_FIELD, argumentCaptor.getValue().getMessage()); - - } - @Test public void testFactory() { RAGTool.Factory factoryMock = new RAGTool.Factory();