From 7aff64736771863339198eb7aa196f802ba687c0 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 28 Jul 2023 07:57:24 -0700 Subject: [PATCH] [Feature] Expose term frequency in Painless script score context Signed-off-by: Louis Chu --- .../expression/ExpressionScoreScript.java | 2 +- .../expression/ExpressionScriptEngine.java | 3 +- .../action/PainlessExecuteAction.java | 6 +- .../painless/spi/org.opensearch.score.txt | 4 + .../ScriptScoreFunctionBuilder.java | 2 +- .../ScriptScoreQueryBuilder.java | 2 +- .../functionscore/TermFrequencyFunction.java | 131 ++++++++++++++++++ .../org/opensearch/script/ScoreScript.java | 23 ++- .../opensearch/script/ScoreScriptUtils.java | 68 +++++++++ .../search/lookup/LeafSearchLookup.java | 6 + 10 files changed, 240 insertions(+), 7 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java diff --git a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java index 6be299146a181..3932559f7685c 100644 --- a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java +++ b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java @@ -66,7 +66,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException { - return new ScoreScript(null, null, null) { + return new ScoreScript(null, null, null, null) { // Fake the scorer until setScorer is called. DoubleValues values = source.getValues(leaf, new DoubleValues() { @Override diff --git a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java index 1c3dc69359952..035d2402857e0 100644 --- a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java +++ b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java @@ -37,6 +37,7 @@ import org.apache.lucene.expressions.js.JavascriptCompiler; import org.apache.lucene.expressions.js.VariableContext; import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.SpecialPermission; import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map params, SearchLoo contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() { @Override - public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { + public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher) { return newScoreScript(expr, lookup, params); } diff --git a/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java b/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java index 9a2c8c1f0aa55..ceb9d459e4eb9 100644 --- a/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java +++ b/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService } else if (scriptContext == ScoreScript.CONTEXT) { return prepareRamIndex(request, (context, leafReaderContext) -> { ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup()); + ScoreScript.LeafFactory leafFactory = factory.newFactory( + request.getScript().getParams(), + context.lookup(), + context.searcher() + ); ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext); scoreScript.setDocument(0); diff --git a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt index cca7e07a95388..24649ca78b354 100644 --- a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt +++ b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import { } static_import { + int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq + float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF + long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq + long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index a8c27d468a8f2..3615668d68260 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -114,7 +114,7 @@ protected int doHashCode() { protected ScoreFunction doToFunction(QueryShardContext context) { try { ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); + ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher()); return new ScriptScoreFunction( script, searchScript, diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java index 8d67a4be38dfb..9bb41588fb53c 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { ); } ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup()); + ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher()); final QueryBuilder queryBuilder = this.query; Query query = queryBuilder.toQuery(context); return new ScriptScoreQuery( diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java new file mode 100644 index 0000000000000..71123e956e36e --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java @@ -0,0 +1,131 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query.functionscore; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TFValueSource; +import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.common.lucene.BytesRefs; + +import java.io.IOException; +import java.util.Map; + +public abstract class TermFrequencyFunction { + + protected final String field; + protected final String term; + protected final int docId; + protected Map context; + + public TermFrequencyFunction(String field, String term, int docId, Map context) { + this.field = field; + this.term = term; + this.docId = docId; + this.context = context; + } + + public abstract Object execute(LeafReaderContext readerContext) throws IOException; + + public static class TermFrequencyFunctionFactory { + public static TermFrequencyFunction createFunction( + TermFrequencyFunctionNamesEnum functionName, + String field, + String term, + int docId, + Map context + ) { + switch (functionName) { + case TERM_FREQ: + return new TermFreqFunction(field, term, docId, context); + case TF: + return new TFFunction(field, term, docId, context); + case TOTAL_TERM_FREQ: + return new TotalTermFreq(field, term, docId, context); + case SUM_TOTAL_TERM_FREQ: + return new SumTotalTermFreq(field, term, docId, context); + default: + throw new IllegalArgumentException("Unsupported function: " + functionName); + } + } + } + + public static class TermFreqFunction extends TermFrequencyFunction { + + public TermFreqFunction(String field, String term, int docId, Map context) { + super(field, term, docId, context); + } + + @Override + public Integer execute(LeafReaderContext readerContext) throws IOException { + TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return valueSource.getValues(null, readerContext).intVal(docId); + } + } + + public static class TFFunction extends TermFrequencyFunction { + + public TFFunction(String field, String term, int docId, Map context) { + super(field, term, docId, context); + } + + @Override + public Float execute(LeafReaderContext readerContext) throws IOException { + TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return valueSource.getValues(context, readerContext).floatVal(docId); + } + } + + public static class TotalTermFreq extends TermFrequencyFunction { + + public TotalTermFreq(String field, String term, int docId, Map context) { + super(field, term, docId, context); + } + + @Override + public Long execute(LeafReaderContext readerContext) throws IOException { + TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); + valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return valueSource.getValues(context, readerContext).longVal(docId); + } + } + + public static class SumTotalTermFreq extends TermFrequencyFunction { + + public SumTotalTermFreq(String field, String term, int docId, Map context) { + super(field, term, docId, context); + } + + @Override + public Long execute(LeafReaderContext readerContext) throws IOException { + SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field); + valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return valueSource.getValues(context, readerContext).longVal(docId); + } + } + + public enum TermFrequencyFunctionNamesEnum { + TERM_FREQ("termFreq"), + TF("tf"), + TOTAL_TERM_FREQ("totalTermFreq"), + SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); + + private final String termFrequencyFunctionName; + + private TermFrequencyFunctionNamesEnum(String termFrequencyFunctionName) { + this.termFrequencyFunctionName = termFrequencyFunctionName; + } + + public String getTermFrequencyFunctionName() { + return termFrequencyFunctionName; + } + } +} diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 5c6553ffc2a28..3bf788f7cebb6 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -33,10 +33,14 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Scorable; import org.opensearch.Version; import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.index.fielddata.ScriptDocValues; +import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum; +import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionFactory; + import org.opensearch.search.lookup.LeafSearchLookup; import org.opensearch.search.lookup.SearchLookup; import org.opensearch.search.lookup.SourceLookup; @@ -115,7 +119,9 @@ public Explanation get(double score, Explanation subQueryExplanation) { private String indexName = null; private Version indexVersion = null; - public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + private final IndexSearcher indexSearcher; + + public ScoreScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass if (lookup == null) { assert params == null; @@ -123,12 +129,14 @@ public ScoreScript(Map params, SearchLookup lookup, LeafReaderCo this.params = null; this.leafLookup = null; this.docBase = 0; + this.indexSearcher = null; } else { this.leafLookup = lookup.getLeafSearchLookup(leafContext); params = new HashMap<>(params); params.putAll(leafLookup.asMap()); this.params = new DynamicMap(params, PARAMS_FUNCTIONS); this.docBase = leafContext.docBase; + this.indexSearcher = indexSearcher; } } @@ -144,6 +152,17 @@ public Map> getDoc() { return leafLookup.doc(); } + public Object getTermFrequency(TermFrequencyFunctionNamesEnum functionName, String field, String val) throws IOException { + Map context = new HashMap<>() { + { + put("searcher", indexSearcher); + } + }; + return leafLookup.executeTermFrequencyFunction( + TermFrequencyFunctionFactory.createFunction(functionName, field, val, docId, context) + ); + } + /** Set the current document to run the script on next. */ public void setDocument(int docid) { this.docId = docid; @@ -268,7 +287,7 @@ public interface LeafFactory { */ public interface Factory extends ScriptFactory { - ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup); + ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher); } diff --git a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java index b94ff77a1d0b7..b55cbd2c91af1 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java +++ b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java @@ -47,6 +47,10 @@ import java.time.ZoneId; import static org.opensearch.common.util.BitMixer.mix32; +import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.SUM_TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TF; +import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TOTAL_TERM_FREQ; /** * Utilities for scoring scripts @@ -69,6 +73,70 @@ public static double sigmoid(double value, double k, double a) { return Math.pow(value, a) / (Math.pow(k, a) + Math.pow(value, a)); } + public static final class TermFreq { + private final ScoreScript scoreScript; + + public TermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public int termFreq(String field, String term) { + try { + return (int) scoreScript.getTermFrequency(TERM_FREQ, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + public static final class TF { + private final ScoreScript scoreScript; + + public TF(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public float tf(String field, String term) { + try { + return (float) scoreScript.getTermFrequency(TF, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + public static final class TotalTermFreq { + private final ScoreScript scoreScript; + + public TotalTermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public long totalTermFreq(String field, String term) { + try { + return (long) scoreScript.getTermFrequency(TOTAL_TERM_FREQ, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + public static final class SumTotalTermFreq { + private final ScoreScript scoreScript; + + public SumTotalTermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public long sumTotalTermFreq(String field) { + try { + return (long) scoreScript.getTermFrequency(SUM_TOTAL_TERM_FREQ, field, null); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + /** * random score based on the documents' values of the given field * diff --git a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java index 1c87f26053060..26b3b34f929cd 100644 --- a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java +++ b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java @@ -33,7 +33,9 @@ package org.opensearch.search.lookup; import org.apache.lucene.index.LeafReaderContext; +import org.opensearch.index.query.functionscore.TermFrequencyFunction; +import java.io.IOException; import java.util.HashMap; import java.util.Map; @@ -87,4 +89,8 @@ public void setDocument(int docId) { sourceLookup.setSegmentAndDocument(ctx, docId); fieldsLookup.setDocument(docId); } + + public Object executeTermFrequencyFunction(TermFrequencyFunction function) throws IOException { + return function.execute(ctx); + } }