From 7278800de51c39c988f7f0711b813ace871f27e8 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 4 Aug 2023 09:48:46 -0700 Subject: [PATCH] Address comments from Froh and Ankit Signed-off-by: Louis Chu --- .../functionscore/TermFrequencyFunction.java | 145 +----------------- .../TermFrequencyFunctionFactory.java | 83 ++++++++++ .../org/opensearch/script/ScoreScript.java | 30 ++-- .../opensearch/script/ScoreScriptUtils.java | 8 +- .../search/lookup/LeafSearchLookup.java | 13 +- 5 files changed, 124 insertions(+), 155 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java index 6a8b7d61fd095..b03794e822b2f 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java @@ -8,145 +8,16 @@ package org.opensearch.index.query.functionscore; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; -import org.apache.lucene.queries.function.valuesource.TFValueSource; -import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; -import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; -import org.apache.lucene.search.IndexSearcher; -import org.opensearch.common.lucene.BytesRefs; - import java.io.IOException; -import java.util.Map; /** - * Abstract class representing a term frequency function. + * An interface representing a term frequency function used to compute document scores + * based on specific term frequency calculations. Implementations of this interface should + * provide a way to execute the term frequency function for a given document ID. + * + * @throws IOException If an I/O error occurs while computing the term frequency. + * @opensearch.internal */ -public abstract class TermFrequencyFunction { - - protected final String field; - protected final String term; - protected final int docId; - protected Map context; - - public TermFrequencyFunction(String field, String term, int docId, Map context) { - this.field = field; - this.term = term; - this.docId = docId; - this.context = context; - } - - public abstract Object execute(LeafReaderContext readerContext) throws IOException; - - /** - * Factory class to create term frequency functions. - */ - public static class TermFrequencyFunctionFactory { - public static TermFrequencyFunction createFunction( - TermFrequencyFunctionNamesEnum functionName, - String field, - String term, - int docId, - Map context - ) { - switch (functionName) { - case TERM_FREQ: - return new TermFreqFunction(field, term, docId, context); - case TF: - return new TFFunction(field, term, docId, context); - case TOTAL_TERM_FREQ: - return new TotalTermFreq(field, term, docId, context); - case SUM_TOTAL_TERM_FREQ: - return new SumTotalTermFreq(field, term, docId, context); - default: - throw new IllegalArgumentException("Unsupported function: " + functionName); - } - } - } - - /** - * TermFreqFunction computes the term frequency in a field. - */ - public static class TermFreqFunction extends TermFrequencyFunction { - - public TermFreqFunction(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Integer execute(LeafReaderContext readerContext) throws IOException { - TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); - return valueSource.getValues(null, readerContext).intVal(docId); - } - } - - /** - * TFFunction computes the term frequency-inverse document frequency (tf-idf) in a field. - */ - public static class TFFunction extends TermFrequencyFunction { - - public TFFunction(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Float execute(LeafReaderContext readerContext) throws IOException { - TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); - return valueSource.getValues(context, readerContext).floatVal(docId); - } - } - - /** - * TotalTermFreq computes the total term frequency in a field. - */ - public static class TotalTermFreq extends TermFrequencyFunction { - - public TotalTermFreq(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Long execute(LeafReaderContext readerContext) throws IOException { - TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); - valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); - return valueSource.getValues(context, readerContext).longVal(docId); - } - } - - /** - * SumTotalTermFreq computes the sum of total term frequencies within a field. - */ - public static class SumTotalTermFreq extends TermFrequencyFunction { - - public SumTotalTermFreq(String field, String term, int docId, Map context) { - super(field, term, docId, context); - } - - @Override - public Long execute(LeafReaderContext readerContext) throws IOException { - SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field); - valueSource.createWeight(context, (IndexSearcher) context.get("searcher")); - return valueSource.getValues(context, readerContext).longVal(docId); - } - } - - /** - * Enum representing the names of term frequency functions. - */ - public enum TermFrequencyFunctionNamesEnum { - TERM_FREQ("termFreq"), - TF("tf"), - TOTAL_TERM_FREQ("totalTermFreq"), - SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); - - private final String termFrequencyFunctionName; - - private TermFrequencyFunctionNamesEnum(String termFrequencyFunctionName) { - this.termFrequencyFunctionName = termFrequencyFunctionName; - } - - public String getTermFrequencyFunctionName() { - return termFrequencyFunctionName; - } - } +public interface TermFrequencyFunction { + Object execute(int docId) throws IOException; } diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java new file mode 100644 index 0000000000000..a00df0e626726 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query.functionscore; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TFValueSource; +import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.common.lucene.BytesRefs; + +import java.io.IOException; +import java.util.Map; + +/** + * A factory class for creating instances of {@link TermFrequencyFunction}. + * This class provides methods for creating different term frequency functions based on + * the specified function name, field, and term. Each term frequency function is designed + * to compute document scores based on specific term frequency calculations. + * + * @opensearch.internal + */ +public class TermFrequencyFunctionFactory { + + public static TermFrequencyFunction createFunction( + TermFrequencyFunctionName functionName, + Map context, + String field, + String term, + LeafReaderContext readerContext + ) throws IOException { + switch (functionName) { + case TERM_FREQ: + TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return docId -> termFreqValueSource.getValues(null, readerContext).intVal(docId); + case TF: + TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); + return docId -> tfValueSource.getValues(context, readerContext).floatVal(docId); + case TOTAL_TERM_FREQ: + TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource( + field, + term, + field, + BytesRefs.toBytesRef(term) + ); + totalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return docId -> totalTermFreqValueSource.getValues(context, readerContext).longVal(docId); + case SUM_TOTAL_TERM_FREQ: + SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field); + sumTotalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher")); + return docId -> sumTotalTermFreqValueSource.getValues(context, readerContext).longVal(docId); + default: + throw new IllegalArgumentException("Unsupported function: " + functionName); + } + } + + /** + * An enumeration representing the names of supported term frequency functions. + */ + public enum TermFrequencyFunctionName { + TERM_FREQ("termFreq"), + TF("tf"), + TOTAL_TERM_FREQ("totalTermFreq"), + SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); + + private final String termFrequencyFunctionName; + + TermFrequencyFunctionName(String termFrequencyFunctionName) { + this.termFrequencyFunctionName = termFrequencyFunctionName; + } + + public String getTermFrequencyFunctionName() { + return termFrequencyFunctionName; + } + } +} diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 34a035f389a3e..71b4b99b2bd4e 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -38,8 +38,7 @@ import org.opensearch.Version; import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum; -import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionFactory; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName; import org.opensearch.search.lookup.LeafSearchLookup; import org.opensearch.search.lookup.SearchLookup; @@ -121,6 +120,8 @@ public Explanation get(double score, Explanation subQueryExplanation) { private final IndexSearcher indexSearcher; + private final Map termFreqCache = new HashMap<>(); + public ScoreScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass if (lookup == null) { @@ -152,16 +153,21 @@ public Map> getDoc() { return leafLookup.doc(); } - public Object getTermFrequency(TermFrequencyFunctionNamesEnum functionName, String field, String val) throws IOException { - // Fetch data from local cache - Map context = new HashMap<>() { - { - put("searcher", indexSearcher); - } - }; - return leafLookup.executeTermFrequencyFunction( - TermFrequencyFunctionFactory.createFunction(functionName, field, val, docId, context) - ); + public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val) throws IOException { + String cacheKey = (val == null) ? String.format("%s-%s", functionName, field) : String.format("%s-%s-%s", functionName, field, val); + + if (!termFreqCache.containsKey(cacheKey)) { + Map context = new HashMap<>() { + { + put("searcher", indexSearcher); + } + }; + + Object termFrequency = leafLookup.getTermFrequency(functionName, context, field, val, docId); + termFreqCache.put(cacheKey, termFrequency); + } + + return termFreqCache.get(cacheKey); } /** Set the current document to run the script on next. */ diff --git a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java index 8f19714c3ca99..f71fdeadfe8b0 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java +++ b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java @@ -47,10 +47,10 @@ import java.time.ZoneId; import static org.opensearch.common.util.BitMixer.mix32; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.SUM_TOTAL_TERM_FREQ; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TERM_FREQ; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TF; -import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.SUM_TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TF; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TOTAL_TERM_FREQ; /** * Utilities for scoring scripts diff --git a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java index 26b3b34f929cd..7a513f311082e 100644 --- a/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java +++ b/server/src/main/java/org/opensearch/search/lookup/LeafSearchLookup.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.query.functionscore.TermFrequencyFunction; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory; import java.io.IOException; import java.util.HashMap; @@ -90,7 +91,15 @@ public void setDocument(int docId) { fieldsLookup.setDocument(docId); } - public Object executeTermFrequencyFunction(TermFrequencyFunction function) throws IOException { - return function.execute(ctx); + public Object getTermFrequency( + TermFrequencyFunctionFactory.TermFrequencyFunctionName functionName, + Map context, + String field, + String val, + int docId + ) throws IOException { + TermFrequencyFunction termFreqFunction = TermFrequencyFunctionFactory.createFunction(functionName, context, field, val, ctx); + // execute the function + return termFreqFunction.execute(docId); } }