Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.15] [Backport 2.17] Bug Fix: Fix for rag processor throwing NPE when optional parameters are not provided #3078

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Objects;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;

import com.google.common.base.Preconditions;

import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
Expand All @@ -45,57 +45,42 @@
@NoArgsConstructor
public class GenerativeQAParameters implements Writeable, ToXContentObject {

private static final ObjectParser<GenerativeQAParameters, Void> PARSER;

// Optional parameter; if provided, conversational memory will be used for RAG
// and the current interaction will be saved in the conversation referenced by this id.
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");
private static final String CONVERSATION_ID = "memory_id";

// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
// provided at the search request level.
private static final ParseField LLM_MODEL = new ParseField("llm_model");
private static final String LLM_MODEL = "llm_model";

// Required parameter; this is sent to LLMs as part of the user prompt.
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
private static final ParseField LLM_QUESTION = new ParseField("llm_question");
private static final String LLM_QUESTION = "llm_question";

// Optional parameter; this parameter controls the number of search results ("contexts") to
// include in the user prompt.
private static final ParseField CONTEXT_SIZE = new ParseField("context_size");
private static final String CONTEXT_SIZE = "context_size";

// Optional parameter; this parameter controls the number of the interactions to include
// in the user prompt.
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");
private static final String INTERACTION_SIZE = "message_size";

// Optional parameter; this parameter controls how long the search pipeline waits for a response
// from a remote inference endpoint before timing out the request.
private static final ParseField TIMEOUT = new ParseField("timeout");
private static final String TIMEOUT = "timeout";

// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
private static final String SYSTEM_PROMPT = "system_prompt";

// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
private static final String USER_INSTRUCTIONS = "user_instructions";

// Optional parameter; this parameter indicates the name of the field in the LLM response
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
private static final String LLM_RESPONSE_FIELD = "llm_response_field";

public static final int SIZE_NULL_VALUE = -1;

static {
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
}

@Setter
@Getter
private String conversationId;
Expand Down Expand Up @@ -132,6 +117,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
@Getter
private String llmResponseField;

@Builder
public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -148,7 +134,7 @@ public GenerativeQAParameters(

// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
Expand All @@ -172,16 +158,45 @@ public GenerativeQAParameters(StreamInput input) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return xContentBuilder
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
.field(LLM_MODEL.getPreferredName(), this.llmModel)
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
xContentBuilder.startObject();
if (this.conversationId != null) {
xContentBuilder.field(CONVERSATION_ID, this.conversationId);
}

if (this.llmModel != null) {
xContentBuilder.field(LLM_MODEL, this.llmModel);
}

if (this.llmQuestion != null) {
xContentBuilder.field(LLM_QUESTION, this.llmQuestion);
}

if (this.systemPrompt != null) {
xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt);
}

if (this.userInstructions != null) {
xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions);
}

if (this.contextSize != null) {
xContentBuilder.field(CONTEXT_SIZE, this.contextSize);
}

if (this.interactionSize != null) {
xContentBuilder.field(INTERACTION_SIZE, this.interactionSize);
}

if (this.timeout != null) {
xContentBuilder.field(TIMEOUT, this.timeout);
}

if (this.llmResponseField != null) {
xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField);
}

xContentBuilder.endObject();
return xContentBuilder;
}

@Override
Expand All @@ -200,7 +215,67 @@ public void writeTo(StreamOutput out) throws IOException {
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
String conversationId = null;
String llmModel = null;
String llmQuestion = null;
String systemPrompt = null;
String userInstructions = null;
Integer contextSize = null;
Integer interactionSize = null;
Integer timeout = null;
String llmResponseField = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String field = parser.currentName();
parser.nextToken();

switch (field) {
case CONVERSATION_ID:
conversationId = parser.text();
break;
case LLM_MODEL:
llmModel = parser.text();
break;
case LLM_QUESTION:
llmQuestion = parser.text();
break;
case SYSTEM_PROMPT:
systemPrompt = parser.text();
break;
case USER_INSTRUCTIONS:
userInstructions = parser.text();
break;
case CONTEXT_SIZE:
contextSize = parser.intValue();
break;
case INTERACTION_SIZE:
interactionSize = parser.intValue();
break;
case TIMEOUT:
timeout = parser.intValue();
break;
case LLM_RESPONSE_FIELD:
llmResponseField = parser.text();
break;
default:
parser.skipChildren();
break;
}
}

return GenerativeQAParameters
.builder()
.conversationId(conversationId)
.llmModel(llmModel)
.llmQuestion(llmQuestion)
.systemPrompt(systemPrompt)
.userInstructions(userInstructions)
.contextSize(contextSize)
.interactionSize(interactionSize)
.timeout(timeout)
.llmResponseField(llmResponseField)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.EOFException;
import java.io.IOException;
import java.util.Collections;

import org.junit.Assert;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentHelper;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;
import org.opensearch.test.OpenSearchTestCase;

public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -107,21 +112,38 @@ public void testMiscMethods() throws IOException {
}

public void testParse() throws IOException {
XContentParser xcParser = mock(XContentParser.class);
when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT);
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser);
String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
requiredJsonStr
);

parser.nextToken();
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser);
assertNotNull(builder);
assertNotNull(builder.getParams());
GenerativeQAParameters params = builder.getParams();
Assert.assertEquals("this is test llm question", params.getLlmQuestion());
}

public void testXContentRoundTrip() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
GenerativeQAParameters parameters = deserialized.getParams();
assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize());
Expand All @@ -133,10 +155,16 @@ public void testXContentRoundTripAllValues() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,18 @@ public void testToXConent() throws IOException {
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXConentAllOptionalParameters() throws IOException {
public void testToXContentEmptyParams() throws IOException {
GenerativeQAParameters parameters = new GenerativeQAParameters();
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
parameters.toXContent(builder, null);
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXContentAllOptionalParameters() throws IOException {
String conversationId = "a";
String llmModel = "b";
String llmQuestion = "c";
Expand Down
Loading