From 50fae344401157df85cc7c32c5af16fe245cabdb Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 25 Aug 2023 09:44:38 -0700 Subject: [PATCH] Fixing the backward incompatible changes coming from core in ScoreScript class. (#1056) (#1063) Signed-off-by: Navneet Verma --- .../knn/plugin/script/KNNScoreScript.java | 21 +++-- .../plugin/script/KNNScoreScriptFactory.java | 8 +- .../knn/plugin/script/KNNScoringSpace.java | 83 +++++++++++++------ .../knn/index/KNNSettingsTests.java | 27 +++--- .../knn/index/query/KNNWeightTests.java | 1 - 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index f190a3e1d..d7a84817b 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.fielddata.ScriptDocValues; @@ -32,9 +33,10 @@ public KNNScoreScript( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) { - super(params, lookup, leafContext); + super(params, lookup, searcher, leafContext); this.queryValue = queryValue; this.field = field; this.scoringMethod = scoringMethod; @@ -51,9 +53,10 @@ public LongType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** @@ -84,9 +87,10 @@ public BigIntegerType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** @@ -118,9 +122,10 @@ public KNNVectorType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) throws IOException { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java index b686a20f0..63b367b2d 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.script.ScoreScript; @@ -21,13 +22,16 @@ public class KNNScoreScriptFactory implements ScoreScript.LeafFactory { private Object query; private KNNScoringSpace knnScoringSpace; - public KNNScoreScriptFactory(Map params, SearchLookup lookup) { + private IndexSearcher searcher; + + public KNNScoreScriptFactory(Map params, SearchLookup lookup, IndexSearcher searcher) { KNNCounter.SCRIPT_QUERY_REQUESTS.increment(); this.params = params; this.lookup = lookup; this.field = getValue(params, "field").toString(); this.similaritySpace = getValue(params, "space_type").toString(); this.query = getValue(params, "query_value"); + this.searcher = searcher; this.knnScoringSpace = KNNScoringSpaceFactory.create( this.similaritySpace, @@ -60,6 +64,6 @@ public boolean needs_score() { */ @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return knnScoringSpace.getScoreScript(params, field, lookup, ctx); + return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 16bf6e204..0e4c9f815 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; import org.apache.lucene.index.LeafReaderContext; @@ -29,14 +30,16 @@ public interface KNNScoringSpace { /** * Return the correct scoring script for a given query. The scoring script * - * @param params Map of parameters - * @param field Fieldname - * @param lookup SearchLookup - * @param ctx ctx LeafReaderContext to be used for scoring documents + * @param params Map of parameters + * @param field Fieldname + * @param lookup SearchLookup + * @param ctx ctx LeafReaderContext to be used for scoring documents + * @param searcher IndexSearcher * @return ScoreScript for this query * @throws IOException throws IOException if ScoreScript cannot be constructed */ - ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException; + ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) + throws IOException; class L2 implements KNNScoringSpace { @@ -62,9 +65,14 @@ public L2(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } @@ -94,9 +102,14 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } @@ -127,8 +140,13 @@ public HammingBit(Object query, MappedFieldType fieldType) { } @SuppressWarnings("unchecked") - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { if (this.processedQuery instanceof Long) { return new KNNScoreScript.LongType( params, @@ -136,7 +154,8 @@ public ScoreScript getScoreScript(Map params, String field, Sear field, (BiFunction) this.scoringMethod, lookup, - ctx + ctx, + searcher ); } @@ -146,7 +165,8 @@ public ScoreScript getScoreScript(Map params, String field, Sear field, (BiFunction) this.scoringMethod, lookup, - ctx + ctx, + searcher ); } } @@ -175,9 +195,14 @@ public L1(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } @@ -205,9 +230,14 @@ public LInf(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } @@ -238,9 +268,14 @@ public InnerProd(Object query, MappedFieldType fieldType) { } @Override - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + public ScoreScript getScoreScript( + Map params, + String field, + SearchLookup lookup, + LeafReaderContext ctx, + IndexSearcher searcher + ) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } } } diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index 9432be33e..17a58cbca 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -122,30 +122,33 @@ public void testFilteredSearchAdvanceSetting_whenValuesProvidedByUsers_thenValid // validate if we are able to set MinValues for the setting final Settings filteredSearchAdvanceSettingsWithMinValues = Settings.builder() - .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThresholdMinValue) - .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThresholdMinValue) - .build(); + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThresholdMinValue) + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThresholdMinValue) + .build(); mockNode.client() - .admin() - .indices() - .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithMinValues, INDEX_NAME)) - .actionGet(); + .admin() + .indices() + .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithMinValues, INDEX_NAME)) + .actionGet(); int filteredSearchThresholdPctMinValue = KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME); int filteredSearchThresholdMinValue = KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME); // Validate if less than MinValues are set then Exception Happens final Settings filteredSearchAdvanceSettingsWithLessThanMinValues = Settings.builder() - .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, -1) - .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, -1) - .build(); + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, -1) + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, -1) + .build(); - Assert.assertThrows(IllegalArgumentException.class, () -> mockNode.client() + Assert.assertThrows( + IllegalArgumentException.class, + () -> mockNode.client() .admin() .indices() .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithLessThanMinValues, INDEX_NAME)) - .actionGet()); + .actionGet() + ); mockNode.close(); assertEquals(userDefinedPctThreshold, filteredSearchThresholdPct); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 9cc624377..6b7bb3208 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -487,7 +487,6 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length)); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);