From b8de7994ccc32b6d5cb8ac783708d8aaf89f32e5 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:16:29 -0700 Subject: [PATCH] =?UTF-8?q?Allow=20llmQuestion=20to=20be=20optional=20when?= =?UTF-8?q?=20llmMessages=20is=20used.=20=20(Issue=20#3=E2=80=A6=20(#3072)?= =?UTF-8?q?=20(#3082)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow llmQuestion to be optional when llmMessages is used. (Issue #3067) Signed-off-by: Austin Lee * Remove unused lines. Signed-off-by: Austin Lee --------- Signed-off-by: Austin Lee (cherry picked from commit 48d275d5512c2de2cadb511264a742075d64aabd) Co-authored-by: Austin Lee --- .../ml/rest/RestMLRAGSearchProcessorIT.java | 23 +++++++++++++------ .../ext/GenerativeQAParameters.java | 14 +++++------ .../ext/GenerativeQAParamExtBuilderTests.java | 3 +-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index b97d73ae8f..b16abef59f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -360,7 +360,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"ext\": {\n" + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" + " \"system_prompt\": \"%s\",\n" + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" @@ -379,8 +378,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"ext\": {\n" + " \"generative_qa_parameters\": {\n" + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - // + " \"system_prompt\": \"%s\",\n" + " \"user_instructions\": \"%s\",\n" + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" @@ -726,8 +723,12 @@ public void testBM25WithBedrock() throws Exception { public void testBM25WithBedrockConverse() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverse because AWS_ACCESS_KEY_ID is null"); return; } + + System.out.println("Running testBM25WithBedrockConverse"); + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -779,8 +780,11 @@ public void testBM25WithBedrockConverse() throws Exception { public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessages because AWS_ACCESS_KEY_ID is null"); return; } + System.out.println("Running testBM25WithBedrockConverseUsingLlmMessages"); + Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -840,8 +844,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { + System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat because AWS_ACCESS_KEY_ID is null"); return; } + + System.out.println("Running testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat"); Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -899,8 +906,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws public void testBM25WithOpenAIWithConversation() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { + System.out.println("Skipping testBM25WithOpenAIWithConversation because OPENAI_KEY is null"); return; } + System.out.println("Running testBM25WithOpenAIWithConversation"); + Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -957,8 +967,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception { public void testBM25WithOpenAIWithConversationAndImage() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { + System.out.println("Skipping testBM25WithOpenAIWithConversationAndImage because OPENAI_KEY is null"); return; } + System.out.println("Running testBM25WithOpenAIWithConversationAndImage"); + Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -1251,7 +1264,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, @@ -1274,8 +1286,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, - // requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, @@ -1315,7 +1325,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.source, requestParameters.match, requestParameters.llmModel, - requestParameters.llmQuestion, requestParameters.systemPrompt, requestParameters.userInstructions, requestParameters.contextSize, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index d5ec0e47c1..6ff89093b3 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -167,9 +167,11 @@ public GenerativeQAParameters( this.conversationId = conversationId; this.llmModel = llmModel; - // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters - // for question rewriting. - Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided."); + Preconditions + .checkArgument( + !(Strings.isNullOrEmpty(llmQuestion) && (llmMessages == null || llmMessages.isEmpty())), + "At least one of " + LLM_QUESTION + " or " + LLM_MESSAGES_FIELD + " must be provided." + ); this.llmQuestion = llmQuestion; this.systemPrompt = systemPrompt; this.userInstructions = userInstructions; @@ -185,7 +187,7 @@ public GenerativeQAParameters( public GenerativeQAParameters(StreamInput input) throws IOException { this.conversationId = input.readOptionalString(); this.llmModel = input.readOptionalString(); - this.llmQuestion = input.readString(); + this.llmQuestion = input.readOptionalString(); this.systemPrompt = input.readOptionalString(); this.userInstructions = input.readOptionalString(); this.contextSize = input.readInt(); @@ -246,9 +248,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(conversationId); out.writeOptionalString(llmModel); - - Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); - out.writeString(llmQuestion); + out.writeOptionalString(llmQuestion); out.writeOptionalString(systemPrompt); out.writeOptionalString(userInstructions); out.writeInt(contextSize); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 2772884f11..8a5ade0072 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -121,8 +121,7 @@ public void testMiscMethods() throws IOException { StreamOutput so = mock(StreamOutput.class); builder1.writeTo(so); - verify(so, times(5)).writeOptionalString(any()); - verify(so, times(1)).writeString(any()); + verify(so, times(6)).writeOptionalString(any()); } public void testParse() throws IOException {