From dd099de285b5108eccb5cbbcde8cf615dc966615 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 30 Nov 2023 18:40:49 +0000 Subject: [PATCH] Move all thresholds to its own class --- .../ml/queries/TextExpansionQueryBuilder.java | 16 ++++++++-------- .../ml/queries/WeightedTokensQueryBuilder.java | 16 ++++++++-------- .../ml/queries/WeightedTokensThreshold.java | 8 ++++---- .../queries/TextExpansionQueryBuilderTests.java | 6 +++--- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index f9e4bc5643d25..6089755628ce7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -36,9 +36,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.RATIO_THRESHOLD_FIELD; -import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.WEIGHT_THRESHOLD_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.RATIO_THRESHOLD_FIELD; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.WEIGHT_THRESHOLD_FIELD; public class TextExpansionQueryBuilder extends AbstractQueryBuilder { @@ -50,13 +50,13 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder weightedTokensSupplier; - private final WeightedTokenThreshold threshold; + private final WeightedTokensThreshold threshold; public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) { this(fieldName, modelText, modelId, null); } - public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokenThreshold threshold) { + public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokensThreshold threshold) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); } @@ -78,7 +78,7 @@ public TextExpansionQueryBuilder(StreamInput in) throws IOException { this.modelText = in.readString(); this.modelId = in.readString(); if (in.getTransportVersion().onOrAfter(TransportVersions.WEIGHTED_TOKENS_QUERY_ADDED)) { - this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new); + this.threshold = in.readOptionalWriteable(WeightedTokensThreshold::new); } else { this.threshold = null; } @@ -103,7 +103,7 @@ String getFieldName() { * Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified * field are considered outliers and may be subject to removal from the query. */ - public WeightedTokenThreshold getThreshold() { + public WeightedTokensThreshold getThreshold() { return threshold; } @@ -296,7 +296,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro fieldName, modelText, modelId, - new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) + new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) ); queryBuilder.queryName(queryName); queryBuilder.boost(boost); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index 82c19b4dab675..dafa732450336 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.*; +import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.*; public class WeightedTokensQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "weighted_tokens"; @@ -41,13 +41,13 @@ public class WeightedTokensQueryBuilder extends AbstractQueryBuilder tokens; - private final WeightedTokenThreshold threshold; + private final WeightedTokensThreshold threshold; public WeightedTokensQueryBuilder(String fieldName, List tokens) { this(fieldName, tokens, null); } - public WeightedTokensQueryBuilder(String fieldName, List tokens, @Nullable WeightedTokenThreshold threshold) { + public WeightedTokensQueryBuilder(String fieldName, List tokens, @Nullable WeightedTokensThreshold threshold) { this.fieldName = Objects.requireNonNull(fieldName, "[" + NAME + "] requires a fieldName"); this.tokens = Objects.requireNonNull(tokens, "[" + NAME + "] requires tokens"); this.threshold = threshold; @@ -57,7 +57,7 @@ public WeightedTokensQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.tokens = in.readCollectionAsList(WeightedToken::new); - this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new); + this.threshold = in.readOptionalWriteable(WeightedTokensThreshold::new); } public String getFieldName() { @@ -65,7 +65,7 @@ public String getFieldName() { } @Nullable - public WeightedTokenThreshold getThreshold() { + public WeightedTokensThreshold getThreshold() { return threshold; } @@ -186,8 +186,8 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr String currentFieldName = null; String fieldName = null; List tokens = new ArrayList<>(); - int ratioThreshold = 0; - float weightThreshold = 1f; + Integer ratioThreshold = null; + Float weightThreshold = 1f; boolean onlyScorePrunedTokens = false; float boost = AbstractQueryBuilder.DEFAULT_BOOST; String queryName = null; @@ -235,7 +235,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr var qb = new WeightedTokensQueryBuilder( fieldName, tokens, - new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) + new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens) ); qb.queryName(queryName); qb.boost(boost); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java index 9a7f102d6f40e..9945b210f9ac3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensThreshold.java @@ -17,7 +17,7 @@ import java.io.IOException; import java.util.Objects; -public class WeightedTokenThreshold implements Writeable, ToXContentFragment { +public class WeightedTokensThreshold implements Writeable, ToXContentFragment { public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold"); public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold"); public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens"); @@ -26,7 +26,7 @@ public class WeightedTokenThreshold implements Writeable, ToXContentFragment { private final float weightThreshold; private final boolean onlyScorePrunedTokens; - public WeightedTokenThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) { + public WeightedTokensThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) { if (ratioThreshold < 1) { throw new IllegalArgumentException( "[" + RATIO_THRESHOLD_FIELD.getPreferredName() + "] must be greater or equal to 1, got " + ratioThreshold @@ -40,7 +40,7 @@ public WeightedTokenThreshold(float ratioThreshold, float weightThreshold, boole this.onlyScorePrunedTokens = onlyScorePrunedTokens; } - public WeightedTokenThreshold(StreamInput in) throws IOException { + public WeightedTokensThreshold(StreamInput in) throws IOException { this.ratioThreshold = in.readFloat(); this.weightThreshold = in.readFloat(); this.onlyScorePrunedTokens = in.readBoolean(); @@ -86,7 +86,7 @@ public boolean isOnlyScorePrunedTokens() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - WeightedTokenThreshold that = (WeightedTokenThreshold) o; + WeightedTokensThreshold that = (WeightedTokensThreshold) o; return Float.compare(that.ratioThreshold, ratioThreshold) == 0 && Float.compare(that.weightThreshold, weightThreshold) == 0 && onlyScorePrunedTokens == that.onlyScorePrunedTokens; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index c22da74037c83..7f937a2ef461e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -46,8 +46,8 @@ public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase