diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 6089755628ce7..e340a6888458c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -36,9 +36,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.RATIO_THRESHOLD_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.WEIGHT_THRESHOLD_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.TOKENS_THRESHOLD_FIELD; public class TextExpansionQueryBuilder extends AbstractQueryBuilder { @@ -98,11 +96,6 @@ String getFieldName() { return fieldName; } - /** - * Returns the frequency ratio threshold to apply on the query. - * Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified - * field are considered outliers and may be subject to removal from the query. - */ public WeightedTokensThreshold getThreshold() { return threshold; } @@ -136,7 +129,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.startObject(fieldName); builder.field(MODEL_TEXT.getPreferredName(), modelText); builder.field(MODEL_ID.getPreferredName(), modelId); - threshold.toXContent(builder, params); + if (threshold != null) { + threshold.toXContent(builder, params); + } boostAndQueryNameToXContent(builder); builder.endObject(); builder.endObject(); @@ -233,9 +228,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro String fieldName = null; String modelText = null; String modelId = null; - int ratioThreshold = 0; - float weightThreshold = 1f; - boolean onlyScorePrunedTokens = false; + WeightedTokensThreshold threshold = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; String queryName = null; String currentFieldName = null; @@ -249,17 +242,20 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + if (TOKENS_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + threshold = WeightedTokensThreshold.fromXContent(parser); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } } else if (token.isValue()) { if (MODEL_TEXT.match(currentFieldName, parser.getDeprecationHandler())) { modelText = parser.text(); } else if (MODEL_ID.match(currentFieldName, parser.getDeprecationHandler())) { modelId = parser.text(); - } else if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - ratioThreshold = parser.intValue(); - } else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - weightThreshold = parser.floatValue(); - } else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - onlyScorePrunedTokens = parser.booleanValue(); } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { boost = parser.floatValue(); } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -296,7 +292,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro fieldName, modelText, modelId, - new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) + threshold ); queryBuilder.queryName(queryName); queryBuilder.boost(boost); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index dafa732450336..960def8d551be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -33,7 +33,8 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.*; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.TOKENS_THRESHOLD_FIELD; + public class WeightedTokensQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "weighted_tokens"; @@ -81,7 +82,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.startObject(NAME); builder.startObject(fieldName); builder.field(TOKENS_FIELD.getPreferredName(), tokens); - threshold.toXContent(builder, params); + if (threshold != null) { + threshold.toXContent(builder, params); + } boostAndQueryNameToXContent(builder); builder.endObject(); builder.endObject(); @@ -186,9 +189,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr String currentFieldName = null; String fieldName = null; List tokens = new ArrayList<>(); - Integer ratioThreshold = null; - Float weightThreshold = 1f; - boolean onlyScorePrunedTokens = false; + WeightedTokensThreshold threshold = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; String queryName = null; XContentParser.Token token; @@ -201,17 +202,19 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); - } else if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - ratioThreshold = parser.intValue(); - } else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - weightThreshold = parser.floatValue(); + } else if (TOKENS_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + if (token != XContentParser.Token.START_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] should be an object" + ); + } + threshold = WeightedTokensThreshold.fromXContent(parser); } else if (TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { var tokensMap = parser.map(); for (var e : tokensMap.entrySet()) { tokens.add(new WeightedToken(e.getKey(), parseWeight(e.getKey(), e.getValue()))); } - } else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - onlyScorePrunedTokens = parser.booleanValue(); } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { boost = parser.floatValue(); } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -235,7 +238,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr var qb = new WeightedTokensQueryBuilder( fieldName, tokens, - new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) + threshold ); qb.queryName(queryName); qb.boost(boost); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java index 9945b210f9ac3..e02dd335d633f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java @@ -7,25 +7,35 @@ package org.elasticsearch.xpack.ml.queries; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; import java.util.Objects; -public class WeightedTokensThreshold implements Writeable, ToXContentFragment { +public class WeightedTokensThreshold implements Writeable, ToXContentObject { + public static final ParseField TOKENS_THRESHOLD_FIELD = new ParseField("tokens_threshold"); public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold"); public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold"); public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens"); + public static final float DEFAULT_RATIO_THRESHOLD = 5; + public static final float DEFAULT_WEIGHT_THRESHOLD = 0.4f; + private final float ratioThreshold; private final float weightThreshold; private final boolean onlyScorePrunedTokens; + public WeightedTokensThreshold() { + this(DEFAULT_RATIO_THRESHOLD, DEFAULT_WEIGHT_THRESHOLD, false); + } + public WeightedTokensThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) { if (ratioThreshold < 1) { throw new IllegalArgumentException( @@ -99,11 +109,45 @@ public int hashCode() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(TOKENS_THRESHOLD_FIELD.getPreferredName()); builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold); builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold); if (onlyScorePrunedTokens) { builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens); } + builder.endObject(); return builder; } + + public static WeightedTokensThreshold fromXContent(XContentParser parser) throws IOException { + String currentFieldName = null; + XContentParser.Token token; + float ratioThreshold = DEFAULT_RATIO_THRESHOLD; + float weightThreshold = DEFAULT_WEIGHT_THRESHOLD; + boolean onlyScorePrunedTokens = false; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + ratioThreshold = parser.intValue(); + } else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + weightThreshold = parser.floatValue(); + } else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + onlyScorePrunedTokens = parser.booleanValue(); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] does not support [" + currentFieldName + "]" + ); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } + } + return new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 7f937a2ef461e..2dd962bd4fe65 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -165,15 +165,17 @@ public void testToXContent() throws IOException { } public void testToXContentWithThresholds() throws IOException { - QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokensThreshold(4, 0.4f, false)); + QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokensThreshold(4, 0.3f, false)); checkGeneratedJson(""" { "text_expansion": { "foo": { "model_text": "bar", "model_id": "baz", - "ratio_threshold": 4, - "weight_threshold": 0.4 + "tokens_threshold": { + "ratio_threshold": 4, + "weight_threshold": 0.3 + } } } }""", query);