Skip to content

Commit

Permalink
add query_context layer to search ext
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 5, 2023
1 parent a6626ee commit daede22
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,24 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo
List<SearchExtBuilder> exts = searchRequest.source().ext();
Map<String, Object> params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams();
Map<String, Object> scoringContext = new HashMap<>();
if (params.containsKey(QUERY_TEXT_FIELD)) {
if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
if (!params.containsKey(NAME)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME));
}
Object ctxObj = params.remove(NAME);
if (!(ctxObj instanceof Map<?, ?>)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME));
}
@SuppressWarnings("unchecked")
Map<String, Object> ctxMap = (Map<String, Object>) ctxObj;
if (ctxMap.containsKey(QUERY_TEXT_FIELD)) {
if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD)
);
}
scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD));
} else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
String path = (String) params.get(QUERY_TEXT_PATH_FIELD);
scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD));
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
// Convert query to a map with io/xcontent shenanigans
ByteArrayOutputStream baos = new ByteArrayOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(baos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private void runQueries() throws Exception {
}

private Map<String, Object> search(String queryText) throws Exception {
String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_text\":\"%s\"}}}";
String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_context\": {\"query_text\":\"%s\"}}}}";
String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText);
log.info(jsonQuery);
Request request = new Request("POST", "/" + INDEX_NAME + "/_search");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ private void setupParams(Map<String, Object> params) {
NeuralQueryBuilder nqb = new NeuralQueryBuilder();
nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins");
ssb.query(nqb);
List<SearchExtBuilder> exts = List.of(new RerankSearchExtBuilder(new HashMap<>(params)));
List<SearchExtBuilder> exts = List.of(
new RerankSearchExtBuilder(new HashMap<>(Map.of(QueryContextSourceFetcher.NAME, new HashMap<>(params))))
);
ssb.ext(exts);
doReturn(ssb).when(request).source();
}
Expand Down

0 comments on commit daede22

Please sign in to comment.