From daede22243b43d7967cfc813e367e688f666716a Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 5 Dec 2023 12:55:39 -0800 Subject: [PATCH] add query_context layer to search ext Signed-off-by: HenryL27 --- .../rerank/QueryContextSourceFetcher.java | 19 ++++++++++++++----- .../TextSimilarityRerankProcessorIT.java | 2 +- .../TextSimilarityRerankProcessorTests.java | 4 +++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index a4a7a26ba..45ba1efa1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -48,15 +48,24 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo List exts = searchRequest.source().ext(); Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); Map scoringContext = new HashMap<>(); - if (params.containsKey(QUERY_TEXT_FIELD)) { - if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + if (!params.containsKey(NAME)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME)); + } + Object ctxObj = params.remove(NAME); + if (!(ctxObj instanceof Map)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME)); + } + @SuppressWarnings("unchecked") + Map ctxMap = (Map) ctxObj; + if (ctxMap.containsKey(QUERY_TEXT_FIELD)) { + if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); - } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - String path = (String) params.get(QUERY_TEXT_PATH_FIELD); + scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); + } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { + String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); // Convert query to a map with io/xcontent shenanigans ByteArrayOutputStream baos = new ByteArrayOutputStream(); XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java index 9cf62f929..ea16c614e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java @@ -126,7 +126,7 @@ private void runQueries() throws Exception { } private Map search(String queryText) throws Exception { - String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_text\":\"%s\"}}}"; + String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_context\": {\"query_text\":\"%s\"}}}}"; String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText); log.info(jsonQuery); Request request = new Request("POST", "/" + INDEX_NAME + "/_search"); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java index a1ef46b91..85c176a57 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java @@ -105,7 +105,9 @@ private void setupParams(Map params) { NeuralQueryBuilder nqb = new NeuralQueryBuilder(); nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); ssb.query(nqb); - List exts = List.of(new RerankSearchExtBuilder(new HashMap<>(params))); + List exts = List.of( + new RerankSearchExtBuilder(new HashMap<>(Map.of(QueryContextSourceFetcher.NAME, new HashMap<>(params)))) + ); ssb.ext(exts); doReturn(ssb).when(request).source(); }