Skip to content

Commit

Permalink
add unittests
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 1, 2023
1 parent aceb846 commit f172279
Show file tree
Hide file tree
Showing 10 changed files with 691 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.annotations.VisibleForTesting;

@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

Expand All @@ -49,14 +51,21 @@ public SearchResponseProcessor create(
@SuppressWarnings("unchecked")
Map<String, String> rerankerConfig = (Map<String, String>) config.get(type.getLabel());
String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD);
if (modelId == null) {
throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified");
}
String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD);
if (rerankContext == null) {
throw new IllegalArgumentException(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified");
}
return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor);
default:
throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel());
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
@VisibleForTesting
RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
for (String key : config.keySet()) {
try {
RerankType attempt = RerankType.from(key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -38,6 +40,7 @@
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHit;

@Log4j2
public class CrossEncoderRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";
Expand Down Expand Up @@ -76,26 +79,29 @@ public void generateScoringContext(
Map<String, Object> scoringContext = new HashMap<>();
if (params.containsKey(QUERY_TEXT_FIELD)) {
if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\"");
throw new IllegalArgumentException(
"Cannot specify both \"" + QUERY_TEXT_FIELD + "\" and \"" + QUERY_TEXT_PATH_FIELD + "\""
);
}
scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD));
} else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) {
String path = (String) params.get(QUERY_TEXT_PATH_FIELD);
// Convert query to a map with io/xcontent shenanigans
PipedOutputStream os = new PipedOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(os);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
XContentBuilder builder = XContentType.CBOR.contentBuilder(baos);
searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS);
PipedInputStream is = new PipedInputStream(os);
XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is);
builder.close();
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais);
Map<String, Object> map = parser.map();
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
if (!(queryText instanceof String)) {
throw new IllegalArgumentException("query_text_path must point to a string field");
throw new IllegalArgumentException(QUERY_TEXT_PATH_FIELD + " must point to a string field");
}
scoringContext.put(QUERY_TEXT_FIELD, (String) queryText);
} else {
throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\"");
throw new IllegalArgumentException("Must specify either \"" + QUERY_TEXT_FIELD + "\" or \"" + QUERY_TEXT_PATH_FIELD + "\"");
}
listener.onResponse(scoringContext);
} catch (Exception e) {
Expand All @@ -115,7 +121,7 @@ public void rescoreSearchResponse(SearchResponse response, Map<String, Object> s
private String contextFromSearchHit(final SearchHit hit) {
if (hit.getFields().containsKey(this.rerankContext)) {
return (String) hit.field(this.rerankContext).getValue();
} else if (hit.getSourceAsMap().containsKey(this.rerankContext)) {
} else if (hit.hasSource() && hit.getSourceAsMap().containsKey(this.rerankContext)) {
return (String) hit.getSourceAsMap().get(this.rerankContext);
} else {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.Arrays;
import java.util.Optional;

import lombok.Getter;

/**
Expand All @@ -39,9 +42,10 @@ private RerankType(String label) {
* @return RerankType represented by the label
*/
public static RerankType from(String label) {
try {
return RerankType.valueOf(label);
} catch (Exception e) {
Optional<RerankType> typeMaybe = Arrays.stream(RerankType.values()).filter(rrt -> rrt.label.equals(label)).findFirst();
if (typeMaybe.isPresent()) {
return typeMaybe.get();
} else {
throw new IllegalArgumentException("Wrong rerank type name: " + label);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,18 @@ public void rerank(SearchResponse searchResponse, Map<String, Object> scoringCon
rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> {
// Assign new scores
SearchHit[] hits = searchResponse.getHits().getHits();
assert (hits.length == scores.size());
if (hits.length != scores.size()) {
throw new Exception("scores and hits are not the same length");
}
for (int i = 0; i < hits.length; i++) {
hits[i].score(scores.get(i));
}
// Re-sort by the new scores
Collections.sort(Arrays.asList(hits), new Comparator<SearchHit>() {
@Override
public int compare(SearchHit hit1, SearchHit hit2) {
return Float.compare(hit1.getScore(), hit2.getScore());
// backwards to sort DESC
return Float.compare(hit2.getScore(), hit1.getScore());
}
});
// Reconstruct the search response, replacing the max score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PARAM_FIELD_NAME, this.params);
builder.endObject();
return builder;
return builder.field(PARAM_FIELD_NAME, this.params);
}

@Override
Expand Down Expand Up @@ -92,7 +89,9 @@ public static RerankSearchExtBuilder fromExtBuilderList(List<SearchExtBuilder> b
* @throws IOException if problems parsing
*/
public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException {
return new RerankSearchExtBuilder(parser.map());
@SuppressWarnings("unchecked")
RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map<String, Object>) parser.map().get(PARAM_FIELD_NAME));
return ans;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,23 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName)
assertEquals("true", node.get("acknowledged").toString());
}

protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
"/_search/pipeline/" + pipelineName,
null,
toHttpEntity(
String.format(
LOCALE,
Files.readString(Path.of(classLoader.getResource(configPath).toURI())),
modelId
)
),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}

/**
* Get the number of documents in a particular index
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,71 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

public void testInferenceSimilarity_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createManyModelTensorOutputs(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
final RuntimeException exception = new RuntimeException();
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(exception);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down Expand Up @@ -353,4 +418,21 @@ private ModelTensorOutput createModelTensorOutput(final Map<String, String> map)
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}

private ModelTensorOutput createManyModelTensorOutputs(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
for (Float score : output) {
List<ModelTensor> tensorList = new ArrayList<>();
String name = "logits";
Number[] data = new Number[] { score };
long[] shape = new long[] { 1 };
MLResultDataType dataType = MLResultDataType.FLOAT32;
MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name());
ModelTensor tensor = ModelTensor.builder().name(name).data(data).shape(shape).dataType(mlResultDataType).build();
tensorList.add(tensor);
tensorsList.add(new ModelTensors(tensorList));
}
ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList);
return modelTensorOutput;
}
}
Loading

0 comments on commit f172279

Please sign in to comment.