forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding support for generic re-ranker interface and opensearch ml re-r…
…anker for improving search relavancy. (opensearch-project#494) * Add rerank processor interfaces Signed-off-by: HenryL27 <[email protected]> * add cross-encoder specific logic and factory Signed-off-by: HenryL27 <[email protected]> * add unittests Signed-off-by: HenryL27 <[email protected]> * add integration test Signed-off-by: HenryL27 <[email protected]> * use string.format() instead of concatenation Signed-off-by: HenryL27 <[email protected]> * rename generateScoringContext to generateRerankingContext Signed-off-by: HenryL27 <[email protected]> * add name change in test too. whoops Signed-off-by: HenryL27 <[email protected]> * start refactoring with contextSaourceFetchers Signed-off-by: HenryL27 <[email protected]> * refactor to use contextSourceFetchers to get context Signed-off-by: HenryL27 <[email protected]> * rename CrossEncoder to TextSimilarity Signed-off-by: HenryL27 <[email protected]> * add query_context layer to search ext Signed-off-by: HenryL27 <[email protected]> * add javadocs Signed-off-by: HenryL27 <[email protected]> * update to new asyncProcessResponse api Signed-off-by: HenryL27 <[email protected]> * rename reranktype to ML_OPENSEARCH Signed-off-by: HenryL27 <[email protected]> * improve error messages for bad rerank type config Signed-off-by: HenryL27 <[email protected]> * simplify configuration/factory logic Signed-off-by: HenryL27 <[email protected]> * improve handling for non-flat-string context fields Signed-off-by: HenryL27 <[email protected]> * rename TextSimilarity files to MLOpenSearch files Signed-off-by: HenryL27 <[email protected]> * apply spotless after rebase Signed-off-by: HenryL27 <[email protected]> * update changelog Signed-off-by: HenryL27 <[email protected]> * after rebase Signed-off-by: HenryL27 <[email protected]> * Address pr comments and fix XContent in search ext Signed-off-by: HenryL27 <[email protected]> * move contextSourceFetchers to their own subdirectory Signed-off-by: HenryL27 <[email protected]> * Apply suggestions from code review Co-authored-by: Martin Gaievski <[email protected]> Signed-off-by: HenryL27 <[email protected]> * CR changes Signed-off-by: HenryL27 <[email protected]> * finish CR comments and fix broken unittest Signed-off-by: HenryL27 <[email protected]> * fix unittest names Signed-off-by: HenryL27 <[email protected]> --------- Signed-off-by: HenryL27 <[email protected]> Co-authored-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
8877db5
commit a90784c
Showing
20 changed files
with
1,833 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
132 changes: 132 additions & 0 deletions
132
src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.factory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.StringJoiner; | ||
|
||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; | ||
import org.opensearch.neuralsearch.processor.rerank.RerankType; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
|
||
import com.google.common.collect.Sets; | ||
|
||
import lombok.AllArgsConstructor; | ||
|
||
/** | ||
* Factory for rerank processors. Must: | ||
* - Instantiate the right kind of rerank processor | ||
* - Instantiate the appropriate context source fetchers | ||
*/ | ||
@AllArgsConstructor | ||
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
public static final String RERANK_PROCESSOR_TYPE = "rerank"; | ||
public static final String CONTEXT_CONFIG_FIELD = "context"; | ||
|
||
private final MLCommonsClientAccessor clientAccessor; | ||
|
||
@Override | ||
public SearchResponseProcessor create( | ||
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, | ||
final String tag, | ||
final String description, | ||
final boolean ignoreFailure, | ||
final Map<String, Object> config, | ||
final Processor.PipelineContext pipelineContext | ||
) { | ||
RerankType type = findRerankType(config); | ||
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); | ||
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag); | ||
switch (type) { | ||
case ML_OPENSEARCH: | ||
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); | ||
String modelId = ConfigurationUtils.readStringProperty( | ||
RERANK_PROCESSOR_TYPE, | ||
tag, | ||
rerankerConfig, | ||
MLOpenSearchRerankProcessor.MODEL_ID_FIELD | ||
); | ||
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); | ||
} | ||
} | ||
|
||
private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException { | ||
// Set of rerank type labels in the config | ||
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); | ||
// A rerank type must be provided | ||
if (rerankTypes.size() == 0) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]"); | ||
for (RerankType t : RerankType.values()) { | ||
msgBuilder.add(t.getLabel()); | ||
} | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
// Only one rerank type may be provided | ||
if (rerankTypes.size() > 1) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); | ||
rerankTypes.forEach(rt -> msgBuilder.add(rt)); | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
return RerankType.from(rerankTypes.iterator().next()); | ||
} | ||
|
||
/** | ||
* Factory class for context fetchers. Constructs a list of context fetchers | ||
* specified in the pipeline config (and maybe the query context fetcher) | ||
*/ | ||
private static class ContextFetcherFactory { | ||
|
||
/** | ||
* Map rerank types to whether they should include the query context source fetcher | ||
* @param type the constructing RerankType | ||
* @return does this RerankType depend on the QueryContextSourceFetcher? | ||
*/ | ||
public static boolean shouldIncludeQueryContextFetcher(RerankType type) { | ||
return type == RerankType.ML_OPENSEARCH; | ||
} | ||
|
||
/** | ||
* Create necessary queryContextFetchers for this processor | ||
* @param config processor config object. Look for "context" field to find fetchers | ||
* @param includeQueryContextFetcher should I include the queryContextFetcher? | ||
* @return list of contextFetchers for the processor to use | ||
*/ | ||
public static List<ContextSourceFetcher> createFetchers( | ||
Map<String, Object> config, | ||
boolean includeQueryContextFetcher, | ||
String tag | ||
) { | ||
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); | ||
List<ContextSourceFetcher> fetchers = new ArrayList<>(); | ||
for (String key : contextConfig.keySet()) { | ||
Object cfg = contextConfig.get(key); | ||
switch (key) { | ||
case DocumentContextSourceFetcher.NAME: | ||
fetchers.add(DocumentContextSourceFetcher.create(cfg)); | ||
break; | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); | ||
} | ||
} | ||
if (includeQueryContextFetcher) { | ||
fetchers.add(new QueryContextSourceFetcher()); | ||
} | ||
return fetchers; | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.rerank; | ||
|
||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
|
||
/** | ||
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore | ||
*/ | ||
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; | ||
|
||
protected final String modelId; | ||
|
||
protected final MLCommonsClientAccessor mlCommonsClientAccessor; | ||
|
||
/** | ||
* Constructor | ||
* @param description | ||
* @param tag | ||
* @param ignoreFailure | ||
* @param modelId id of TEXT_SIMILARITY model | ||
* @param contextSourceFetchers | ||
* @param mlCommonsClientAccessor | ||
*/ | ||
public MLOpenSearchRerankProcessor( | ||
final String description, | ||
final String tag, | ||
final boolean ignoreFailure, | ||
final String modelId, | ||
final List<ContextSourceFetcher> contextSourceFetchers, | ||
final MLCommonsClientAccessor mlCommonsClientAccessor | ||
) { | ||
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); | ||
this.modelId = modelId; | ||
this.mlCommonsClientAccessor = mlCommonsClientAccessor; | ||
} | ||
|
||
@Override | ||
public void rescoreSearchResponse( | ||
final SearchResponse response, | ||
final Map<String, Object> rerankingContext, | ||
final ActionListener<List<Float>> listener | ||
) { | ||
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); | ||
if (!(ctxObj instanceof List<?>)) { | ||
listener.onFailure( | ||
new IllegalStateException( | ||
String.format( | ||
Locale.ROOT, | ||
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", | ||
RerankProcessorFactory.CONTEXT_CONFIG_FIELD, | ||
DocumentContextSourceFetcher.NAME | ||
) | ||
) | ||
); | ||
return; | ||
} | ||
List<?> ctxList = (List<?>) ctxObj; | ||
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); | ||
mlCommonsClientAccessor.inferenceSimilarity( | ||
modelId, | ||
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), | ||
contexts, | ||
listener | ||
); | ||
} | ||
|
||
} |
Oops, something went wrong.