Skip to content

Commit

Permalink
Addressing Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 26, 2023
1 parent badced1 commit 9c010e7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Enabled support for applying default modelId in neural search query ([#337](https://github.com/opensearch-project/neural-search/pull/337)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();

@Override
public Collection<Object> createComponents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

/**
* Neural Search Query Request Processor
*/
public class NeuralQueryProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
Expand Down Expand Up @@ -48,6 +51,11 @@ protected NeuralQueryProcessor(
this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap;
}

/**
* Processes the Search Request.
*
* @return The Search Request.
*/
@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
Expand All @@ -59,6 +67,11 @@ public static class Factory implements Processor.Factory<SearchRequestProcessor>
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id";

/**
* Create the processor object.
*
* @return NeuralQueryProcessor
*/
@Override
public NeuralQueryProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,52 @@

import java.util.Map;

import lombok.AllArgsConstructor;

import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

/**
* Neural Search Query Visitor. It visits the each and every component of query buikder tree.
*/
@AllArgsConstructor
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {

private String modelId;
private Map<String, Object> neuralFieldMap;

public NeuralSearchQueryVisitor(String modelId, Map<String, Object> neuralFieldMap) {
this.modelId = modelId;
this.neuralFieldMap = neuralFieldMap;
}

/**
* Accept method accepts every query builder from the search request,
* and processes it if the required conditions in accept method are satisfied.
*/
@Override
public void accept(QueryBuilder queryBuilder) {
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(
"model id must be provided in neural query or a default model id must be set in search request processor"
);
if (neuralQueryBuilder.modelId() == null) {
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);

Check warning on line 39 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L38-L39

Added lines #L38 - L39 were not covered by tests
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(

Check warning on line 43 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L43

Added line #L43 was not covered by tests
"model id must be provided in neural query or a default model id must be set in search request processor"
);
}
}
}
}

/**
* Retrieves the child visitor from the Visitor object.
*
* @return The sub Query Visitor
*/
@Override
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;

Check warning on line 58 in src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java#L58

Added line #L58 was not covered by tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class NeuralSearchClusterUtil {
Expand Down

0 comments on commit 9c010e7

Please sign in to comment.