Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for context_size and include 'interaction_id' in SearchRe… #1385

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.18.0'
}

repositories {
Expand Down Expand Up @@ -73,3 +74,12 @@ jacocoTestCoverageVerification {
}

check.dependsOn jacocoTestCoverageVerification

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class GenerativeQAProcessorConstants {
// The field in search results that contain the context to be sent to the LLM.
public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list";

public static final String CONFIG_NAME_SYSTEM_PROMPT = "system_prompt";
public static final String CONFIG_NAME_USER_INSTRUCTIONS = "user_instructions";

public static final Setting<Boolean> RAG_PIPELINE_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

import java.util.Map;
import java.util.function.BooleanSupplier;

/**
* Defines the request processor for generative QA search pipelines.
*/
Expand All @@ -35,7 +35,13 @@ public class GenerativeQARequestProcessor extends AbstractProcessor implements S
private String modelId;
private final BooleanSupplier featureFlagSupplier;

protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId, BooleanSupplier supplier) {
protected GenerativeQARequestProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.featureFlagSupplier = supplier;
Expand Down Expand Up @@ -76,12 +82,17 @@ public SearchRequestProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (featureFlagSupplier.getAsBoolean()) {
return new GenerativeQARequestProcessor(tag, description, ignoreFailure,
ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
return new GenerativeQARequestProcessor(
tag,
description,
ignoreFailure,
ConfigurationUtils
.readStringProperty(
GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
this.featureFlagSupplier
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -32,22 +37,20 @@
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import com.google.gson.JsonArray;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Defines the response processor for generative QA search pipelines.
Expand All @@ -58,11 +61,16 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be a hard add?


private final String llmModel;
private final List<String> contextFields;

private final String systemPrompt;
private final String userInstructions;

@Setter
private ConversationalMemoryClient memoryClient;

Expand All @@ -73,11 +81,23 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private final BooleanSupplier featureFlagSupplier;

protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure,
Llm llm, String llmModel, List<String> contextFields, BooleanSupplier supplier) {
protected GenerativeQAResponseProcessor(
Client client,
String tag,
String description,
boolean ignoreFailure,
Llm llm,
String llmModel,
List<String> contextFields,
String systemPrompt,
String userInstructions,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.llmModel = llmModel;
this.contextFields = contextFields;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
this.llm = llm;
this.memoryClient = new ConversationalMemoryClient(client);
this.featureFlagSupplier = supplier;
Expand All @@ -93,45 +113,112 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);

Integer timeout = params.getTimeout();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one liner?

Integer timeout = (params.getTimeout() != null && params.getTimeout() != GenerativeQAParameters.SIZE_NULL_VALUE) ? params.getTimeout() : MAX_PROCESSOR_TIME_IN_SECONDS;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that would involve calling .getTimeout() up to three times. Does the Java compiler do something clever to avoid that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. We could do:

Integer timeoutValue = params.getTimeout();
Integer timeout = (timeoutValue != null && timeoutValue != GenerativeQAParameters.SIZE_NULL_VALUE) ? timeoutValue : MAX_PROCESSOR_TIME_IN_SECONDS;

But upto you, keep the code as it is if you want.

if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
}
log.info("Timeout for this request: {} seconds.", timeout);

String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
if (llmModel == null) {
throw new IllegalArgumentException("llm_model cannot be null.");
}
String conversationId = params.getConversationId();
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW);
List<String> searchResults = getSearchResults(response);
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
String answer = (String) output.getAnswers().get(0);
Instant start = Instant.now();
Integer interactionSize = params.getInteractionSize();
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
log.info("Using interaction size of {}", interactionSize);
List<Interaction> chatHistory = (conversationId == null)
? Collections.emptyList()
: memoryClient.getInteractions(conversationId, interactionSize);
log.info("Retrieved chat history. ({})", getDuration(start));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious to know what's the goal of adding getDuration(start) in the log?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It logs the elapsed time for client calls. I can use debug for the log level.


Integer topN = params.getContextSize();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one liner?

Integer topN = params.getContextSize() != null ? params.getContextSize() : GenerativeQAParameters.SIZE_NULL_VALUE;

if (topN == null) {
topN = GenerativeQAParameters.SIZE_NULL_VALUE;
}
List<String> searchResults = getSearchResults(response, topN);

log.info("system_prompt: {}", systemPrompt);
log.info("user_instructions: {}", userInstructions);
start = Instant.now();
ChatCompletionOutput output = llm
.doChatCompletion(
LlmIOUtil
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
);
log.info("doChatCompletion complete. ({})", getDuration(start));

String answer = null;
String errorMessage = null;
String interactionId = null;
if (conversationId != null) {
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults));
if (output.isErrorOccurred()) {
errorMessage = output.getErrors().get(0);
} else {
answer = (String) output.getAnswers().get(0);

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient
.createInteraction(
conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
jsonArrayToString(searchResults)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
}
}

return insertAnswer(response, answer, interactionId);
return insertAnswer(response, answer, errorMessage, interactionId);
}

long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
return new GenerativeSearchResponse(
answer,
errorMessage,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters(),
interactionId
);
}

private List<String> getSearchResults(SearchResponse response) {
private List<String> getSearchResults(SearchResponse response, Integer topN) {
List<String> searchResults = new ArrayList<>();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSourceMap = hit.getSourceAsMap();
SearchHit[] hits = response.getHits().getHits();
int total = hits.length;
int end = (topN != GenerativeQAParameters.SIZE_NULL_VALUE) ? Math.min(topN, total) : total;
for (int i = 0; i < end; i++) {
Map<String, Object> docSourceMap = hits[i].getSourceAsMap();
for (String contextField : contextFields) {
Object context = docSourceMap.get(contextField);
if (context == null) {
log.error("Context " + contextField + " not found in search hit " + hit);
log.error("Context " + contextField + " not found in search hit " + hits[i]);
// TODO throw a more meaningful error here?
throw new RuntimeException();
}
Expand Down Expand Up @@ -167,36 +254,68 @@ public SearchResponseProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (this.featureFlagSupplier.getAsBoolean()) {
String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
String modelId = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils
.readList(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
if (contextFields.isEmpty()) {
throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
throw newConfigurationException(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST,
"required property can't be empty."
);
}
log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields);
return new GenerativeQAResponseProcessor(client,
String systemPrompt = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT
);
String userInstructions = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS
);
log
.info(
"model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}",
modelId,
llmModel,
contextFields,
systemPrompt,
userInstructions
);
return new GenerativeQAResponseProcessor(
client,
tag,
description,
ignoreFailure,
ModelLocator.getLlm(modelId, client),
llmModel,
contextFields,
systemPrompt,
userInstructions,
featureFlagSupplier
);
} else {
Expand Down
Loading