diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java index 3651a7354e5de..f329677a94340 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchType; @@ -93,7 +94,7 @@ public String getType() { public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { assert scriptSource.equals("explainable_script"); assert context == ScoreScript.CONTEXT; - ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() { + ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() { @Override public boolean needs_score() { return false; @@ -101,7 +102,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new MyScript(params1, lookup, ctx); + return new MyScript(params1, lookup, indexSearcher, ctx); } }; return context.factoryClazz.cast(factory); @@ -117,8 +118,8 @@ public Set> getSupportedContexts() { static class MyScript extends ScoreScript implements ExplainableScoreScript { - MyScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { - super(params, lookup, leafContext); + MyScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { + super(params, lookup, indexSearcher, leafContext); } @Override diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java index 6a8b7d61fd095..95fbecc53f4ae 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java @@ -8,145 +8,15 @@ package org.opensearch.index.query.functionscore; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; -import org.apache.lucene.queries.function.valuesource.TFValueSource; -import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; -import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; -import org.apache.lucene.search.IndexSearcher; -import org.opensearch.common.lucene.BytesRefs; - import java.io.IOException; -import java.util.Map; /** - * Abstract class representing a term frequency function. + * An interface representing a term frequency function used to compute document scores + * based on specific term frequency calculations. Implementations of this interface should + * provide a way to execute the term frequency function for a given document ID. + * + * @opensearch.internal */ -public abstract class TermFrequencyFunction { - - protected final String field; - protected final String term; - protected final int docId; - protected Map context; - - public TermFrequencyFunction(String field, String term, int docId, Map context) { - this.field = field; - this.term = term; - this.docId = docId; - this.context = context; - } - - public abstract Object execute(LeafReaderContext readerContext) throws IOException; - - /** - * Factory class to create term frequency functions. - */ - public static class TermFrequencyFunctionFactory { - public static TermFrequencyFunction createFunction( - TermFrequencyFunctionNamesEnum functionName, - String field, - String term, - int docId, - Map context - ) { - switch (functionName) { - case TERM_FREQ: - return new TermFreqFunction(field, term, docId, context); - case TF: - return new TFFunction(field, term, docId, context); - case TOTAL_TERM_FREQ: - return new TotalTermFreq(field, term, docId, context); - case SUM_TOTAL_TERM_FREQ: - return new SumTotalTermFreq(field, term, docId, context); - default: - throw new IllegalArgumentException("Unsupported function: " + functionName); - } - } - } - - /** - * TermFreqFunction computes the term frequency in a field. - */ - public static class TermFreqFunction extends TermFrequencyFunction { - - public TermFreqFunction(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Integer execute(LeafReaderContext readerContext) throws IOException { - TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); - return valueSource.getValues(null, readerContext).intVal(docId); - } - } - - /** - * TFFunction computes the term frequency-inverse document frequency (tf-idf) in a field. - */ - public static class TFFunction extends TermFrequencyFunction { - - public TFFunction(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Float execute(LeafReaderContext readerContext) throws IOException { - TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); - return valueSource.getValues(context, readerContext).floatVal(docId); - } - } - - /** - * TotalTermFreq computes the total term frequency in a field. - */ - public static class TotalTermFreq extends TermFrequencyFunction { - - public TotalTermFreq(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Long execute(LeafReaderContext readerContext) throws IOException { - TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); - valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); - return valueSource.getValues(context, readerContext).longVal(docId); - } - } - - /** - * SumTotalTermFreq computes the sum of total term frequencies within a field. - */ - public static class SumTotalTermFreq extends TermFrequencyFunction { - - public SumTotalTermFreq(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Long execute(LeafReaderContext readerContext) throws IOException { - SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field); - valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); - return valueSource.getValues(context, readerContext).longVal(docId); - } - } - - /** - * Enum representing the names of term frequency functions. - */ - public enum TermFrequencyFunctionNamesEnum { - TERM_FREQ("termFreq"), - TF("tf"), - TOTAL_TERM_FREQ("totalTermFreq"), - SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); - - private final String termFrequencyFunctionName; - - private TermFrequencyFunctionNamesEnum(String termFrequencyFunctionName) { - this.termFrequencyFunctionName = termFrequencyFunctionName; - } - - public String getTermFrequencyFunctionName() { - return termFrequencyFunctionName; - } - } +public interface TermFrequencyFunction { + Object execute(int docId) throws IOException; } diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java new file mode 100644 index 0000000000000..a00df0e626726 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query.functionscore; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TFValueSource; +import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.common.lucene.BytesRefs; + +import java.io.IOException; +import java.util.Map; + +/** + * A factory class for creating instances of {@link TermFrequencyFunction}. + * This class provides methods for creating different term frequency functions based on + * the specified function name, field, and term. Each term frequency function is designed + * to compute document scores based on specific term frequency calculations. + * + * @opensearch.internal + */ +public class TermFrequencyFunctionFactory { + + public static TermFrequencyFunction createFunction( + TermFrequencyFunctionName functionName, + Map context, + String field, + String term, + LeafReaderContext readerContext + ) throws IOException { + switch (functionName) { + case TERM_FREQ: + TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return docId -> termFreqValueSource.getValues(null, readerContext).intVal(docId); + case TF: + TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return docId -> tfValueSource.getValues(context, readerContext).floatVal(docId); + case TOTAL_TERM_FREQ: + TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource( + field, + term, + field, + BytesRefs.toBytesRef(term) + ); + totalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return docId -> totalTermFreqValueSource.getValues(context, readerContext).longVal(docId); + case SUM_TOTAL_TERM_FREQ: + SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field); + sumTotalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return docId -> sumTotalTermFreqValueSource.getValues(context, readerContext).longVal(docId); + default: + throw new IllegalArgumentException("Unsupported function: " + functionName); + } + } + + /** + * An enumeration representing the names of supported term frequency functions. + */ + public enum TermFrequencyFunctionName { + TERM_FREQ("termFreq"), + TF("tf"), + TOTAL_TERM_FREQ("totalTermFreq"), + SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); + + private final String termFrequencyFunctionName; + + TermFrequencyFunctionName(String termFrequencyFunctionName) { + this.termFrequencyFunctionName = termFrequencyFunctionName; + } + + public String getTermFrequencyFunctionName() { + return termFrequencyFunctionName; + } + } +} diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 34a035f389a3e..a81b7a5bb97cd 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -38,8 +38,7 @@ import org.opensearch.Version; import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum; -import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionFactory; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName; import org.opensearch.search.lookup.LeafSearchLookup; import org.opensearch.search.lookup.SearchLookup; @@ -48,6 +47,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import java.util.function.DoubleSupplier; import java.util.function.Function; @@ -121,6 +121,8 @@ public Explanation get(double score, Explanation subQueryExplanation) { private final IndexSearcher indexSearcher; + private final Map termFreqCache = new HashMap<>(); + public ScoreScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass if (lookup == null) { @@ -152,16 +154,21 @@ public Map> getDoc() { return leafLookup.doc(); } - public Object getTermFrequency(TermFrequencyFunctionNamesEnum functionName, String field, String val) throws IOException { - // Fetch data from local cache - Map context = new HashMap<>() { - { - put("searcher", indexSearcher); - } - }; - return leafLookup.executeTermFrequencyFunction( - TermFrequencyFunctionFactory.createFunction(functionName, field, val, docId, context) - ); + public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val) throws IOException { + String cacheKey = (val == null) ? String.format(Locale.ROOT, "%s-%s", functionName, field) : String.format(Locale.ROOT, "%s-%s-%s", functionName, field, val); + + if (!termFreqCache.containsKey(cacheKey)) { + Map context = new HashMap<>() { + { + put("searcher", indexSearcher); + } + }; + + Object termFrequency = leafLookup.getTermFrequency(functionName, context, field, val, docId); + termFreqCache.put(cacheKey, termFrequency); + } + + return termFreqCache.get(cacheKey); } /** Set the current document to run the script on next. */ diff --git a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java index 8f19714c3ca99..f71fdeadfe8b0 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java +++ b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java @@ -47,10 +47,10 @@ import java.time.ZoneId; import static org.opensearch.common.util.BitMixer.mix32; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.SUM_TOTAL_TERM_FREQ; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TERM_FREQ; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TF; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.SUM_TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TF; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TOTAL_TERM_FREQ; /** * Utilities for scoring scripts diff --git a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java index 26b3b34f929cd..7a513f311082e 100644 --- a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java +++ b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.query.functionscore.TermFrequencyFunction; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory; import java.io.IOException; import java.util.HashMap; @@ -90,7 +91,15 @@ public void setDocument(int docId) { fieldsLookup.setDocument(docId); } - public Object executeTermFrequencyFunction(TermFrequencyFunction function) throws IOException { - return function.execute(ctx); + public Object getTermFrequency( + TermFrequencyFunctionFactory.TermFrequencyFunctionName functionName, + Map context, + String field, + String val, + int docId + ) throws IOException { + TermFrequencyFunction termFreqFunction = TermFrequencyFunctionFactory.createFunction(functionName, context, field, val, ctx); + // execute the function + return termFreqFunction.execute(docId); } }