Skip to content

Commit

Permalink
Move all thresholds to its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Nov 30, 2023
1 parent 3d95e80 commit dd099de
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.RATIO_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.WEIGHT_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.RATIO_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.WEIGHT_THRESHOLD_FIELD;

public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansionQueryBuilder> {

Expand All @@ -50,13 +50,13 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansio
private final String modelText;
private final String modelId;
private SetOnce<TextExpansionResults> weightedTokensSupplier;
private final WeightedTokenThreshold threshold;
private final WeightedTokensThreshold threshold;

public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) {
this(fieldName, modelText, modelId, null);
}

public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokenThreshold threshold) {
public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable WeightedTokensThreshold threshold) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a fieldName");
}
Expand All @@ -78,7 +78,7 @@ public TextExpansionQueryBuilder(StreamInput in) throws IOException {
this.modelText = in.readString();
this.modelId = in.readString();
if (in.getTransportVersion().onOrAfter(TransportVersions.WEIGHTED_TOKENS_QUERY_ADDED)) {
this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new);
this.threshold = in.readOptionalWriteable(WeightedTokensThreshold::new);
} else {
this.threshold = null;
}
Expand All @@ -103,7 +103,7 @@ String getFieldName() {
* Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified
* field are considered outliers and may be subject to removal from the query.
*/
public WeightedTokenThreshold getThreshold() {
public WeightedTokensThreshold getThreshold() {
return threshold;
}

Expand Down Expand Up @@ -296,7 +296,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro
fieldName,
modelText,
modelId,
new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
);
queryBuilder.queryName(queryName);
queryBuilder.boost(boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.ml.queries.WeightedTokenThreshold.*;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.*;

public class WeightedTokensQueryBuilder extends AbstractQueryBuilder<WeightedTokensQueryBuilder> {
public static final String NAME = "weighted_tokens";

public static final ParseField TOKENS_FIELD = new ParseField("tokens");
private final String fieldName;
private final List<WeightedToken> tokens;
private final WeightedTokenThreshold threshold;
private final WeightedTokensThreshold threshold;

public WeightedTokensQueryBuilder(String fieldName, List<WeightedToken> tokens) {
this(fieldName, tokens, null);
}

public WeightedTokensQueryBuilder(String fieldName, List<WeightedToken> tokens, @Nullable WeightedTokenThreshold threshold) {
public WeightedTokensQueryBuilder(String fieldName, List<WeightedToken> tokens, @Nullable WeightedTokensThreshold threshold) {
this.fieldName = Objects.requireNonNull(fieldName, "[" + NAME + "] requires a fieldName");
this.tokens = Objects.requireNonNull(tokens, "[" + NAME + "] requires tokens");
this.threshold = threshold;
Expand All @@ -57,15 +57,15 @@ public WeightedTokensQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.tokens = in.readCollectionAsList(WeightedToken::new);
this.threshold = in.readOptionalWriteable(WeightedTokenThreshold::new);
this.threshold = in.readOptionalWriteable(WeightedTokensThreshold::new);
}

public String getFieldName() {
return fieldName;
}

@Nullable
public WeightedTokenThreshold getThreshold() {
public WeightedTokensThreshold getThreshold() {
return threshold;
}

Expand Down Expand Up @@ -186,8 +186,8 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr
String currentFieldName = null;
String fieldName = null;
List<WeightedToken> tokens = new ArrayList<>();
int ratioThreshold = 0;
float weightThreshold = 1f;
Integer ratioThreshold = null;
Float weightThreshold = 1f;
boolean onlyScorePrunedTokens = false;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
String queryName = null;
Expand Down Expand Up @@ -235,7 +235,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr
var qb = new WeightedTokensQueryBuilder(
fieldName,
tokens,
new WeightedTokenThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
);
qb.queryName(queryName);
qb.boost(boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.io.IOException;
import java.util.Objects;

public class WeightedTokenThreshold implements Writeable, ToXContentFragment {
public class WeightedTokensThreshold implements Writeable, ToXContentFragment {
public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold");
public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold");
public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens");
Expand All @@ -26,7 +26,7 @@ public class WeightedTokenThreshold implements Writeable, ToXContentFragment {
private final float weightThreshold;
private final boolean onlyScorePrunedTokens;

public WeightedTokenThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) {
public WeightedTokensThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) {
if (ratioThreshold < 1) {
throw new IllegalArgumentException(
"[" + RATIO_THRESHOLD_FIELD.getPreferredName() + "] must be greater or equal to 1, got " + ratioThreshold
Expand All @@ -40,7 +40,7 @@ public WeightedTokenThreshold(float ratioThreshold, float weightThreshold, boole
this.onlyScorePrunedTokens = onlyScorePrunedTokens;
}

public WeightedTokenThreshold(StreamInput in) throws IOException {
public WeightedTokensThreshold(StreamInput in) throws IOException {
this.ratioThreshold = in.readFloat();
this.weightThreshold = in.readFloat();
this.onlyScorePrunedTokens = in.readBoolean();
Expand Down Expand Up @@ -86,7 +86,7 @@ public boolean isOnlyScorePrunedTokens() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedTokenThreshold that = (WeightedTokenThreshold) o;
WeightedTokensThreshold that = (WeightedTokensThreshold) o;
return Float.compare(that.ratioThreshold, ratioThreshold) == 0
&& Float.compare(that.weightThreshold, weightThreshold) == 0
&& onlyScorePrunedTokens == that.onlyScorePrunedTokens;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase<TextEx

@Override
protected TextExpansionQueryBuilder doCreateTestQueryBuilder() {
WeightedTokenThreshold threshold = rarely()
? new WeightedTokenThreshold(randomIntBetween(1, 100), randomFloat(), randomBoolean())
WeightedTokensThreshold threshold = rarely()
? new WeightedTokensThreshold(randomIntBetween(1, 100), randomFloat(), randomBoolean())
: null;
var builder = new TextExpansionQueryBuilder(RANK_FEATURES_FIELD, randomAlphaOfLength(4), randomAlphaOfLength(4), threshold);
if (randomBoolean()) {
Expand Down Expand Up @@ -165,7 +165,7 @@ public void testToXContent() throws IOException {
}

public void testToXContentWithThresholds() throws IOException {
QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokenThreshold(4, 0.4f, false));
QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokensThreshold(4, 0.4f, false));
checkGeneratedJson("""
{
"text_expansion": {
Expand Down

0 comments on commit dd099de

Please sign in to comment.