diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 5779bbf2c..953b91766 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -89,7 +89,7 @@ private RerankType findRerankType(final Map config) throws Illeg * Factory class for context fetchers. Constructs a list of context fetchers * specified in the pipeline config (and maybe the query context fetcher) */ - protected static class ContextFetcherFactory { + private static class ContextFetcherFactory { /** * Map rerank types to whether they should include the query context source fetcher diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index 7dc577502..34fd42d86 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -18,10 +18,12 @@ import org.opensearch.search.SearchHit; import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; /** * Context Source Fetcher that gets context from the search results (documents) */ +@Log4j2 @AllArgsConstructor public class DocumentContextSourceFetcher implements ContextSourceFetcher { @@ -59,6 +61,14 @@ private String contextFromSearchHit(final SearchHit hit, final String field) { Object sourceValue = ObjectPath.eval(field, hit.getSourceAsMap()); return String.valueOf(sourceValue); } else { + log.warn( + String.format( + Locale.ROOT, + "Could not find field %s in document %s for reranking! Using the empty string instead.", + field, + hit.getId() + ) + ); return ""; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index 5000dd756..b027e3f6f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -6,6 +6,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -56,14 +57,7 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); - // Convert query to a map with io/xcontent shenanigans - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); - searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.close(); - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); - Map map = parser.map(); + Map map = requestToMap(searchRequest); // Get the text at the path Object queryText = ObjectPath.eval(path, map); if (!(queryText instanceof String)) { @@ -87,4 +81,22 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo public String getName() { return NAME; } + + /** + * Convert a search request to a general map by streaming out as XContent and then back in, + * with the intention of representing the query as a user would see it + * @param request Search request to turn into xcontent + * @return Map representing the XContent-ified search request + * @throws IOException + */ + private static Map requestToMap(SearchRequest request) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); + request.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); + Map map = parser.map(); + return map; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index c52a5223d..43efb795d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -59,8 +59,9 @@ public void rerank(SearchResponse searchResponse, Map rerankingC // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); if (hits.length != scores.size()) { - throw new Exception("scores and hits are not the same length"); + throw new RuntimeException("scores and hits are not the same length"); } + // NOTE: Assumes that the new scores came back in the same order for (int i = 0; i < hits.length; i++) { hits[i].score(scores.get(i)); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index 4c0026a89..3909c7499 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -18,9 +18,34 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.extern.log4j.Log4j2; -@Log4j2 +/** + * Holds ext data from the query for reranking processors. Since + * there can be multiple kinds of rerank processors with different + * contexts, all we can assume is that there's keys and objects. + * e.g. ext might look like + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_text": "some question to rerank about" + * } + * } + * } + * } + * or + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_path": "query.neural.embedding.query_text" + * } + * } + * } + * } + */ @AllArgsConstructor public class RerankSearchExtBuilder extends SearchExtBuilder { @@ -44,7 +69,10 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.field(PARAM_FIELD_NAME, this.params); + for (String key : this.params.keySet()) { + builder.field(key, this.params.get(key)); + } + return builder; } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 7488e2607..fa15eda46 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -29,13 +29,13 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase { final String TAG = "default-tag"; final String DESC = "processor description"; - RerankProcessorFactory factory; + private RerankProcessorFactory factory; @Mock - MLCommonsClientAccessor clientAccessor; + private MLCommonsClientAccessor clientAccessor; @Mock - PipelineContext pipelineContext; + private PipelineContext pipelineContext; @Before public void setup() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index 1fa6ac629..c5bdb77f4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -28,11 +28,12 @@ @Log4j2 public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { - final static String PIPELINE_NAME = "rerank-mlos-pipeline"; - final static String INDEX_NAME = "rerank-test"; - final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; - final static String TEXT_REP_2 = "Fish like to eat plankton"; - final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + private final static String PIPELINE_NAME = "rerank-mlos-pipeline"; + private final static String INDEX_NAME = "rerank-test"; + private final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; + private final static String TEXT_REP_2 = "Fish like to eat plankton"; + private final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + private String modelId; @After @SneakyThrows @@ -42,13 +43,14 @@ public void tearDown() { * this happens in case we leave model from previous test case. We use new model for every test, and old model * can be undeployed and deleted to free resources after each test case execution. */ + deleteModel(modelId); deleteSearchPipeline(PIPELINE_NAME); - findDeployedModels().forEach(this::deleteModel); deleteIndex(INDEX_NAME); } - public void testCrossEncoderRerankProcessor() throws Exception { - String modelId = uploadTextSimilarityModel(); + @SneakyThrows + public void testCrossEncoderRerankProcessor() { + modelId = uploadTextSimilarityModel(); loadModel(modelId); createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); setupIndex(); @@ -59,7 +61,7 @@ private String uploadTextSimilarityModel() throws Exception { String requestBody = Files.readString( Path.of(classLoader.getResource("processor/UploadTextSimilarityModelRequestBody.json").toURI()) ); - return uploadModel(requestBody); + return registerModelGroupAndUploadModel(requestBody); } private void setupIndex() throws Exception { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 53019d7ce..018677b60 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -46,28 +46,25 @@ import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.test.OpenSearchTestCase; -import lombok.extern.log4j.Log4j2; - -@Log4j2 public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Mock - SearchRequest request; + private SearchRequest request; - SearchResponse response; + private SearchResponse response; @Mock - MLCommonsClientAccessor mlCommonsClientAccessor; + private MLCommonsClientAccessor mlCommonsClientAccessor; @Mock - PipelineContext pipelineContext; + private PipelineContext pipelineContext; @Mock - PipelineProcessingContext ppctx; + private PipelineProcessingContext ppctx; - RerankProcessorFactory factory; + private RerankProcessorFactory factory; - MLOpenSearchRerankProcessor processor; + private MLOpenSearchRerankProcessor processor; @Before public void setup() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java index e1724b014..ea0af1eb5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -7,16 +7,23 @@ import static org.mockito.Mockito.mock; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -29,7 +36,20 @@ public class RerankSearchExtBuilderTests extends OpenSearchTestCase { @Before public void setup() { - params = Map.of("query_text", "question about the meaning of life, the universe, and everything"); + params = Map.of("query_context", Map.of("query_text", "question about the meaning of life, the universe, and everything")); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry( + SearchExtBuilder.class, + new ParseField(RerankSearchExtBuilder.PARAM_FIELD_NAME), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ) + ); } public void testStreaming() throws IOException { @@ -43,19 +63,23 @@ public void testStreaming() throws IOException { assert (b1.equals(b2)); } - // public void testToXContent() throws IOException { - // RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); - // XContentBuilder builder = XContentType.JSON.contentBuilder(); - // builder.startObject(); - // b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); - // builder.endObject(); - // String extString = builder.toString(); - // log.info(extString); - // XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); - // RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); - // assert (b2.getParams().equals(params)); - // assert (b1.equals(b2)); - // } + public void testToXContent() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); + XContentBuilder builder = XContentType.JSON.contentBuilder(); + builder.startObject(); + b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + builder.endObject(); + String extString = builder.toString(); + log.info(extString); + XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + SearchExtBuilder b2 = parser.namedObject(SearchExtBuilder.class, RerankSearchExtBuilder.PARAM_FIELD_NAME, parser); + assert (b2 instanceof RerankSearchExtBuilder); + RerankSearchExtBuilder b3 = (RerankSearchExtBuilder) b2; + log.info(b1.getParams().toString()); + log.info(b3.getParams().toString()); + assert (b3.getParams().equals(params)); + assert (b1.equals(b3)); + } public void testPullFromListOfExtBuilders() { RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); diff --git a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json index 82529202d..3c23f6f21 100644 --- a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json +++ b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json @@ -4,7 +4,7 @@ "function_name": "TEXT_SIMILARITY", "description": "test model", "model_format": "TORCH_SCRIPT", - "model_group_id": "", + "model_group_id": "%s", "model_content_hash_value": "90e39a926101d1a4e542aade0794319404689b12acfd5d7e65c03d91c668b5cf", "model_config": { "model_type": "bert",