Skip to content

Commit

Permalink
Add the thresholds in a high level object
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Nov 30, 2023
1 parent dd099de commit 5dea9d2
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.ONLY_SCORE_PRUNED_TOKENS_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.RATIO_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.WEIGHT_THRESHOLD_FIELD;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.TOKENS_THRESHOLD_FIELD;

public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansionQueryBuilder> {

Expand Down Expand Up @@ -98,11 +96,6 @@ String getFieldName() {
return fieldName;
}

/**
* Returns the frequency ratio threshold to apply on the query.
* Tokens whose frequency is more than ratio_threshold times the average frequency of all tokens in the specified
* field are considered outliers and may be subject to removal from the query.
*/
public WeightedTokensThreshold getThreshold() {
return threshold;
}
Expand Down Expand Up @@ -136,7 +129,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.startObject(fieldName);
builder.field(MODEL_TEXT.getPreferredName(), modelText);
builder.field(MODEL_ID.getPreferredName(), modelId);
threshold.toXContent(builder, params);
if (threshold != null) {
threshold.toXContent(builder, params);
}
boostAndQueryNameToXContent(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -233,9 +228,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro
String fieldName = null;
String modelText = null;
String modelId = null;
int ratioThreshold = 0;
float weightThreshold = 1f;
boolean onlyScorePrunedTokens = false;
WeightedTokensThreshold threshold = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
String queryName = null;
String currentFieldName = null;
Expand All @@ -249,17 +242,20 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
if (TOKENS_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
threshold = WeightedTokensThreshold.fromXContent(parser);
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
} else if (token.isValue()) {
if (MODEL_TEXT.match(currentFieldName, parser.getDeprecationHandler())) {
modelText = parser.text();
} else if (MODEL_ID.match(currentFieldName, parser.getDeprecationHandler())) {
modelId = parser.text();
} else if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
ratioThreshold = parser.intValue();
} else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
weightThreshold = parser.floatValue();
} else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
onlyScorePrunedTokens = parser.booleanValue();
} else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
Expand Down Expand Up @@ -296,7 +292,7 @@ public static TextExpansionQueryBuilder fromXContent(XContentParser parser) thro
fieldName,
modelText,
modelId,
new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
threshold
);
queryBuilder.queryName(queryName);
queryBuilder.boost(boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.*;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensThreshold.TOKENS_THRESHOLD_FIELD;


public class WeightedTokensQueryBuilder extends AbstractQueryBuilder<WeightedTokensQueryBuilder> {
public static final String NAME = "weighted_tokens";
Expand Down Expand Up @@ -81,7 +82,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.startObject(NAME);
builder.startObject(fieldName);
builder.field(TOKENS_FIELD.getPreferredName(), tokens);
threshold.toXContent(builder, params);
if (threshold != null) {
threshold.toXContent(builder, params);
}
boostAndQueryNameToXContent(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -186,9 +189,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr
String currentFieldName = null;
String fieldName = null;
List<WeightedToken> tokens = new ArrayList<>();
Integer ratioThreshold = null;
Float weightThreshold = 1f;
boolean onlyScorePrunedTokens = false;
WeightedTokensThreshold threshold = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
String queryName = null;
XContentParser.Token token;
Expand All @@ -201,17 +202,19 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
ratioThreshold = parser.intValue();
} else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
weightThreshold = parser.floatValue();
} else if (TOKENS_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
if (token != XContentParser.Token.START_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] should be an object"
);
}
threshold = WeightedTokensThreshold.fromXContent(parser);
} else if (TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
var tokensMap = parser.map();
for (var e : tokensMap.entrySet()) {
tokens.add(new WeightedToken(e.getKey(), parseWeight(e.getKey(), e.getValue())));
}
} else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
onlyScorePrunedTokens = parser.booleanValue();
} else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
Expand All @@ -235,7 +238,7 @@ public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) thr
var qb = new WeightedTokensQueryBuilder(
fieldName,
tokens,
new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens)
threshold
);
qb.queryName(queryName);
qb.boost(boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

package org.elasticsearch.xpack.ml.queries;

import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class WeightedTokensThreshold implements Writeable, ToXContentFragment {
public class WeightedTokensThreshold implements Writeable, ToXContentObject {
public static final ParseField TOKENS_THRESHOLD_FIELD = new ParseField("tokens_threshold");
public static final ParseField RATIO_THRESHOLD_FIELD = new ParseField("ratio_threshold");
public static final ParseField WEIGHT_THRESHOLD_FIELD = new ParseField("weight_threshold");
public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens");

public static final float DEFAULT_RATIO_THRESHOLD = 5;
public static final float DEFAULT_WEIGHT_THRESHOLD = 0.4f;

private final float ratioThreshold;
private final float weightThreshold;
private final boolean onlyScorePrunedTokens;

public WeightedTokensThreshold() {
this(DEFAULT_RATIO_THRESHOLD, DEFAULT_WEIGHT_THRESHOLD, false);
}

public WeightedTokensThreshold(float ratioThreshold, float weightThreshold, boolean onlyScorePrunedTokens) {
if (ratioThreshold < 1) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -99,11 +109,45 @@ public int hashCode() {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(TOKENS_THRESHOLD_FIELD.getPreferredName());
builder.field(RATIO_THRESHOLD_FIELD.getPreferredName(), ratioThreshold);
builder.field(WEIGHT_THRESHOLD_FIELD.getPreferredName(), weightThreshold);
if (onlyScorePrunedTokens) {
builder.field(ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(), onlyScorePrunedTokens);
}
builder.endObject();
return builder;
}

public static WeightedTokensThreshold fromXContent(XContentParser parser) throws IOException {
String currentFieldName = null;
XContentParser.Token token;
float ratioThreshold = DEFAULT_RATIO_THRESHOLD;
float weightThreshold = DEFAULT_WEIGHT_THRESHOLD;
boolean onlyScorePrunedTokens = false;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue()) {
if (RATIO_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
ratioThreshold = parser.intValue();
} else if (WEIGHT_THRESHOLD_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
weightThreshold = parser.floatValue();
} else if (ONLY_SCORE_PRUNED_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
onlyScorePrunedTokens = parser.booleanValue();
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] does not support [" + currentFieldName + "]"
);
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + TOKENS_THRESHOLD_FIELD.getPreferredName() + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
return new WeightedTokensThreshold(ratioThreshold, weightThreshold, onlyScorePrunedTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,17 @@ public void testToXContent() throws IOException {
}

public void testToXContentWithThresholds() throws IOException {
QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokensThreshold(4, 0.4f, false));
QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new WeightedTokensThreshold(4, 0.3f, false));
checkGeneratedJson("""
{
"text_expansion": {
"foo": {
"model_text": "bar",
"model_id": "baz",
"ratio_threshold": 4,
"weight_threshold": 0.4
"tokens_threshold": {
"ratio_threshold": 4,
"weight_threshold": 0.3
}
}
}
}""", query);
Expand Down

0 comments on commit 5dea9d2

Please sign in to comment.