Skip to content

Commit

Permalink
Address pr comments and fix XContent in search ext
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Jan 10, 2024
1 parent 7a6595f commit 708fb66
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private RerankType findRerankType(final Map<String, Object> config) throws Illeg
* Factory class for context fetchers. Constructs a list of context fetchers
* specified in the pipeline config (and maybe the query context fetcher)
*/
protected static class ContextFetcherFactory {
private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import org.opensearch.search.SearchHit;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
* Context Source Fetcher that gets context from the search results (documents)
*/
@Log4j2
@AllArgsConstructor
public class DocumentContextSourceFetcher implements ContextSourceFetcher {

Expand Down Expand Up @@ -59,6 +61,14 @@ private String contextFromSearchHit(final SearchHit hit, final String field) {
Object sourceValue = ObjectPath.eval(field, hit.getSourceAsMap());
return String.valueOf(sourceValue);
} else {
log.warn(
String.format(
Locale.ROOT,
"Could not find field %s in document %s for reranking! Using the empty string instead.",
field,
hit.getId()
)
);
return "";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -56,14 +57,7 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo
scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD));
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
// Convert query to a map with io/xcontent shenanigans
ByteArrayOutputStream baos = new ByteArrayOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(baos);
searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.close();
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais);
Map<String, Object> map = parser.map();
Map<String, Object> map = requestToMap(searchRequest);
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
if (!(queryText instanceof String)) {
Expand All @@ -87,4 +81,22 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo
public String getName() {
return NAME;
}

/**
* Convert a search request to a general map by streaming out as XContent and then back in,
* with the intention of representing the query as a user would see it
* @param request Search request to turn into xcontent
* @return Map representing the XContent-ified search request
* @throws IOException
*/
private static Map<String, Object> requestToMap(SearchRequest request) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(baos);
request.source().toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.close();
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais);
Map<String, Object> map = parser.map();
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ public void rerank(SearchResponse searchResponse, Map<String, Object> rerankingC
// Assign new scores
SearchHit[] hits = searchResponse.getHits().getHits();
if (hits.length != scores.size()) {
throw new Exception("scores and hits are not the same length");
throw new RuntimeException("scores and hits are not the same length");
}
// NOTE: Assumes that the new scores came back in the same order
for (int i = 0; i < hits.length; i++) {
hits[i].score(scores.get(i));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,34 @@

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

@Log4j2
/**
* Holds ext data from the query for reranking processors. Since
* there can be multiple kinds of rerank processors with different
* contexts, all we can assume is that there's keys and objects.
* e.g. ext might look like
* {
* "query": {blah},
* "ext": {
* "rerank": {
* "query_context": {
* "query_text": "some question to rerank about"
* }
* }
* }
* }
* or
* {
* "query": {blah},
* "ext": {
* "rerank": {
* "query_context": {
* "query_path": "query.neural.embedding.query_text"
* }
* }
* }
* }
*/
@AllArgsConstructor
public class RerankSearchExtBuilder extends SearchExtBuilder {

Expand All @@ -44,7 +69,10 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.field(PARAM_FIELD_NAME, this.params);
for (String key : this.params.keySet()) {
builder.field(key, this.params.get(key));
}
return builder;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase {
final String TAG = "default-tag";
final String DESC = "processor description";

RerankProcessorFactory factory;
private RerankProcessorFactory factory;

@Mock
MLCommonsClientAccessor clientAccessor;
private MLCommonsClientAccessor clientAccessor;

@Mock
PipelineContext pipelineContext;
private PipelineContext pipelineContext;

@Before
public void setup() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
@Log4j2
public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT {

final static String PIPELINE_NAME = "rerank-mlos-pipeline";
final static String INDEX_NAME = "rerank-test";
final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy";
final static String TEXT_REP_2 = "Fish like to eat plankton";
final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}";
private final static String PIPELINE_NAME = "rerank-mlos-pipeline";
private final static String INDEX_NAME = "rerank-test";
private final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy";
private final static String TEXT_REP_2 = "Fish like to eat plankton";
private final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}";
private String modelId;

@After
@SneakyThrows
Expand All @@ -42,13 +43,14 @@ public void tearDown() {
* this happens in case we leave model from previous test case. We use new model for every test, and old model
* can be undeployed and deleted to free resources after each test case execution.
*/
deleteModel(modelId);
deleteSearchPipeline(PIPELINE_NAME);
findDeployedModels().forEach(this::deleteModel);
deleteIndex(INDEX_NAME);
}

public void testCrossEncoderRerankProcessor() throws Exception {
String modelId = uploadTextSimilarityModel();
@SneakyThrows
public void testCrossEncoderRerankProcessor() {
modelId = uploadTextSimilarityModel();
loadModel(modelId);
createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json");
setupIndex();
Expand All @@ -59,7 +61,7 @@ private String uploadTextSimilarityModel() throws Exception {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/UploadTextSimilarityModelRequestBody.json").toURI())
);
return uploadModel(requestBody);
return registerModelGroupAndUploadModel(requestBody);
}

private void setupIndex() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,25 @@
import org.opensearch.search.pipeline.Processor.PipelineContext;
import org.opensearch.test.OpenSearchTestCase;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {

@Mock
SearchRequest request;
private SearchRequest request;

SearchResponse response;
private SearchResponse response;

@Mock
MLCommonsClientAccessor mlCommonsClientAccessor;
private MLCommonsClientAccessor mlCommonsClientAccessor;

@Mock
PipelineContext pipelineContext;
private PipelineContext pipelineContext;

@Mock
PipelineProcessingContext ppctx;
private PipelineProcessingContext ppctx;

RerankProcessorFactory factory;
private RerankProcessorFactory factory;

MLOpenSearchRerankProcessor processor;
private MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
import static org.mockito.Mockito.mock;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.test.OpenSearchTestCase;

Expand All @@ -29,7 +36,20 @@ public class RerankSearchExtBuilderTests extends OpenSearchTestCase {

@Before
public void setup() {
params = Map.of("query_text", "question about the meaning of life, the universe, and everything");
params = Map.of("query_context", Map.of("query_text", "question about the meaning of life, the universe, and everything"));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(
List.of(
new NamedXContentRegistry.Entry(
SearchExtBuilder.class,
new ParseField(RerankSearchExtBuilder.PARAM_FIELD_NAME),
parser -> RerankSearchExtBuilder.parse(parser)
)
)
);
}

public void testStreaming() throws IOException {
Expand All @@ -43,19 +63,23 @@ public void testStreaming() throws IOException {
assert (b1.equals(b2));
}

// public void testToXContent() throws IOException {
// RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params));
// XContentBuilder builder = XContentType.JSON.contentBuilder();
// builder.startObject();
// b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS);
// builder.endObject();
// String extString = builder.toString();
// log.info(extString);
// XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString);
// RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser);
// assert (b2.getParams().equals(params));
// assert (b1.equals(b2));
// }
public void testToXContent() throws IOException {
RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params));
XContentBuilder builder = XContentType.JSON.contentBuilder();
builder.startObject();
b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS);
builder.endObject();
String extString = builder.toString();
log.info(extString);
XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString);
SearchExtBuilder b2 = parser.namedObject(SearchExtBuilder.class, RerankSearchExtBuilder.PARAM_FIELD_NAME, parser);
assert (b2 instanceof RerankSearchExtBuilder);
RerankSearchExtBuilder b3 = (RerankSearchExtBuilder) b2;
log.info(b1.getParams().toString());
log.info(b3.getParams().toString());
assert (b3.getParams().equals(params));
assert (b1.equals(b3));
}

public void testPullFromListOfExtBuilders() {
RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"function_name": "TEXT_SIMILARITY",
"description": "test model",
"model_format": "TORCH_SCRIPT",
"model_group_id": "<MODEL_GROUP_ID>",
"model_group_id": "%s",
"model_content_hash_value": "90e39a926101d1a4e542aade0794319404689b12acfd5d7e65c03d91c668b5cf",
"model_config": {
"model_type": "bert",
Expand Down

0 comments on commit 708fb66

Please sign in to comment.