From f7a794429b7f88f44daf42901af247feafd65337 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 13 Nov 2023 18:03:36 -0800 Subject: [PATCH] Add rerank processor interfaces Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 38 +++++ .../rerank/CrossEncoderRerankProcessor.java | 115 +++++++++++++++ .../processor/rerank/RerankProcessor.java | 64 +++++++++ .../processor/rerank/RerankType.java | 48 +++++++ .../rerank/RescoringRerankProcessor.java | 136 ++++++++++++++++++ .../query/ext/RerankSearchExtBuilder.java | 98 +++++++++++++ 6 files changed, 499 insertions(+) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java new file mode 100644 index 000000000..7449743ff --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; + +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +public class RerankProcessorFactory implements Processor.Factory { + + @Override + public SearchResponseProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java new file mode 100644 index 000000000..ea2152378 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ObjectPath; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; + +public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { + + public static final String MODEL_ID_FIELD = "model_id"; + public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; + public static final String RERANK_CONTEXT_FIELD = "rerank_context_field"; + + protected final String modelId; + protected final String rerank_context; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public CrossEncoderRerankProcessor( + String description, + String tag, + boolean ignoreFailure, + String modelId, + String rerank_context, + MLCommonsClientAccessor mlCommonsClientAccessor, + Environment environment + ) { + super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure); + this.modelId = modelId; + this.rerank_context = rerank_context; + this.mlCommonsClientAccessor = mlCommonsClientAccessor; + this.environment = environment; + } + + @Override + public void generateScoringContext( + SearchRequest searchRequest, + SearchResponse searchResponse, + ActionListener> listener + ) { + try { + List exts = searchRequest.source().ext(); + Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); + Map scoringContext = new HashMap<>(); + if (params.containsKey(QUERY_TEXT_FIELD)) { + if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\""); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); + } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + String path = (String) params.get(QUERY_TEXT_PATH_FIELD); + // Convert query to a map with io/xcontent shenanigans + PipedOutputStream os = new PipedOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(os); + searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + PipedInputStream is = new PipedInputStream(os); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is); + Map map = parser.map(); + // Get the text at the path + Object queryText = ObjectPath.eval(path, map); + if (!(queryText instanceof String)) { + throw new IllegalArgumentException("query_text_path must point to a string field"); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); + } else { + throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\""); + } + listener.onResponse(scoringContext); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public void rescoreSearchResponse(SearchResponse response, Map scoringContext, ActionListener> listener) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'"); + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java new file mode 100644 index 000000000..62ab61da4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +public interface RerankProcessor extends SearchResponseProcessor { + + /** + * Generate the information that this processor needs in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + public void generateScoringContext( + SearchRequest searchRequest, + SearchResponse searchResponse, + ActionListener> listener + ); + + /** + * Given the scoring context generated by the processor and the search results, + * rerank the search results. Do so asynchronously. + * @param searchResponse the search results to rerank + * @param scoringContext the information this processor needs in order to rerank + * @param listener be async + */ + public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener); + + @Override + default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { + try { + generateScoringContext( + request, + response, + ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { responseListener.onFailure(e); }) + ); + } catch (Exception e) { + responseListener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java new file mode 100644 index 000000000..6bfb9feed --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import lombok.Getter; + +/** + * enum for distinguishing various reranking methods + */ +public enum RerankType { + + CROSS_ENCODER("cross-encoder"); + + @Getter + private final String label; + + private RerankType(String label) { + this.label = label; + } + + /** + * Construct a RerankType from the label + * @param label label of a RerankType + * @return RerankType represented by the label + */ + public static RerankType from(String label) { + try { + return RerankType.valueOf(label); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong rerank type name: " + label); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java new file mode 100644 index 000000000..f88c02d0d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -0,0 +1,136 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import lombok.AllArgsConstructor; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.profile.SearchProfileShardResults; + +@AllArgsConstructor +public abstract class RescoringRerankProcessor implements RerankProcessor { + + private final RerankType type; + private final String description; + private final String tag; + private final boolean ignoreFailure; + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); + } + + @Override + public String getType() { + return "rerank-" + type.getLabel(); + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return ignoreFailure; + } + + /** + * Generate a list of new scores for all of the documents, given the scoring context + * @param response search results to rescore + * @param scoringContext extra information needed to score the search results; e.g. model id + * @param listener be async. recieves the list of new scores + */ + public abstract void rescoreSearchResponse( + SearchResponse response, + Map scoringContext, + ActionListener> listener + ); + + @Override + public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener) { + try { + rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { + // Assign new scores + SearchHit[] hits = searchResponse.getHits().getHits(); + assert (hits.length == scores.size()); + for (int i = 0; i < hits.length; i++) { + hits[i].score(scores.get(i)); + } + // Re-sort by the new scores + Collections.sort(Arrays.asList(hits), new Comparator() { + @Override + public int compare(SearchHit hit1, SearchHit hit2) { + return Float.compare(hit1.getScore(), hit2.getScore()); + } + }); + // Reconstruct the search response, replacing the max score + SearchHits newHits = new SearchHits( + hits, + searchResponse.getHits().getTotalHits(), + hits[0].getScore(), + searchResponse.getHits().getSortFields(), + searchResponse.getHits().getCollapseField(), + searchResponse.getHits().getCollapseValues() + ); + SearchResponseSections newInternalResponse = new SearchResponseSections( + newHits, + searchResponse.getAggregations(), + searchResponse.getSuggest(), + searchResponse.isTimedOut(), + searchResponse.isTerminatedEarly(), + new SearchProfileShardResults(searchResponse.getProfileResults()), + searchResponse.getNumReducePhases(), + searchResponse.getInternalResponse().getSearchExtBuilders() + ); + SearchResponse newResponse = new SearchResponse( + newInternalResponse, + searchResponse.getScrollId(), + searchResponse.getTotalShards(), + searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), + searchResponse.getTook().millis(), + searchResponse.getPhaseTook(), + searchResponse.getShardFailures(), + searchResponse.getClusters(), + searchResponse.pointInTimeId() + ); + listener.onResponse(newResponse); + }, e -> { listener.onFailure(e); })); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java new file mode 100644 index 000000000..ad3756aa8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.query.ext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +@AllArgsConstructor +public class RerankSearchExtBuilder extends SearchExtBuilder { + + public final static String PARAM_FIELD_NAME = "rerank"; + @Getter + protected Map params; + + public RerankSearchExtBuilder(StreamInput in) throws IOException { + params = in.readMap(); + } + + @Override + public String getWriteableName() { + return PARAM_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(params); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PARAM_FIELD_NAME, this.params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.params); + } + + @Override + public boolean equals(Object obj) { + return (obj instanceof RerankSearchExtBuilder) && params.equals(((RerankSearchExtBuilder) obj).params); + } + + /** + * Pick out the first RerankSearchExtBuilder from a list of SearchExtBuilders + * @param builders list of SearchExtBuilders + * @return the RerankSearchExtBuilder + */ + public static RerankSearchExtBuilder fromExtBuilderList(List builders) { + Optional b = builders.stream().filter(bldr -> bldr instanceof RerankSearchExtBuilder).findFirst(); + if (b.isPresent()) { + return (RerankSearchExtBuilder) b.get(); + } else { + return null; + } + } + + /** + * Parse XContent to rerankSearchExtBuilder + * @param parser parser parsing this searchExt + * @return RerankSearchExtBuilder represented by this searchExt + * @throws IOException if problems parsing + */ + public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { + return new RerankSearchExtBuilder(parser.map()); + } + +}