From 801b5e401e212a4b21825298991caed6746f9ca6 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:39:01 -0700 Subject: [PATCH] [Backport 2.17] Bug Fix: Fix for rag processor throwing NPE when optional parameters are not provided (#3066) (#3079) (cherry picked from commit 8b5b38e4b28e182a0cdfbf1c55c6ef00e6663a57) Co-authored-by: Pavan Yekbote --- .../ext/GenerativeQAParameters.java | 153 +++++++++++++----- .../ext/GenerativeQAParamExtBuilderTests.java | 42 ++++- .../ext/GenerativeQAParametersTests.java | 13 +- 3 files changed, 161 insertions(+), 47 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index 01dc97db75..7222b7369d 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -17,22 +17,22 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.Objects; -import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; import com.google.common.base.Preconditions; +import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -45,57 +45,42 @@ @NoArgsConstructor public class GenerativeQAParameters implements Writeable, ToXContentObject { - private static final ObjectParser PARSER; - // Optional parameter; if provided, conversational memory will be used for RAG // and the current interaction will be saved in the conversation referenced by this id. - private static final ParseField CONVERSATION_ID = new ParseField("memory_id"); + private static final String CONVERSATION_ID = "memory_id"; // Optional parameter; if an LLM model is not set at the search pipeline level, one must be // provided at the search request level. - private static final ParseField LLM_MODEL = new ParseField("llm_model"); + private static final String LLM_MODEL = "llm_model"; // Required parameter; this is sent to LLMs as part of the user prompt. // TODO support question rewriting when chat history is not used (conversation_id is not provided). - private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + private static final String LLM_QUESTION = "llm_question"; // Optional parameter; this parameter controls the number of search results ("contexts") to // include in the user prompt. - private static final ParseField CONTEXT_SIZE = new ParseField("context_size"); + private static final String CONTEXT_SIZE = "context_size"; // Optional parameter; this parameter controls the number of the interactions to include // in the user prompt. - private static final ParseField INTERACTION_SIZE = new ParseField("message_size"); + private static final String INTERACTION_SIZE = "message_size"; // Optional parameter; this parameter controls how long the search pipeline waits for a response // from a remote inference endpoint before timing out the request. - private static final ParseField TIMEOUT = new ParseField("timeout"); + private static final String TIMEOUT = "timeout"; // Optional parameter: this parameter allows request-level customization of the "system" (role) prompt. - private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT); + private static final String SYSTEM_PROMPT = "system_prompt"; // Optional parameter: this parameter allows request-level customization of the "user" (role) prompt. - private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS); + private static final String USER_INSTRUCTIONS = "user_instructions"; // Optional parameter; this parameter indicates the name of the field in the LLM response // that contains the chat completion text, i.e. "answer". - private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field"); + private static final String LLM_RESPONSE_FIELD = "llm_response_field"; public static final int SIZE_NULL_VALUE = -1; - static { - PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); - PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); - PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); - PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); - PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT); - PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS); - PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); - PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD); - } - @Setter @Getter private String conversationId; @@ -132,6 +117,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private String llmResponseField; + @Builder public GenerativeQAParameters( String conversationId, String llmModel, @@ -148,7 +134,7 @@ public GenerativeQAParameters( // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters // for question rewriting. - Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); + Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided."); this.llmQuestion = llmQuestion; this.systemPrompt = systemPrompt; this.userInstructions = userInstructions; @@ -172,16 +158,45 @@ public GenerativeQAParameters(StreamInput input) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return xContentBuilder - .field(CONVERSATION_ID.getPreferredName(), this.conversationId) - .field(LLM_MODEL.getPreferredName(), this.llmModel) - .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) - .field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt) - .field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions) - .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) - .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) - .field(TIMEOUT.getPreferredName(), this.timeout) - .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField); + xContentBuilder.startObject(); + if (this.conversationId != null) { + xContentBuilder.field(CONVERSATION_ID, this.conversationId); + } + + if (this.llmModel != null) { + xContentBuilder.field(LLM_MODEL, this.llmModel); + } + + if (this.llmQuestion != null) { + xContentBuilder.field(LLM_QUESTION, this.llmQuestion); + } + + if (this.systemPrompt != null) { + xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt); + } + + if (this.userInstructions != null) { + xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions); + } + + if (this.contextSize != null) { + xContentBuilder.field(CONTEXT_SIZE, this.contextSize); + } + + if (this.interactionSize != null) { + xContentBuilder.field(INTERACTION_SIZE, this.interactionSize); + } + + if (this.timeout != null) { + xContentBuilder.field(TIMEOUT, this.timeout); + } + + if (this.llmResponseField != null) { + xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField); + } + + xContentBuilder.endObject(); + return xContentBuilder; } @Override @@ -200,7 +215,67 @@ public void writeTo(StreamOutput out) throws IOException { } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + String conversationId = null; + String llmModel = null; + String llmQuestion = null; + String systemPrompt = null; + String userInstructions = null; + Integer contextSize = null; + Integer interactionSize = null; + Integer timeout = null; + String llmResponseField = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String field = parser.currentName(); + parser.nextToken(); + + switch (field) { + case CONVERSATION_ID: + conversationId = parser.text(); + break; + case LLM_MODEL: + llmModel = parser.text(); + break; + case LLM_QUESTION: + llmQuestion = parser.text(); + break; + case SYSTEM_PROMPT: + systemPrompt = parser.text(); + break; + case USER_INSTRUCTIONS: + userInstructions = parser.text(); + break; + case CONTEXT_SIZE: + contextSize = parser.intValue(); + break; + case INTERACTION_SIZE: + interactionSize = parser.intValue(); + break; + case TIMEOUT: + timeout = parser.intValue(); + break; + case LLM_RESPONSE_FIELD: + llmResponseField = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return GenerativeQAParameters + .builder() + .conversationId(conversationId) + .llmModel(llmModel) + .llmQuestion(llmQuestion) + .systemPrompt(systemPrompt) + .userInstructions(userInstructions) + .contextSize(contextSize) + .interactionSize(interactionSize) + .timeout(timeout) + .llmResponseField(llmResponseField) + .build(); } @Override diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 49f164cdb5..95374b14ea 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -21,18 +21,23 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.EOFException; import java.io.IOException; +import java.util.Collections; +import org.junit.Assert; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { @@ -107,21 +112,38 @@ public void testMiscMethods() throws IOException { } public void testParse() throws IOException { - XContentParser xcParser = mock(XContentParser.class); - when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT); - GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser); + String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + requiredJsonStr + ); + + parser.nextToken(); + GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser); assertNotNull(builder); assertNotNull(builder.getParams()); + GenerativeQAParameters params = builder.getParams(); + Assert.assertEquals("this is test llm question", params.getLlmQuestion()); } public void testXContentRoundTrip() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); GenerativeQAParameters parameters = deserialized.getParams(); assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize()); @@ -133,10 +155,16 @@ public void testXContentRoundTripAllValues() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index c36dcdb2a5..2d7d459202 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -200,7 +200,18 @@ public void testToXConent() throws IOException { assertNotNull(parameters.toXContent(builder, null)); } - public void testToXConentAllOptionalParameters() throws IOException { + public void testToXContentEmptyParams() throws IOException { + GenerativeQAParameters parameters = new GenerativeQAParameters(); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + parameters.toXContent(builder, null); + assertNotNull(parameters.toXContent(builder, null)); + } + + public void testToXContentAllOptionalParameters() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c";