Skip to content

Commit

Permalink
rename TextSimilarity files to MLOpenSearch files
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 18, 2023
1 parent 3c13ee0 commit 11b5be1
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.QueryContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

Expand Down Expand Up @@ -70,9 +70,9 @@ public SearchResponseProcessor create(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
TextSimilarityRerankProcessor.MODEL_ID_FIELD
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new TextSimilarityRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
/**
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore
*/
public class TextSimilarityRerankProcessor extends RescoringRerankProcessor {
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";

Expand All @@ -47,7 +47,7 @@ public class TextSimilarityRerankProcessor extends RescoringRerankProcessor {
* @param contextSourceFetchers
* @param mlCommonsClientAccessor
*/
public TextSimilarityRerankProcessor(
public MLOpenSearchRerankProcessor(
String description,
String tag,
boolean ignoreFailure,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor;
import org.opensearch.search.pipeline.Processor.PipelineContext;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -68,7 +68,7 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() {

public void testRerankProcessorFactory_NonExistentType_ThenFail() {
Map<String, Object> config = new HashMap<>(
Map.of("jpeo rvgh we iorgn", Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))
Map.of("jpeo rvgh we iorgn", Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id"))
);
assertThrows(
"no rerank type found",
Expand All @@ -81,14 +81,14 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation"))))
)
);
SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext);
assert (processor instanceof RerankProcessor);
assert (processor instanceof TextSimilarityRerankProcessor);
assert (processor instanceof MLOpenSearchRerankProcessor);
assert (processor.getType().equals(RerankProcessor.TYPE));
}

Expand All @@ -98,22 +98,22 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() {
"poafn aorr;anv",
Map.of(";oawhls", "aowirhg "),
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation"))))
)
);
SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext);
assert (processor instanceof RerankProcessor);
assert (processor instanceof TextSimilarityRerankProcessor);
assert (processor instanceof MLOpenSearchRerankProcessor);
assert (processor.getType().equals(RerankProcessor.TYPE));
}

public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -143,7 +143,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() {

public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() {
Map<String, Object> config = new HashMap<>(
Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")))
Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")))
);
assertThrows(
String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD),
Expand All @@ -162,7 +162,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() {
)
);
assertThrows(
String.format(Locale.ROOT, "[%s] required property is missing", TextSimilarityRerankProcessor.MODEL_ID_FIELD),
String.format(Locale.ROOT, "[%s] required property is missing", MLOpenSearchRerankProcessor.MODEL_ID_FIELD),
OpenSearchParseException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
Expand All @@ -172,7 +172,7 @@ public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail(
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation"))
)
Expand All @@ -188,7 +188,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFai
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>()))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
import com.google.common.collect.ImmutableList;

@Log4j2
public class TextSimilarityRerankProcessorIT extends BaseNeuralSearchIT {
public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT {

final static String PIPELINE_NAME = "rerank-ts-pipeline";
final static String PIPELINE_NAME = "rerank-mlos-pipeline";
final static String INDEX_NAME = "rerank-test";
final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy";
final static String TEXT_REP_2 = "Fish like to eat plankton";
Expand All @@ -63,7 +63,7 @@ public void tearDown() {
public void testCrossEncoderRerankProcessor() throws Exception {
String modelId = uploadTextSimilarityModel();
loadModel(modelId);
createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/TextSimilarityRerankPipelineConfiguration.json");
createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json");
setupIndex();
runQueries();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import org.opensearch.test.OpenSearchTestCase;

@Log4j2
public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase {
public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {

@Mock
SearchRequest request;
Expand All @@ -80,7 +80,7 @@ public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase {

RerankProcessorFactory factory;

TextSimilarityRerankProcessor processor;
MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
Expand All @@ -89,12 +89,12 @@ public void setup() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation"))))
)
);
processor = (TextSimilarityRerankProcessor) factory.create(
processor = (MLOpenSearchRerankProcessor) factory.create(
Map.of(),
"rerank processor",
"processor for reranking with a cross encoder",
Expand Down

0 comments on commit 11b5be1

Please sign in to comment.