Skip to content

Commit

Permalink
[Feature] Expose term frequency in Painless script score context
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 2, 2023
1 parent 0003bd8 commit 7aff647
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public boolean needs_score() {

@Override
public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException {
return new ScoreScript(null, null, null) {
return new ScoreScript(null, null, null, null) {
// Fake the scorer until setScorer is called.
DoubleValues values = source.getValues(leaf, new DoubleValues() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.SpecialPermission;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map<String, Object> params, SearchLoo

contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() {
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return newScoreScript(expr, lookup, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService
} else if (scriptContext == ScoreScript.CONTEXT) {
return prepareRamIndex(request, (context, leafReaderContext) -> {
ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup());
ScoreScript.LeafFactory leafFactory = factory.newFactory(
request.getScript().getParams(),
context.lookup(),
context.searcher()
);
ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext);
scoreScript.setDocument(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import {
}

static_import {
int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq
float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF
long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq
long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq
double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils
double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected int doHashCode() {
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
return new ScriptScoreFunction(
script,
searchScript,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
);
}
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.Map;

public abstract class TermFrequencyFunction {

protected final String field;
protected final String term;
protected final int docId;
protected Map<Object, Object> context;

public TermFrequencyFunction(String field, String term, int docId, Map<Object, Object> context) {
this.field = field;
this.term = term;
this.docId = docId;
this.context = context;
}

public abstract Object execute(LeafReaderContext readerContext) throws IOException;

public static class TermFrequencyFunctionFactory {
public static TermFrequencyFunction createFunction(
TermFrequencyFunctionNamesEnum functionName,
String field,
String term,
int docId,
Map<Object, Object> context
) {
switch (functionName) {
case TERM_FREQ:
return new TermFreqFunction(field, term, docId, context);
case TF:
return new TFFunction(field, term, docId, context);
case TOTAL_TERM_FREQ:
return new TotalTermFreq(field, term, docId, context);
case SUM_TOTAL_TERM_FREQ:
return new SumTotalTermFreq(field, term, docId, context);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
}
}

public static class TermFreqFunction extends TermFrequencyFunction {

public TermFreqFunction(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Integer execute(LeafReaderContext readerContext) throws IOException {
TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
return valueSource.getValues(null, readerContext).intVal(docId);
}
}

public static class TFFunction extends TermFrequencyFunction {

public TFFunction(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Float execute(LeafReaderContext readerContext) throws IOException {
TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
return valueSource.getValues(context, readerContext).floatVal(docId);
}
}

public static class TotalTermFreq extends TermFrequencyFunction {

public TotalTermFreq(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Long execute(LeafReaderContext readerContext) throws IOException {
TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
valueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return valueSource.getValues(context, readerContext).longVal(docId);
}
}

public static class SumTotalTermFreq extends TermFrequencyFunction {

public SumTotalTermFreq(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Long execute(LeafReaderContext readerContext) throws IOException {
SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field);
valueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return valueSource.getValues(context, readerContext).longVal(docId);
}
}

public enum TermFrequencyFunctionNamesEnum {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

private final String termFrequencyFunctionName;

private TermFrequencyFunctionNamesEnum(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;
}
}
}
23 changes: 21 additions & 2 deletions server/src/main/java/org/opensearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorable;
import org.opensearch.Version;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum;
import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionFactory;

import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;
Expand Down Expand Up @@ -115,20 +119,24 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
private final IndexSearcher indexSearcher;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
if (lookup == null) {
assert params == null;
assert leafContext == null;
this.params = null;
this.leafLookup = null;
this.docBase = 0;
this.indexSearcher = null;
} else {
this.leafLookup = lookup.getLeafSearchLookup(leafContext);
params = new HashMap<>(params);
params.putAll(leafLookup.asMap());
this.params = new DynamicMap(params, PARAMS_FUNCTIONS);
this.docBase = leafContext.docBase;
this.indexSearcher = indexSearcher;
}
}

Expand All @@ -144,6 +152,17 @@ public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc();
}

public Object getTermFrequency(TermFrequencyFunctionNamesEnum functionName, String field, String val) throws IOException {
Map<Object, Object> context = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};
return leafLookup.executeTermFrequencyFunction(
TermFrequencyFunctionFactory.createFunction(functionName, field, val, docId, context)
);
}

/** Set the current document to run the script on next. */
public void setDocument(int docid) {
this.docId = docid;
Expand Down Expand Up @@ -268,7 +287,7 @@ public interface LeafFactory {
*/
public interface Factory extends ScriptFactory {

ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup);
ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher);

}

Expand Down
68 changes: 68 additions & 0 deletions server/src/main/java/org/opensearch/script/ScoreScriptUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
import java.time.ZoneId;

import static org.opensearch.common.util.BitMixer.mix32;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.SUM_TOTAL_TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TF;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TOTAL_TERM_FREQ;

/**
* Utilities for scoring scripts
Expand All @@ -69,6 +73,70 @@ public static double sigmoid(double value, double k, double a) {
return Math.pow(value, a) / (Math.pow(k, a) + Math.pow(value, a));
}

public static final class TermFreq {
private final ScoreScript scoreScript;

public TermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public int termFreq(String field, String term) {
try {
return (int) scoreScript.getTermFrequency(TERM_FREQ, field, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class TF {
private final ScoreScript scoreScript;

public TF(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public float tf(String field, String term) {
try {
return (float) scoreScript.getTermFrequency(TF, field, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class TotalTermFreq {
private final ScoreScript scoreScript;

public TotalTermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public long totalTermFreq(String field, String term) {
try {
return (long) scoreScript.getTermFrequency(TOTAL_TERM_FREQ, field, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class SumTotalTermFreq {
private final ScoreScript scoreScript;

public SumTotalTermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public long sumTotalTermFreq(String field) {
try {
return (long) scoreScript.getTermFrequency(SUM_TOTAL_TERM_FREQ, field, null);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
* random score based on the documents' values of the given field
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
package org.opensearch.search.lookup;

import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.query.functionscore.TermFrequencyFunction;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -87,4 +89,8 @@ public void setDocument(int docId) {
sourceLookup.setSegmentAndDocument(ctx, docId);
fieldsLookup.setDocument(docId);
}

public Object executeTermFrequencyFunction(TermFrequencyFunction function) throws IOException {
return function.execute(ctx);
}
}

0 comments on commit 7aff647

Please sign in to comment.