Skip to content

Commit

Permalink
Support for Default Model id
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 25, 2023
1 parent dac5012 commit 6903f94
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.neuralsearch.plugin;

import org.opensearch.neuralsearch.processor.DefaultValueProcessor;
import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED;

import java.util.Arrays;
Expand Down Expand Up @@ -42,6 +42,7 @@
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand Down Expand Up @@ -80,6 +81,7 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
Expand Down Expand Up @@ -131,8 +133,8 @@ public List<Setting<?>> getSettings() {
@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(
DefaultValueProcessor.TYPE,
new DefaultValueProcessor.Factory(parameters.namedXContentRegistry)
NeuralQueryProcessor.TYPE,
new NeuralQueryProcessor.Factory()
);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

public class NeuralQueryProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "default_query";

private final String modelId;

private final Map<String, Object> fieldInfoMap;

/**
* Returns the type of the processor.
*
* @return The processor type.
*/
@Override
public String getType() {
return TYPE;
}

protected NeuralQueryProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
Map<String, Object> fieldInfoMap
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.fieldInfoMap = fieldInfoMap;
}

@Override
public SearchRequest processRequest(SearchRequest searchRequest) throws Exception {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, fieldInfoMap));
return searchRequest;
}

public static class Factory implements Processor.Factory<SearchRequestProcessor> {
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_MAP = "neural_field_map";

@Override
public NeuralQueryProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
String modelId = (String) config.remove(DEFAULT_MODEL_ID);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_MAP);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
}

return new NeuralQueryProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -31,12 +32,10 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.*;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -82,6 +81,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;

/**
* Constructor from stream input
Expand All @@ -93,7 +93,11 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
if (isClusterOnOrAfterMinRequiredVersion()) {
this.modelId = in.readOptionalString();
} else {
this.modelId = in.readString();
}
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
Expand All @@ -102,7 +106,11 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
if (isClusterOnOrAfterMinRequiredVersion()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
}
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}
Expand Down Expand Up @@ -152,20 +160,21 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
parseQueryParams(parser, neuralQueryBuilder);
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ neuralQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ neuralQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
);
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

if (!isClusterOnOrAfterMinRequiredVersion()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}
return neuralQueryBuilder;
}

Expand All @@ -188,8 +197,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.boost(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
Expand All @@ -198,8 +207,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
Expand All @@ -222,10 +231,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter());
}
Expand Down Expand Up @@ -258,4 +267,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}
Loading

0 comments on commit 6903f94

Please sign in to comment.