Skip to content

Commit

Permalink
add integration test
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 1, 2023
1 parent f172279 commit f7eda01
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand Down Expand Up @@ -221,14 +220,9 @@ private MLInput createMLTextInput(final List<String> targetResponseFilters, List
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final List<Pair<String, String>> pairs) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(pairs);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
List<Pair<String, String>> pairs = inputText.stream().map(text -> Pair.of(query, text)).collect(Collectors.toList());
return createMLTextPairsInput(pairs);
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor;
Expand All @@ -29,6 +30,7 @@

import com.google.common.annotations.VisibleForTesting;

@Log4j2
@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

Expand All @@ -49,7 +51,7 @@ public SearchResponseProcessor create(
switch (type) {
case CROSS_ENCODER:
@SuppressWarnings("unchecked")
Map<String, String> rerankerConfig = (Map<String, String>) config.get(type.getLabel());
Map<String, String> rerankerConfig = (Map<String, String>) config.remove(type.getLabel());
String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD);
if (modelId == null) {
throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -33,6 +34,7 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.profile.SearchProfileShardResults;

@Log4j2
@AllArgsConstructor
public abstract class RescoringRerankProcessor implements RerankProcessor {

Expand Down Expand Up @@ -80,6 +82,7 @@ public abstract void rescoreSearchResponse(

@Override
public void rerank(SearchResponse searchResponse, Map<String, Object> scoringContext, ActionListener<SearchResponse> listener) {
log.info("==================RERANKING==================");
try {
rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> {
// Assign new scores
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchExtBuilder;

@Log4j2
@AllArgsConstructor
public class RerankSearchExtBuilder extends SearchExtBuilder {

Expand Down Expand Up @@ -89,8 +91,7 @@ public static RerankSearchExtBuilder fromExtBuilderList(List<SearchExtBuilder> b
* @throws IOException if problems parsing
*/
public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException {
@SuppressWarnings("unchecked")
RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map<String, Object>) parser.map().get(PARAM_FIELD_NAME));
RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map<String, Object>) parser.map());
return ans;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,19 +317,19 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName)

protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
"/_search/pipeline/" + pipelineName,
null,
toHttpEntity(
String.format(
LOCALE,
Files.readString(Path.of(classLoader.getResource(configPath).toURI())),
modelId
)
),
client(),
"PUT",
"/_search/pipeline/" + pipelineName,
null,
toHttpEntity(String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(configPath).toURI())), modelId)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(pipelineCreateResponse.getEntity()),
false
);
assertEquals("true", node.get("acknowledged").toString());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.mockito.Mockito.mock;

import java.util.HashMap;
import java.util.Map;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -55,7 +56,7 @@ public void setup() {
}

public void testRerankProcessorFactory_EmptyConfig_ThenFail() {
Map<String, Object> config = Map.of();
Map<String, Object> config = new HashMap<>(Map.of());
assertThrows(
"no rerank type found",
IllegalArgumentException.class,
Expand All @@ -64,7 +65,9 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() {
}

public void testRerankProcessorFactory_NonExistentType_ThenFail() {
Map<String, Object> config = Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"));
Map<String, Object> config = new HashMap<>(
Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"))
);
assertThrows(
"no rerank type found",
IllegalArgumentException.class,
Expand All @@ -73,13 +76,17 @@ public void testRerankProcessorFactory_NonExistentType_ThenFail() {
}

public void testRerankProcessorFactory_CrossEncoder_HappyPath() {
Map<String, Object> config = Map.of(
RerankType.CROSS_ENCODER.getLabel(),
Map<String, Object> config = new HashMap<>(
Map.of(
CrossEncoderRerankProcessor.MODEL_ID_FIELD,
"model-id",
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD,
"text_representation"
RerankType.CROSS_ENCODER.getLabel(),
new HashMap<>(
Map.of(
CrossEncoderRerankProcessor.MODEL_ID_FIELD,
"model-id",
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD,
"text_representation"
)
)
)
);
SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext);
Expand All @@ -89,17 +96,21 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() {
}

public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() {
Map<String, Object> config = Map.of(
"poafn aorr;anv",
Map.of(";oawhls", "aowirhg "),
RerankType.CROSS_ENCODER.getLabel(),
Map<String, Object> config = new HashMap<>(
Map.of(
CrossEncoderRerankProcessor.MODEL_ID_FIELD,
"model-id",
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD,
"text_representation",
"pqiohg rpowierhg",
"pw;oith4pt3ih go"
"poafn aorr;anv",
Map.of(";oawhls", "aowirhg "),
RerankType.CROSS_ENCODER.getLabel(),
new HashMap<>(
Map.of(
CrossEncoderRerankProcessor.MODEL_ID_FIELD,
"model-id",
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD,
"text_representation",
"pqiohg rpowierhg",
"pw;oith4pt3ih go"
)
)
)
);
SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext);
Expand All @@ -109,7 +120,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() {
}

public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() {
Map<String, Object> config = Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of());
Map<String, Object> config = new HashMap<>(Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of()));
assertThrows(
CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified",
IllegalArgumentException.class,
Expand All @@ -118,9 +129,8 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() {
}

public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() {
Map<String, Object> config = Map.of(
RerankType.CROSS_ENCODER.getLabel(),
Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")
Map<String, Object> config = new HashMap<>(
Map.of(RerankType.CROSS_ENCODER.getLabel(), new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")))
);
assertThrows(
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified",
Expand All @@ -130,9 +140,11 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() {
}

public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() {
Map<String, Object> config = Map.of(
RerankType.CROSS_ENCODER.getLabel(),
Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation")
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.CROSS_ENCODER.getLabel(),
new HashMap<>(Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation"))
)
);
assertThrows(
CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;

import com.google.common.collect.ImmutableList;

@Log4j2
public class CrossEncoderRerankProcessorIT extends BaseNeuralSearchIT {

final static String PIPELINE_NAME = "rerank-ce-pipeline";
final static String INDEX_NAME = "rerank-test";
final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy";
final static String TEXT_REP_2 = "Fish like to eat plankton";
final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}";

@After
@SneakyThrows
public void tearDown() {
super.tearDown();
/* this is required to minimize chance of model not being deployed due to open memory CB,
* this happens in case we leave model from previous test case. We use new model for every test, and old model
* can be undeployed and deleted to free resources after each test case execution.
*/
deleteSearchPipeline(PIPELINE_NAME);
findDeployedModels().forEach(this::deleteModel);
deleteIndex(INDEX_NAME);
}

public void testCrossEncoderRerankProcessor() throws Exception {
String modelId = uploadCrossEncoderModel();
loadModel(modelId);
createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/CrossEncoderRerankPipelineConfiguration.json");
setupIndex();
runQueries();
}

private String uploadCrossEncoderModel() throws Exception {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/UploadCrossEncoderModelRequestBody.json").toURI())
);
return uploadModel(requestBody);
}

private void setupIndex() throws Exception {
createIndexWithConfiguration(INDEX_NAME, INDEX_CONFIG, PIPELINE_NAME);
Response response1 = makeRequest(
client(),
"POST",
INDEX_NAME + "/_doc?refresh",
null,
toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_1)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Response response2 = makeRequest(
client(),
"POST",
INDEX_NAME + "/_doc?refresh",
null,
toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_2)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> map = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(response1.getEntity()),
false
);
assertEquals("created", map.get("result"));
map = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response2.getEntity()), false);
assertEquals("created", map.get("result"));
}

private void runQueries() throws Exception {
Map<String, Object> response1 = search("What do fish eat?");
@SuppressWarnings("unchecked")
List<Map<String, ?>> hits = (List<Map<String, ?>>) ((Map<String, ?>) response1.get("hits")).get("hits");
@SuppressWarnings("unchecked")
Map<String, String> hit0Source = (Map<String, String>) hits.get(0).get("_source");
assert ((String) hit0Source.get("text_representation")).equals(TEXT_REP_2);
@SuppressWarnings("unchecked")
Map<String, String> hit1Source = (Map<String, String>) hits.get(1).get("_source");
assert ((String) hit1Source.get("text_representation")).equals(TEXT_REP_1);

Map<String, Object> response2 = search("Who loves fish?");
@SuppressWarnings("unchecked")
List<Map<String, ?>> hits2 = (List<Map<String, ?>>) ((Map<String, ?>) response2.get("hits")).get("hits");
@SuppressWarnings("unchecked")
Map<String, String> hit2Source = (Map<String, String>) hits2.get(0).get("_source");
assert ((String) hit2Source.get("text_representation")).equals(TEXT_REP_1);
@SuppressWarnings("unchecked")
Map<String, String> hit3Source = (Map<String, String>) hits2.get(1).get("_source");
assert ((String) hit3Source.get("text_representation")).equals(TEXT_REP_2);
}

private Map<String, Object> search(String queryText) throws Exception {
String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_text\":\"%s\"}}}";
String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText);
log.info(jsonQuery);
Request request = new Request("POST", "/" + INDEX_NAME + "/_search");
request.addParameter("search_pipeline", PIPELINE_NAME);
request.setJsonEntity(jsonQuery);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());

return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
}
}
Loading

0 comments on commit f7eda01

Please sign in to comment.