Skip to content

Commit

Permalink
Support for default model Id
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 26, 2023
1 parent 88ba2e9 commit c06d07a
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.neuralsearch.plugin;

import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -29,6 +29,7 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
Expand Down Expand Up @@ -132,10 +133,9 @@ public List<Setting<?>> getSettings() {
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(
NeuralQueryProcessor.TYPE,
new NeuralQueryProcessor.Factory()
);
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(
Parameters parameters
) {
return Map.of(NeuralQueryProcessor.TYPE, new NeuralQueryProcessor.Factory());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ public class NeuralQueryProcessor extends AbstractProcessor implements SearchReq
/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "default_query";
public static final String TYPE = "neural_query";

private final String modelId;
final String modelId;

private final Map<String, Object> fieldInfoMap;
final Map<String, Object> neuralFieldMap;

/**
* Returns the type of the processor.
Expand All @@ -37,39 +37,39 @@ public String getType() {
}

protected NeuralQueryProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
Map<String, Object> fieldInfoMap
String tag,
String description,
boolean ignoreFailure,
String modelId,
Map<String, Object> fieldInfoMap
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.fieldInfoMap = fieldInfoMap;
this.neuralFieldMap = fieldInfoMap;
}

@Override
public SearchRequest processRequest(SearchRequest searchRequest) throws Exception {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, fieldInfoMap));
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldMap));
return searchRequest;
}

public static class Factory implements Processor.Factory<SearchRequestProcessor> {
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_MAP = "neural_field_map";
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id";

@Override
public NeuralQueryProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
String modelId = (String) config.remove(DEFAULT_MODEL_ID);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_MAP);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
parseQueryParams(parser, neuralQueryBuilder);
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ neuralQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ neuralQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
);
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
Expand Down Expand Up @@ -197,8 +197,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.boost(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
Expand All @@ -207,8 +207,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
Expand All @@ -231,10 +231,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter());
}
Expand Down Expand Up @@ -271,4 +271,4 @@ public String getWriteableName() {
private static boolean isClusterOnOrAfterMinRequiredVersion() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ public void accept(QueryBuilder queryBuilder) {
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(
"model id must be provided in neural query or a default model id must be set in search request processor"
"model id must be provided in neural query or a default model id must be set in search request processor"
);
}
}
Expand All @@ -45,4 +45,4 @@ public void accept(QueryBuilder queryBuilder) {
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.util;

import java.util.Locale;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -47,11 +49,15 @@ public Version getClusterMinVersion() {
return this.clusterService.state().getNodes().getMinNodeVersion();
} catch (Exception exception) {
log.error(
String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT),
exception
String.format(
Locale.ROOT,
"Failed to get cluster minimum node version, returning current node version %s instead.",
Version.CURRENT
),
exception
);
return Version.CURRENT;
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,29 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro
assertEquals("true", node.get("acknowledged").toString());
}

protected void createSearchRequestProcessor(String modelId, String pipelineName) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
"/_search/pipeline/" + pipelineName,
null,
toHttpEntity(
String.format(
LOCALE,
Files.readString(Path.of(classLoader.getResource("processor/SearchRequestPipelineConfiguration.json").toURI())),
modelId
)
),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(pipelineCreateResponse.getEntity()),
false
);
assertEquals("true", node.get("acknowledged").toString());
}

/**
* Get the number of documents in a particular index
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.Optional;

import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
Expand All @@ -22,6 +23,7 @@
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;

public class NeuralSearchTests extends OpenSearchQueryTestCase {
Expand Down Expand Up @@ -73,4 +75,14 @@ public void testSearchPhaseResultsProcessors() {
);
assertTrue(scoringProcessor instanceof NormalizationProcessorFactory);
}

public void testRequestProcessors() {
NeuralSearch plugin = new NeuralSearch();
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> processors = plugin.getRequestProcessors(
parameters
);
assertNotNull(processors);
assertNotNull(processors.get(NeuralQueryProcessor.TYPE));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralQueryProcessorTests extends OpenSearchTestCase {

public void testFactory() throws Exception {
NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory();
NeuralQueryProcessor processor = createTestProcessor(factory);
assertEquals("vasdcvkcjkbldbjkd", processor.modelId);
assertEquals("bahbkcdkacb", processor.neuralFieldMap.get("fieldName").toString());

// Missing "query" parameter:
expectThrows(
IllegalArgumentException.class,
() -> factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null)
);
}

public void testProcessRequest() throws Exception {
NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
NeuralQueryProcessor processor = createTestProcessor(factory);
SearchRequest processSearchRequest = processor.processRequest(searchRequest);
assertEquals(processSearchRequest, searchRequest);
}

public NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception {
Map<String, Object> configMap = new HashMap<>();
configMap.put("default_model_id", "vasdcvkcjkbldbjkd");
configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb"));
NeuralQueryProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configMap, null);
return processor;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.opensearch.Version;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;

public class NeuralSearchClusterTestUtils {

/**
* Create new mock for ClusterService
* @param version min version for cluster nodes
* @return
*/
public static ClusterService mockClusterService(final Version version) {
ClusterService clusterService = mock(ClusterService.class);
ClusterState clusterState = mock(ClusterState.class);
when(clusterService.state()).thenReturn(clusterState);
DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class);
when(clusterState.getNodes()).thenReturn(discoveryNodes);
when(discoveryNodes.getMinNodeVersion()).thenReturn(version);
return clusterService;
}
}
Loading

0 comments on commit c06d07a

Please sign in to comment.