Skip to content

Commit

Permalink
Addressing Comments of Navneet
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 29, 2023
1 parent ef5f939 commit 29017ba
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.*;
import org.opensearch.neuralsearch.processor.EnrichingQueryDefaultProcessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
Expand Down Expand Up @@ -142,6 +145,6 @@ public List<Setting<?>> getSettings() {
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(
Parameters parameters
) {
return Map.of(EnrichingQueryDefaultProcessor.TYPE, new EnrichingQueryDefaultProcessor.Factory());
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.neuralsearch.processor;

import static org.opensearch.ingest.ConfigurationUtils.*;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;

import java.util.Map;

import lombok.Getter;
Expand All @@ -25,12 +28,12 @@
*/
@Setter
@Getter
public class EnrichingQueryDefaultProcessor extends AbstractProcessor implements SearchRequestProcessor {
public class NeuralQueryEnricherProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "enriching_query_defaults";
public static final String TYPE = "neural_query_enricher";

private final String modelId;

Expand All @@ -46,7 +49,7 @@ public String getType() {
return TYPE;
}

private EnrichingQueryDefaultProcessor(
private NeuralQueryEnricherProcessor(
String tag,
String description,
boolean ignoreFailure,
Expand Down Expand Up @@ -77,30 +80,25 @@ public static class Factory implements Processor.Factory<SearchRequestProcessor>
/**
* Create the processor object.
*
* @return EnrichingQueryDefaultProcessor
* @return NeuralQueryEnricherProcessor
*/
@Override
public EnrichingQueryDefaultProcessor create(
public NeuralQueryEnricherProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws IllegalArgumentException {
String modelId;
try {
modelId = (String) config.remove(DEFAULT_MODEL_ID);
} catch (ClassCastException e) {
throw new IllegalArgumentException("model Id must of String type");
}
String modelId = readOptionalStringProperty(TYPE, tag, config, DEFAULT_MODEL_ID);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
}

return new EnrichingQueryDefaultProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
return new NeuralQueryEnricherProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import java.util.Optional;

import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.EnrichingQueryDefaultProcessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
Expand Down Expand Up @@ -83,6 +83,6 @@ public void testRequestProcessors() {
parameters
);
assertNotNull(processors);
assertNotNull(processors.get(EnrichingQueryDefaultProcessor.TYPE));
assertNotNull(processors.get(NeuralQueryEnricherProcessor.TYPE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import com.google.common.primitives.Floats;

public class EnrichingQueryDefaultProcessorIT extends BaseNeuralSearchIT {
public class NeuralQueryEnricherProcessorIT extends BaseNeuralSearchIT {

private static final String index = "my-nlp-index";
private static final String search_pipeline = "search-pipeline";
Expand All @@ -47,7 +47,7 @@ public void tearDown() {
}

@SneakyThrows
public void testEnrichingQueryProcessor_whenNoModelIdPassed_thenSuccess() {
public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() {
initializeIndexIfNotExist();
String modelId = getDeployedModelId();
createSearchRequestProcessor(modelId, search_pipeline);
Expand All @@ -65,7 +65,7 @@ public void testEnrichingQueryProcessor_whenNoModelIdPassed_thenSuccess() {

@SneakyThrows
private void initializeIndexIfNotExist() {
if (index.equals(EnrichingQueryDefaultProcessorIT.index) && !indexExists(index)) {
if (index.equals(NeuralQueryEnricherProcessorIT.index) && !indexExists(index)) {
prepareKnnIndex(
index,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import java.util.HashMap;
import java.util.Map;

import org.opensearch.OpenSearchParseException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;

public class EnrichingQueryDefaultProcessorTests extends OpenSearchTestCase {
public class NeuralQueryEnricherProcessorTests extends OpenSearchTestCase {

public void testFactory_whenMissingQueryParam_thenThrowException() throws Exception {
EnrichingQueryDefaultProcessor.Factory factory = new EnrichingQueryDefaultProcessor.Factory();
EnrichingQueryDefaultProcessor processor = createTestProcessor(factory);
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
assertEquals("vasdcvkcjkbldbjkd", processor.getModelId());
assertEquals("bahbkcdkacb", processor.getNeuralFieldDefaultIdMap().get("fieldName").toString());

Expand All @@ -30,33 +31,33 @@ public void testFactory_whenMissingQueryParam_thenThrowException() throws Except
}

public void testFactory_whenModelIdIsNotString_thenFail() {
EnrichingQueryDefaultProcessor.Factory factory = new EnrichingQueryDefaultProcessor.Factory();
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
Map<String, Object> configMap = new HashMap<>();
configMap.put("default_model_id", 55555L);
expectThrows(IllegalArgumentException.class, () -> factory.create(Collections.emptyMap(), null, null, false, configMap, null));
expectThrows(OpenSearchParseException.class, () -> factory.create(Collections.emptyMap(), null, null, false, configMap, null));
}

public void testProcessRequest_whenVisitingQueryBuilder_thenSuccess() throws Exception {
EnrichingQueryDefaultProcessor.Factory factory = new EnrichingQueryDefaultProcessor.Factory();
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
EnrichingQueryDefaultProcessor processor = createTestProcessor(factory);
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
SearchRequest processSearchRequest = processor.processRequest(searchRequest);
assertEquals(processSearchRequest, searchRequest);
}

public void testType() throws Exception {
EnrichingQueryDefaultProcessor.Factory factory = new EnrichingQueryDefaultProcessor.Factory();
EnrichingQueryDefaultProcessor processor = createTestProcessor(factory);
assertEquals(EnrichingQueryDefaultProcessor.TYPE, processor.getType());
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
assertEquals(NeuralQueryEnricherProcessor.TYPE, processor.getType());
}

private EnrichingQueryDefaultProcessor createTestProcessor(EnrichingQueryDefaultProcessor.Factory factory) throws Exception {
private NeuralQueryEnricherProcessor createTestProcessor(NeuralQueryEnricherProcessor.Factory factory) throws Exception {
Map<String, Object> configMap = new HashMap<>();
configMap.put("default_model_id", "vasdcvkcjkbldbjkd");
configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb"));
EnrichingQueryDefaultProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configMap, null);
NeuralQueryEnricherProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configMap, null);
return processor;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"request_processors": [
{
"enriching_query_defaults": {
"neural_query_enricher": {
"tag": "tag1",
"description": "This processor is going to restrict to publicly visible documents",
"default_model_id": "%s"
Expand Down

0 comments on commit 29017ba

Please sign in to comment.