Skip to content

Commit

Permalink
Add support for context_size and include 'interaction_id' in SearchRe… (
Browse files Browse the repository at this point in the history
#1385)

* Add support for context_size and include 'interaction_id' in SearchResponse. [Issue #1372]

Signed-off-by: Austin Lee <[email protected]>

* Added spotless, removed unused code, added more comments.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Oct 3, 2023
1 parent 9359487 commit ae6995a
Show file tree
Hide file tree
Showing 31 changed files with 1,114 additions and 348 deletions.
10 changes: 10 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.18.0'
}

repositories {
Expand Down Expand Up @@ -73,3 +74,12 @@ jacocoTestCoverageVerification {
}

check.dependsOn jacocoTestCoverageVerification

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class GenerativeQAProcessorConstants {
// The field in search results that contain the context to be sent to the LLM.
public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list";

public static final String CONFIG_NAME_SYSTEM_PROMPT = "system_prompt";
public static final String CONFIG_NAME_USER_INSTRUCTIONS = "user_instructions";

public static final Setting<Boolean> RAG_PIPELINE_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

import java.util.Map;
import java.util.function.BooleanSupplier;

/**
* Defines the request processor for generative QA search pipelines.
*/
Expand All @@ -35,7 +35,13 @@ public class GenerativeQARequestProcessor extends AbstractProcessor implements S
private String modelId;
private final BooleanSupplier featureFlagSupplier;

protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId, BooleanSupplier supplier) {
protected GenerativeQARequestProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.featureFlagSupplier = supplier;
Expand Down Expand Up @@ -76,12 +82,17 @@ public SearchRequestProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (featureFlagSupplier.getAsBoolean()) {
return new GenerativeQARequestProcessor(tag, description, ignoreFailure,
ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
return new GenerativeQARequestProcessor(
tag,
description,
ignoreFailure,
ConfigurationUtils
.readStringProperty(
GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
this.featureFlagSupplier
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -32,22 +37,20 @@
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import com.google.gson.JsonArray;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Defines the response processor for generative QA search pipelines.
Expand All @@ -58,11 +61,16 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

private final String systemPrompt;
private final String userInstructions;

@Setter
private ConversationalMemoryClient memoryClient;

Expand All @@ -73,11 +81,23 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private final BooleanSupplier featureFlagSupplier;

protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure,
Llm llm, String llmModel, List<String> contextFields, BooleanSupplier supplier) {
protected GenerativeQAResponseProcessor(
Client client,
String tag,
String description,
boolean ignoreFailure,
Llm llm,
String llmModel,
List<String> contextFields,
String systemPrompt,
String userInstructions,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.llmModel = llmModel;
this.contextFields = contextFields;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
this.llm = llm;
this.memoryClient = new ConversationalMemoryClient(client);
this.featureFlagSupplier = supplier;
Expand All @@ -93,45 +113,112 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
}
log.info("Timeout for this request: {} seconds.", timeout);

String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
if (llmModel == null) {
throw new IllegalArgumentException("llm_model cannot be null.");
}
String conversationId = params.getConversationId();
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW);
List<String> searchResults = getSearchResults(response);
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
String answer = (String) output.getAnswers().get(0);
Instant start = Instant.now();
Integer interactionSize = params.getInteractionSize();
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
log.info("Using interaction size of {}", interactionSize);
List<Interaction> chatHistory = (conversationId == null)
? Collections.emptyList()
: memoryClient.getInteractions(conversationId, interactionSize);
log.info("Retrieved chat history. ({})", getDuration(start));

Integer topN = params.getContextSize();
if (topN == null) {
topN = GenerativeQAParameters.SIZE_NULL_VALUE;
}
List<String> searchResults = getSearchResults(response, topN);

log.info("system_prompt: {}", systemPrompt);
log.info("user_instructions: {}", userInstructions);
start = Instant.now();
ChatCompletionOutput output = llm
.doChatCompletion(
LlmIOUtil
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
);
log.info("doChatCompletion complete. ({})", getDuration(start));

String answer = null;
String errorMessage = null;
String interactionId = null;
if (conversationId != null) {
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults));
if (output.isErrorOccurred()) {
errorMessage = output.getErrors().get(0);
} else {
answer = (String) output.getAnswers().get(0);

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient
.createInteraction(
conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
jsonArrayToString(searchResults)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
}
}

return insertAnswer(response, answer, interactionId);
return insertAnswer(response, answer, errorMessage, interactionId);
}

long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
return new GenerativeSearchResponse(
answer,
errorMessage,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters(),
interactionId
);
}

private List<String> getSearchResults(SearchResponse response) {
private List<String> getSearchResults(SearchResponse response, Integer topN) {
List<String> searchResults = new ArrayList<>();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSourceMap = hit.getSourceAsMap();
SearchHit[] hits = response.getHits().getHits();
int total = hits.length;
int end = (topN != GenerativeQAParameters.SIZE_NULL_VALUE) ? Math.min(topN, total) : total;
for (int i = 0; i < end; i++) {
Map<String, Object> docSourceMap = hits[i].getSourceAsMap();
for (String contextField : contextFields) {
Object context = docSourceMap.get(contextField);
if (context == null) {
log.error("Context " + contextField + " not found in search hit " + hit);
log.error("Context " + contextField + " not found in search hit " + hits[i]);
// TODO throw a more meaningful error here?
throw new RuntimeException();
}
Expand Down Expand Up @@ -167,36 +254,68 @@ public SearchResponseProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (this.featureFlagSupplier.getAsBoolean()) {
String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
String modelId = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils
.readList(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
if (contextFields.isEmpty()) {
throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
throw newConfigurationException(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST,
"required property can't be empty."
);
}
log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields);
return new GenerativeQAResponseProcessor(client,
String systemPrompt = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT
);
String userInstructions = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS
);
log
.info(
"model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}",
modelId,
llmModel,
contextFields,
systemPrompt,
userInstructions
);
return new GenerativeQAResponseProcessor(
client,
tag,
description,
ignoreFailure,
ModelLocator.getLlm(modelId, client),
llmModel,
contextFields,
systemPrompt,
userInstructions,
featureFlagSupplier
);
} else {
Expand Down
Loading

0 comments on commit ae6995a

Please sign in to comment.