Skip to content

Commit

Permalink
Move thresholds to its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Nov 30, 2023
1 parent 6d678ac commit 3d95e80
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
Expand All @@ -35,9 +36,9 @@

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.ONLY_SCORE_PRUNED_TOKENS_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.RATIO_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.WEIGHT_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.RATIO_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.WEIGHT_THRESHOLD_FIELD;

public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansionQueryBuilder> {

Expand All @@ -49,22 +50,13 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansio
private final String modelText;
private final String modelId;
private SetOnce<TextExpansionResults> weightedTokensSupplier;
private final int ratioThreshold;
private final float weightThreshold;
private final boolean onlyScorePrunedTokens;
private final WeightedTokenThreshold threshold;

public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) {
this(fieldName, modelText, modelId, -1, 0f, false);
this(fieldName, modelText, modelId, null);
}

public TextExpansionQueryBuilder(
String fieldName,
String modelText,
String modelId,
int ratioThreshold,
float weightThreshold,
boolean onlyScorePrunedTokens
) {
public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokenThreshold threshold) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a fieldName");
}
Expand All @@ -74,23 +66,10 @@ public TextExpansionQueryBuilder(
if (modelId == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a " + MODEL_ID.getPreferredName() + " value");
}
if (weightThreshold < 0 || weightThreshold > 1) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires the "
+ WEIGHT_THRESHOLD_FIELD.getPreferredName()
+ " to be between 0 and 1, got "
+ weightThreshold
);
}

this.fieldName = fieldName;
this.modelText = modelText;
this.modelId = modelId;
this.ratioThreshold = ratioThreshold;
this.weightThreshold = weightThreshold;
this.onlyScorePrunedTokens = onlyScorePrunedTokens;
this.threshold = threshold;
}

public TextExpansionQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -99,23 +78,17 @@ public TextExpansionQueryBuilder(StreamInput in) throws IOException {
this.modelText = in.readString();
this.modelId = in.readString();
if (in.getTransportVersion().onOrAfter(TransportVersions.WEIGHTED_TOKENS_QUERY_ADDED)) {
this.ratioThreshold = in.readInt();
this.weightThreshold = in.readFloat();
this.onlyScorePrunedTokens = in.readBoolean();
this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new);
} else {
this.ratioThreshold = 0;
this.weightThreshold = 1f;
this.onlyScorePrunedTokens = false;
this.threshold = null;
}
}

private TextExpansionQueryBuilder(TextExpansionQueryBuilder other, SetOnce<TextExpansionResults> weightedTokensSupplier) {
this.fieldName = other.fieldName;
this.modelText = other.modelText;
this.modelId = other.modelId;
this.ratioThreshold = other.ratioThreshold;
this.weightThreshold = other.weightThreshold;
this.onlyScorePrunedTokens = other.onlyScorePrunedTokens;
this.threshold = other.threshold;
this.boost = other.boost;
this.queryName = other.queryName;
this.weightedTokensSupplier = weightedTokensSupplier;
Expand All @@ -130,28 +103,8 @@ String getFieldName() {
* Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified
* field are considered outliers and may be subject to removal from the query.
*/
public int getRatioThreshold() {
return ratioThreshold;
}

/**
* Returns the weight threshold to apply on the query.
* Tokens whose weight is more than (weightThreshold * best_weight) of the highest weight in the query are not
* considered outliers, even if their frequency exceeds the specified ratio_threshold.
* This threshold ensures that important tokens, as indicated by their weight, are retained in the query.
*/
public float getWeightThreshold() {
return weightThreshold;
}

/**
* Returns whether the filtering process retains tokens identified as non-relevant based on the specified thresholds
* (ratio and weight). When {@code true}, only non-relevant tokens are considered for matching and scoring documents.
* Enabling this option is valuable for re-scoring top hits retrieved from a {@link WeightedTokensQueryBuilder} with
* active thresholds.
*/
public boolean onlyScorePrunedTokens() {
return onlyScorePrunedTokens;
public WeightedTokenThreshold getThreshold() {
return threshold;
}

@Override
Expand All @@ -173,9 +126,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(modelText);
out.writeString(modelId);
if (out.getTransportVersion().onOrAfter(TransportVersions.WEIGHTED_TOKENS_QUERY_ADDED)) {
out.writeInt(ratioThreshold);
out.writeFloat(weightThreshold);
out.writeBoolean(onlyScorePrunedTokens);
out.writeOptionalWriteable(threshold);
}
}

Expand All @@ -185,15 +136,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.startObject(fieldName);
builder.field(MODEL_TEXT.getPreferredName(), modelText);
builder.field(MODEL_ID.getPreferredName(), modelId);
if (ratioThreshold > 0) {
builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold);
}
if (weightThreshold != 0) {
builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold);
}
if (onlyScorePrunedTokens) {
builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens);
}
threshold.toXContent(builder, params);
boostAndQueryNameToXContent(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -256,23 +199,14 @@ private QueryBuilder weightedTokensToQuery(
TextExpansionResults textExpansionResults,
QueryRewriteContext queryRewriteContext
) throws IOException {
if (ratioThreshold > 0) {
return new WeightedTokensQueryBuilder(
fieldName,
textExpansionResults.getWeightedTokens(),
ratioThreshold,
weightThreshold,
onlyScorePrunedTokens
);
if (threshold != null) {
return new WeightedTokensQueryBuilder(fieldName, textExpansionResults.getWeightedTokens(), threshold);
}
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));
}
boolQuery.minimumShouldMatch(1);
if (onlyScorePrunedTokens && ratioThreshold <= 0) {
return QueryBuilders.boolQuery().filter(boolQuery);
}
return boolQuery;
}

Expand All @@ -286,15 +220,13 @@ protected boolean doEquals(TextExpansionQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
&& Objects.equals(modelText, other.modelText)
&& Objects.equals(modelId, other.modelId)
&& ratioThreshold == other.ratioThreshold
&& Float.compare(weightThreshold, other.weightThreshold) == 0
&& onlyScorePrunedTokens == other.onlyScorePrunedTokens
&& Objects.equals(threshold, other.threshold)
&& Objects.equals(weightedTokensSupplier, other.weightedTokensSupplier);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, modelText, modelId, ratioThreshold, weightThreshold, onlyScorePrunedTokens, weightedTokensSupplier);
return Objects.hash(fieldName, modelText, modelId, threshold, weightedTokensSupplier);
}

public static TextExpansionQueryBuilder fromXContent(XContentParser parser) throws IOException {
Expand Down Expand Up @@ -364,9 +296,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro
fieldName,
modelText,
modelId,
ratioThreshold,
weightThreshold,
onlyScorePrunedTokens
new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
);
queryBuilder.queryName(queryName);
queryBuilder.boost(boost);
Expand Down
Loading

0 comments on commit 3d95e80

Please sign in to comment.