Skip to content

Commit

Permalink
Fixing the backward incompatible changes coming from core in ScoreScr…
Browse files Browse the repository at this point in the history
…ipt class. (#1056) (#1063)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Aug 25, 2023
1 parent d89c85e commit 50fae34
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 47 deletions.
21 changes: 13 additions & 8 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.fielddata.ScriptDocValues;
Expand Down Expand Up @@ -32,9 +33,10 @@ public KNNScoreScript(
String field,
BiFunction<T, T, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) {
super(params, lookup, leafContext);
super(params, lookup, searcher, leafContext);
this.queryValue = queryValue;
this.field = field;
this.scoringMethod = scoringMethod;
Expand All @@ -51,9 +53,10 @@ public LongType(
String field,
BiFunction<Long, Long, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down Expand Up @@ -84,9 +87,10 @@ public BigIntegerType(
String field,
BiFunction<BigInteger, BigInteger, Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down Expand Up @@ -118,9 +122,10 @@ public KNNVectorType(
String field,
BiFunction<float[], float[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext);
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.script.ScoreScript;
Expand All @@ -21,13 +22,16 @@ public class KNNScoreScriptFactory implements ScoreScript.LeafFactory {
private Object query;
private KNNScoringSpace knnScoringSpace;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup) {
private IndexSearcher searcher;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
Expand Down Expand Up @@ -60,6 +64,6 @@ public boolean needs_score() {
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx);
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
}
}
83 changes: 59 additions & 24 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -29,14 +30,16 @@ public interface KNNScoringSpace {
/**
* Return the correct scoring script for a given query. The scoring script
*
* @param params Map of parameters
* @param field Fieldname
* @param lookup SearchLookup
* @param ctx ctx LeafReaderContext to be used for scoring documents
* @param params Map of parameters
* @param field Fieldname
* @param lookup SearchLookup
* @param ctx ctx LeafReaderContext to be used for scoring documents
* @param searcher IndexSearcher
* @return ScoreScript for this query
* @throws IOException throws IOException if ScoreScript cannot be constructed
*/
ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException;
ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher)
throws IOException;

class L2 implements KNNScoringSpace {

Expand All @@ -62,9 +65,14 @@ public L2(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down Expand Up @@ -94,9 +102,14 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude);
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down Expand Up @@ -127,16 +140,22 @@ public HammingBit(Object query, MappedFieldType fieldType) {
}

@SuppressWarnings("unchecked")
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
if (this.processedQuery instanceof Long) {
return new KNNScoreScript.LongType(
params,
(Long) this.processedQuery,
field,
(BiFunction<Long, Long, Float>) this.scoringMethod,
lookup,
ctx
ctx,
searcher
);
}

Expand All @@ -146,7 +165,8 @@ public ScoreScript getScoreScript(Map<String, Object> params, String field, Sear
field,
(BiFunction<BigInteger, BigInteger, Float>) this.scoringMethod,
lookup,
ctx
ctx,
searcher
);
}
}
Expand Down Expand Up @@ -175,9 +195,14 @@ public L1(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down Expand Up @@ -205,9 +230,14 @@ public LInf(Object query, MappedFieldType fieldType) {
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down Expand Up @@ -238,9 +268,14 @@ public InnerProd(Object query, MappedFieldType fieldType) {
}

@Override
public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx)
throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx);
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}
}
27 changes: 15 additions & 12 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,33 @@ public void testFilteredSearchAdvanceSetting_whenValuesProvidedByUsers_thenValid

// validate if we are able to set MinValues for the setting
final Settings filteredSearchAdvanceSettingsWithMinValues = Settings.builder()
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThresholdMinValue)
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThresholdMinValue)
.build();
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThresholdMinValue)
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThresholdMinValue)
.build();

mockNode.client()
.admin()
.indices()
.updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithMinValues, INDEX_NAME))
.actionGet();
.admin()
.indices()
.updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithMinValues, INDEX_NAME))
.actionGet();

int filteredSearchThresholdPctMinValue = KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME);
int filteredSearchThresholdMinValue = KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME);

// Validate if less than MinValues are set then Exception Happens
final Settings filteredSearchAdvanceSettingsWithLessThanMinValues = Settings.builder()
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, -1)
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, -1)
.build();
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, -1)
.put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, -1)
.build();

Assert.assertThrows(IllegalArgumentException.class, () -> mockNode.client()
Assert.assertThrows(
IllegalArgumentException.class,
() -> mockNode.client()
.admin()
.indices()
.updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithLessThanMinValues, INDEX_NAME))
.actionGet());
.actionGet()
);

mockNode.close();
assertEquals(userDefinedPctThreshold, filteredSearchThresholdPct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces

when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length));


final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY);

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
Expand Down

0 comments on commit 50fae34

Please sign in to comment.