Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for generic re-ranker interface and opensearch ml re-ranker for improving search relavancy. #494

Merged
merged 27 commits into from
Jan 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bdf5d9b
Add rerank processor interfaces
HenryL27 Nov 14, 2023
17cff65
add cross-encoder specific logic and factory
HenryL27 Nov 14, 2023
8d476db
add unittests
HenryL27 Nov 16, 2023
4efa463
add integration test
HenryL27 Nov 18, 2023
de96761
use string.format() instead of concatenation
HenryL27 Dec 1, 2023
6f85824
rename generateScoringContext to generateRerankingContext
HenryL27 Dec 1, 2023
a30180c
add name change in test too. whoops
HenryL27 Dec 1, 2023
b8820ec
start refactoring with contextSaourceFetchers
HenryL27 Dec 4, 2023
5e1c00b
refactor to use contextSourceFetchers to get context
HenryL27 Dec 5, 2023
2976807
rename CrossEncoder to TextSimilarity
HenryL27 Dec 5, 2023
5332fee
add query_context layer to search ext
HenryL27 Dec 5, 2023
aa1d524
add javadocs
HenryL27 Dec 5, 2023
77301d9
update to new asyncProcessResponse api
HenryL27 Dec 11, 2023
e8de412
rename reranktype to ML_OPENSEARCH
HenryL27 Dec 18, 2023
a7090b2
improve error messages for bad rerank type config
HenryL27 Dec 18, 2023
797eaf6
simplify configuration/factory logic
HenryL27 Dec 18, 2023
ddf2866
improve handling for non-flat-string context fields
HenryL27 Dec 18, 2023
14c8f89
rename TextSimilarity files to MLOpenSearch files
HenryL27 Dec 18, 2023
577f855
apply spotless after rebase
HenryL27 Dec 19, 2023
e3cf218
update changelog
HenryL27 Dec 21, 2023
7a6595f
after rebase
HenryL27 Jan 9, 2024
708fb66
Address pr comments and fix XContent in search ext
HenryL27 Jan 10, 2024
2d04075
move contextSourceFetchers to their own subdirectory
HenryL27 Jan 10, 2024
a39428b
Apply suggestions from code review
HenryL27 Jan 11, 2024
f462965
CR changes
HenryL27 Jan 11, 2024
db8bec1
finish CR comments and fix broken unittest
HenryL27 Jan 11, 2024
7962ffa
fix unittest names
HenryL27 Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
simplify configuration/factory logic
Signed-off-by: HenryL27 <[email protected]>
HenryL27 committed Jan 9, 2024
commit 797eaf6b00bea0e343c61ffd315b133f7b438f0d
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;

import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher;
@@ -63,22 +63,19 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
switch (type) {
case ML_OPENSEARCH:
@SuppressWarnings("unchecked")
Map<String, String> rerankerConfig = (Map<String, String>) config.remove(type.getLabel());
String modelId = rerankerConfig.get(TextSimilarityRerankProcessor.MODEL_ID_FIELD);
if (modelId == null) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "%s must be specified", TextSimilarityRerankProcessor.MODEL_ID_FIELD)
);
}
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
TextSimilarityRerankProcessor.MODEL_ID_FIELD
);
return new TextSimilarityRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(
String.format(Locale.ROOT, "could not find constructor for reranker type %s", type.getLabel())
);
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
}

@@ -97,9 +94,7 @@ RerankType findRerankType(final Map<String, Object> config) throws IllegalArgume
// Only one rerank type may be provided
if (rerankTypes.size() > 1) {
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted.");
for (String rt : rerankTypes) {
msgBuilder.add(rt);
}
rerankTypes.forEach(rt -> msgBuilder.add(rt));
throw new IllegalArgumentException(msgBuilder.toString());
}
return RerankType.from(rerankTypes.iterator().next());
@@ -131,26 +126,18 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(Map<String, Object> config, boolean includeQueryContextFetcher) {
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
@SuppressWarnings("unchecked")
Map<String, Object> contextConfig = (Map<String, Object>) config.remove(CONTEXT_CONFIG_FIELD);
if (contextConfig == null) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field must be provided", CONTEXT_CONFIG_FIELD));
}
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
Object cfg = contextConfig.get(key);
if (!(cfg instanceof List<?>)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of strings", key));
}
List<?> fields = (List<?>) contextConfig.get(key);
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", key));
}
List<String> strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList());
fetchers.add(new DocumentContextSourceFetcher(strfields));
fetchers.add(DocumentContextSourceFetcher.create(cfg));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
Original file line number Diff line number Diff line change
@@ -45,4 +45,5 @@ public interface ContextSourceFetcher {
* @return Name of the fetcher
*/
public String getName();

}
Original file line number Diff line number Diff line change
@@ -20,7 +20,9 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;

@@ -75,4 +77,21 @@ private String contextFromSearchHit(final SearchHit hit, final String field) {
public String getName() {
return NAME;
}

/**
* Create a document context source fetcher from list of field names provided by configuration
* @param config configuration object grabbed from parsed API request. Should be a list of strings
* @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed
*/
public static DocumentContextSourceFetcher create(Object config) {
if (!(config instanceof List)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME));
}
List<?> fields = (List<?>) config;
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME));
}
List<String> strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList());
return new DocumentContextSourceFetcher(strfields);
}
}
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@

import org.junit.Before;
import org.mockito.Mock;
import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
@@ -134,8 +135,8 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() {
public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() {
Map<String, Object> config = new HashMap<>(Map.of(RerankType.ML_OPENSEARCH.getLabel(), Map.of()));
assertThrows(
TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified",
IllegalArgumentException.class,
String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD),
OpenSearchParseException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}
@@ -145,8 +146,8 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() {
Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")))
);
assertThrows(
String.format(Locale.ROOT, "%s field must be provided", RerankProcessorFactory.CONTEXT_CONFIG_FIELD),
IllegalArgumentException.class,
String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD),
OpenSearchParseException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}
@@ -161,8 +162,8 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() {
)
);
assertThrows(
TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified",
IllegalArgumentException.class,
String.format(Locale.ROOT, "[%s] required property is missing", TextSimilarityRerankProcessor.MODEL_ID_FIELD),
OpenSearchParseException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}