diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 6da4d84613b97..f9e4bc5643d25 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -35,9 +36,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.ONLY_SCORE_PRUNED_TOKENS_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.RATIO_THRESHOLD_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.WEIGHT_THRESHOLD_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.RATIO_THRESHOLD_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.WEIGHT_THRESHOLD_FIELD; public class TextExpansionQueryBuilder extends AbstractQueryBuilder { @@ -49,22 +50,13 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder weightedTokensSupplier; - private final int ratioThreshold; - private final float weightThreshold; - private final boolean onlyScorePrunedTokens; + private final WeightedTokenThreshold threshold; public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) { - this(fieldName, modelText, modelId, -1, 0f, false); + this(fieldName, modelText, modelId, null); } - public TextExpansionQueryBuilder( - String fieldName, - String modelText, - String modelId, - int ratioThreshold, - float weightThreshold, - boolean onlyScorePrunedTokens - ) { + public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokenThreshold threshold) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); } @@ -74,23 +66,10 @@ public TextExpansionQueryBuilder( if (modelId == null) { throw new IllegalArgumentException("[" + NAME + "] requires a " + MODEL_ID.getPreferredName() + " value"); } - if (weightThreshold < 0 || weightThreshold > 1) { - throw new IllegalArgumentException( - "[" - + NAME - + "] requires the " - + WEIGHT_THRESHOLD_FIELD.getPreferredName() - + " to be between 0 and 1, got " - + weightThreshold - ); - } - this.fieldName = fieldName; this.modelText = modelText; this.modelId = modelId; - this.ratioThreshold = ratioThreshold; - this.weightThreshold = weightThreshold; - this.onlyScorePrunedTokens = onlyScorePrunedTokens; + this.threshold = threshold; } public TextExpansionQueryBuilder(StreamInput in) throws IOException { @@ -99,13 +78,9 @@ public TextExpansionQueryBuilder(StreamInput in) throws IOException { this.modelText = in.readString(); this.modelId = in.readString(); if (in.getTransportVersion().onOrAfter(TransportVersions.WEIGHTED_TOKENS_QUERY_ADDED)) { - this.ratioThreshold = in.readInt(); - this.weightThreshold = in.readFloat(); - this.onlyScorePrunedTokens = in.readBoolean(); + this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new); } else { - this.ratioThreshold = 0; - this.weightThreshold = 1f; - this.onlyScorePrunedTokens = false; + this.threshold = null; } } @@ -113,9 +88,7 @@ private TextExpansionQueryBuilder(TextExpansionQueryBuilder other, SetOnce 0) { - builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold); - } - if (weightThreshold != 0) { - builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold); - } - if (onlyScorePrunedTokens) { - builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens); - } + threshold.toXContent(builder, params); boostAndQueryNameToXContent(builder); builder.endObject(); builder.endObject(); @@ -256,23 +199,14 @@ private QueryBuilder weightedTokensToQuery( TextExpansionResults textExpansionResults, QueryRewriteContext queryRewriteContext ) throws IOException { - if (ratioThreshold > 0) { - return new WeightedTokensQueryBuilder( - fieldName, - textExpansionResults.getWeightedTokens(), - ratioThreshold, - weightThreshold, - onlyScorePrunedTokens - ); + if (threshold != null) { + return new WeightedTokensQueryBuilder(fieldName, textExpansionResults.getWeightedTokens(), threshold); } var boolQuery = QueryBuilders.boolQuery(); for (var weightedToken : textExpansionResults.getWeightedTokens()) { boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); } boolQuery.minimumShouldMatch(1); - if (onlyScorePrunedTokens && ratioThreshold <= 0) { - return QueryBuilders.boolQuery().filter(boolQuery); - } return boolQuery; } @@ -286,15 +220,13 @@ protected boolean doEquals(TextExpansionQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(modelText, other.modelText) && Objects.equals(modelId, other.modelId) - && ratioThreshold == other.ratioThreshold - && Float.compare(weightThreshold, other.weightThreshold) == 0 - && onlyScorePrunedTokens == other.onlyScorePrunedTokens + && Objects.equals(threshold, other.threshold) && Objects.equals(weightedTokensSupplier, other.weightedTokensSupplier); } @Override protected int doHashCode() { - return Objects.hash(fieldName, modelText, modelId, ratioThreshold, weightThreshold, onlyScorePrunedTokens, weightedTokensSupplier); + return Objects.hash(fieldName, modelText, modelId, threshold, weightedTokensSupplier); } public static TextExpansionQueryBuilder fromXContent(XContentParser parser) throws IOException { @@ -364,9 +296,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro fieldName, modelText, modelId, - ratioThreshold, - weightThreshold, - onlyScorePrunedTokens + new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) ); queryBuilder.queryName(queryName); queryBuilder.boost(boost); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index 109b1b0c33b8f..82c19b4dab675 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; @@ -32,97 +33,47 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.*; + public class WeightedTokensQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "weighted_tokens"; public static final ParseField TOKENS_FIELD = new ParseField("tokens"); - public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold"); - public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold"); - public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens"); private final String fieldName; private final List tokens; - private final int ratioThreshold; - private final float weightThreshold; - private final boolean onlyScorePrunedTokens; + private final WeightedTokenThreshold threshold; public WeightedTokensQueryBuilder(String fieldName, List tokens) { - this(fieldName, tokens, -1, 0f, false); + this(fieldName, tokens, null); } - public WeightedTokensQueryBuilder( - String fieldName, - List tokens, - int ratioThreshold, - float weightThreshold, - boolean onlyScorePrunedTokens - ) { + public WeightedTokensQueryBuilder(String fieldName, List tokens, @Nullable WeightedTokenThreshold threshold) { this.fieldName = Objects.requireNonNull(fieldName, "[" + NAME + "] requires a fieldName"); this.tokens = Objects.requireNonNull(tokens, "[" + NAME + "] requires tokens"); - this.ratioThreshold = ratioThreshold; - this.weightThreshold = weightThreshold; - this.onlyScorePrunedTokens = onlyScorePrunedTokens; - - if (weightThreshold < 0 || weightThreshold > 1) { - throw new IllegalArgumentException( - "[" - + NAME - + "] requires the " - + WEIGHT_THRESHOLD_FIELD.getPreferredName() - + " to be between 0 and 1, got " - + weightThreshold - ); - } + this.threshold = threshold; } public WeightedTokensQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.tokens = in.readCollectionAsList(WeightedToken::new); - this.ratioThreshold = in.readInt(); - this.weightThreshold = in.readFloat(); - this.onlyScorePrunedTokens = in.readBoolean(); + this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new); } public String getFieldName() { return fieldName; } - /** - * Returns the frequency ratio threshold to apply on the query. - * Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified - * field are considered outliers and may be subject to removal from the query. - */ - public int getRatioThreshold() { - return ratioThreshold; - } - - /** - * Returns the weight threshold to apply on the query. - * Tokens whose weight is more than (weightThreshold * best_weight) of the highest weight in the query are not - * considered outliers, even if their frequency exceeds the specified ratio_threshold. - * This threshold ensures that important tokens, as indicated by their weight, are retained in the query. - */ - public float getWeightThreshold() { - return weightThreshold; - } - - /** - * Returns whether the filtering process retains tokens identified as non-relevant based on the specified thresholds - * (ratio and weight). When {@code true}, only non-relevant tokens are considered for matching and scoring documents. - * Enabling this option is valuable for re-scoring top hits retrieved from a {@link WeightedTokensQueryBuilder} with - * active thresholds. - */ - public boolean onlyScorePrunedTokens() { - return onlyScorePrunedTokens; + @Nullable + public WeightedTokenThreshold getThreshold() { + return threshold; } @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeCollection(tokens); - out.writeInt(ratioThreshold); - out.writeFloat(weightThreshold); - out.writeBoolean(onlyScorePrunedTokens); + out.writeOptionalWriteable(threshold); } @Override @@ -130,15 +81,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.startObject(NAME); builder.startObject(fieldName); builder.field(TOKENS_FIELD.getPreferredName(), tokens); - if (ratioThreshold > 0) { - builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold); - } - if (weightThreshold != 0) { - builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold); - } - if (onlyScorePrunedTokens) { - builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens); - } + threshold.toXContent(builder, params); boostAndQueryNameToXContent(builder); builder.endObject(); builder.endObject(); @@ -169,7 +112,7 @@ private boolean shouldKeepToken( float averageTokenFreqRatio, float bestWeight ) throws IOException { - if (ratioThreshold <= 0) { + if (threshold == null) { return true; } int docFreq = reader.docFreq(new Term(fieldName, token.token())); @@ -177,7 +120,8 @@ private boolean shouldKeepToken( return false; } float tokenFreqRatio = (float) docFreq / fieldDocCount; - return tokenFreqRatio < ratioThreshold * averageTokenFreqRatio || token.weight() > weightThreshold * bestWeight; + return tokenFreqRatio < threshold.getRatioThreshold() * averageTokenFreqRatio + || token.weight() > threshold.getWeightThreshold() * bestWeight; } @Override @@ -197,8 +141,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { return new MatchNoDocsQuery("The \"" + getName() + "\" query is against an empty field"); } for (var token : tokens) { - boolean keep = shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight) - ^ onlyScorePrunedTokens; + boolean keep = shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight) ^ threshold + .isOnlyScorePrunedTokens(); if (keep) { qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); } @@ -208,15 +152,12 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { @Override protected boolean doEquals(WeightedTokensQueryBuilder other) { - return Float.compare(weightThreshold, other.weightThreshold) == 0 - && ratioThreshold == other.ratioThreshold - && tokens.equals(other.tokens) - && onlyScorePrunedTokens == other.onlyScorePrunedTokens; + return Objects.equals(fieldName, other.fieldName) && Objects.equals(threshold, other.threshold) && tokens.equals(other.tokens); } @Override protected int doHashCode() { - return Objects.hash(tokens, ratioThreshold, weightThreshold, onlyScorePrunedTokens); + return Objects.hash(fieldName, tokens, threshold); } @Override @@ -291,7 +232,11 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query"); } - var qb = new WeightedTokensQueryBuilder(fieldName, tokens, ratioThreshold, weightThreshold, onlyScorePrunedTokens); + var qb = new WeightedTokensQueryBuilder( + fieldName, + tokens, + new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) + ); qb.queryName(queryName); qb.boost(boost); return qb; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java new file mode 100644 index 0000000000000..9a7f102d6f40e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class WeightedTokenThreshold implements Writeable, ToXContentFragment { + public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold"); + public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold"); + public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens"); + + private final float ratioThreshold; + private final float weightThreshold; + private final boolean onlyScorePrunedTokens; + + public WeightedTokenThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) { + if (ratioThreshold < 1) { + throw new IllegalArgumentException( + "[" + RATIO_THRESHOLD_FIELD.getPreferredName() + "] must be greater or equal to 1, got " + ratioThreshold + ); + } + if (weightThreshold < 0 || weightThreshold > 1) { + throw new IllegalArgumentException("[" + WEIGHT_THRESHOLD_FIELD.getPreferredName() + "] must be between 0 and 1"); + } + this.ratioThreshold = ratioThreshold; + this.weightThreshold = weightThreshold; + this.onlyScorePrunedTokens = onlyScorePrunedTokens; + } + + public WeightedTokenThreshold(StreamInput in) throws IOException { + this.ratioThreshold = in.readFloat(); + this.weightThreshold = in.readFloat(); + this.onlyScorePrunedTokens = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(ratioThreshold); + out.writeFloat(weightThreshold); + out.writeBoolean(onlyScorePrunedTokens); + } + + /** + * Returns the frequency ratio threshold to apply on the query. + * Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified + * field are considered outliers and may be subject to removal from the query. + */ + public float getRatioThreshold() { + return ratioThreshold; + } + + /** + * Returns the weight threshold to apply on the query. + * Tokens whose weight is more than (weightThreshold * best_weight) of the highest weight in the query are not + * considered outliers, even if their frequency exceeds the specified ratio_threshold. + * This threshold ensures that important tokens, as indicated by their weight, are retained in the query. + */ + public float getWeightThreshold() { + return weightThreshold; + } + + /** + * Returns whether the filtering process retains tokens identified as non-relevant based on the specified thresholds + * (ratio and weight). When {@code true}, only non-relevant tokens are considered for matching and scoring documents. + * Enabling this option is valuable for re-scoring top hits retrieved from a {@link WeightedTokensQueryBuilder} with + * active thresholds. + */ + public boolean isOnlyScorePrunedTokens() { + return onlyScorePrunedTokens; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedTokenThreshold that = (WeightedTokenThreshold) o; + return Float.compare(that.ratioThreshold, ratioThreshold) == 0 + && Float.compare(that.weightThreshold, weightThreshold) == 0 + && onlyScorePrunedTokens == that.onlyScorePrunedTokens; + } + + @Override + public int hashCode() { + return Objects.hash(ratioThreshold, weightThreshold, onlyScorePrunedTokens); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold); + builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold); + if (onlyScorePrunedTokens) { + builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens); + } + return builder; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 03901566d4bb5..c22da74037c83 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -46,15 +46,10 @@ public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase new TextExpansionQueryBuilder("field name", "model text", "model id", 10, 4, false) - ); - assertEquals("[text_expansion] requires the weight_threshold to be between 0 and 1, got 4", e.getMessage()); - } } public void testToXContent() throws IOException { @@ -177,7 +165,7 @@ public void testToXContent() throws IOException { } public void testToXContentWithThresholds() throws IOException { - QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", 4, 0.4f, false); + QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokenThreshold(4, 0.4f, false)); checkGeneratedJson(""" { "text_expansion": {