Skip to content

Commit

Permalink
Add rerank processor interfaces
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 1, 2023
1 parent b3c73bd commit f7a7944
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.Map;

import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
final String tag,
final String description,
final boolean ignoreFailure,
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;

public class CrossEncoderRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String QUERY_TEXT_FIELD = "query_text";
public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";
public static final String RERANK_CONTEXT_FIELD = "rerank_context_field";

protected final String modelId;
protected final String rerank_context;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

private final Environment environment;

public CrossEncoderRerankProcessor(
String description,
String tag,
boolean ignoreFailure,
String modelId,
String rerank_context,
MLCommonsClientAccessor mlCommonsClientAccessor,
Environment environment
) {
super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure);
this.modelId = modelId;
this.rerank_context = rerank_context;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
this.environment = environment;
}

@Override
public void generateScoringContext(
SearchRequest searchRequest,
SearchResponse searchResponse,
ActionListener<Map<String, Object>> listener
) {
try {
List<SearchExtBuilder> exts = searchRequest.source().ext();
Map<String, Object> params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams();
Map<String, Object> scoringContext = new HashMap<>();
if (params.containsKey(QUERY_TEXT_FIELD)) {
if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\"");
}
scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD));
} else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
String path = (String) params.get(QUERY_TEXT_PATH_FIELD);
// Convert query to a map with io/xcontent shenanigans
PipedOutputStream os = new PipedOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(os);
searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS);
PipedInputStream is = new PipedInputStream(os);
XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is);
Map<String, Object> map = parser.map();
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
if (!(queryText instanceof String)) {
throw new IllegalArgumentException("query_text_path must point to a string field");
}
scoringContext.put(QUERY_TEXT_FIELD, (String) queryText);
} else {
throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\"");
}
listener.onResponse(scoringContext);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
public void rescoreSearchResponse(SearchResponse response, Map<String, Object> scoringContext, ActionListener<List<Float>> listener) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.pipeline.SearchResponseProcessor;

public interface RerankProcessor extends SearchResponseProcessor {

/**
* Generate the information that this processor needs in order to rerank.
* That could be as simple as grabbing a field from the search request or
* as complicated as a lookup to some external service
* @param searchRequest the search query
* @param searchResponse the search results, in case they're relevant
* @param listener be async
*/
public void generateScoringContext(
SearchRequest searchRequest,
SearchResponse searchResponse,
ActionListener<Map<String, Object>> listener
);

/**
* Given the scoring context generated by the processor and the search results,
* rerank the search results. Do so asynchronously.
* @param searchResponse the search results to rerank
* @param scoringContext the information this processor needs in order to rerank
* @param listener be async
*/
public void rerank(SearchResponse searchResponse, Map<String, Object> scoringContext, ActionListener<SearchResponse> listener);

@Override
default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener<SearchResponse> responseListener) {
try {
generateScoringContext(
request,
response,
ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { responseListener.onFailure(e); })
);
} catch (Exception e) {
responseListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.neuralsearch.processor.rerank;

import lombok.Getter;

/**
* enum for distinguishing various reranking methods
*/
public enum RerankType {

CROSS_ENCODER("cross-encoder");

@Getter
private final String label;

private RerankType(String label) {
this.label = label;
}

/**
* Construct a RerankType from the label
* @param label label of a RerankType
* @return RerankType represented by the label
*/
public static RerankType from(String label) {
try {
return RerankType.valueOf(label);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong rerank type name: " + label);
}
}
}
Loading

0 comments on commit f7a7944

Please sign in to comment.