diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index 000129f8b054d..7581aa3280b8c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -50,7 +50,8 @@ public WeightedTokensQueryBuilder(String fieldName, public WeightedTokensQueryBuilder(String fieldName, List tokens, - int ratioThreshold, float weightThreshold) { + int ratioThreshold, + float weightThreshold) { this.fieldName = Objects.requireNonNull(fieldName, "[" + NAME + "] requires a fieldName"); this.tokens = Objects.requireNonNull(tokens, "[" + NAME + "] requires tokens"); this.ratioThreshold = ratioThreshold; @@ -115,7 +116,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.endObject(); } - private float getAverageTokenFreq(IndexReader reader) throws IOException { + private float getAverageTokenFreqRatio(IndexReader reader, int fieldDocCount) throws IOException { int numUniqueTokens = 0; for (var leaf : reader.getContext().leaves()) { var terms = leaf.reader().terms(fieldName); @@ -126,7 +127,7 @@ private float getAverageTokenFreq(IndexReader reader) throws IOException { if (numUniqueTokens == 0) { return 0; } - return (float) reader.getSumDocFreq(fieldName) / reader.getDocCount(fieldName) / numUniqueTokens; + return (float) reader.getSumDocFreq(fieldName) / fieldDocCount / numUniqueTokens; } /** @@ -136,21 +137,24 @@ private float getAverageTokenFreq(IndexReader reader) throws IOException { private boolean shouldKeepToken(IndexReader reader, WeightedToken token, int fieldDocCount, - float averageTokenFreq, + float averageTokenFreqRatio, float bestWeight) throws IOException { + if (ratioThreshold <= 0) { + return true; + } int docFreq = reader.docFreq(new Term(fieldName, token.token())); if (docFreq == 0) { return false; } - float tokenFreq = (float) docFreq / fieldDocCount; - return tokenFreq < ratioThreshold * averageTokenFreq + float tokenFreqRatio = (float) docFreq / fieldDocCount; + return tokenFreqRatio < ratioThreshold * averageTokenFreqRatio || token.weight() > weightThreshold * bestWeight; } @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - final MappedFieldType mapper = context.getFieldType(fieldName); - if (mapper == null) { + final MappedFieldType ft = context.getFieldType(fieldName); + if (ft == null) { return new MatchNoDocsQuery("The \"" + getName() + "\" query is against a field that does not exist"); } var qb = new BooleanQuery.Builder(); @@ -159,13 +163,13 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { for (var t : tokens) { bestWeight = Math.max(t.weight(), bestWeight); } - float averageTokenFreq = getAverageTokenFreq(context.getIndexReader()); - if (averageTokenFreq == 0) { + float averageTokenFreqRatio = getAverageTokenFreqRatio(context.getIndexReader(), fieldDocCount); + if (averageTokenFreqRatio == 0) { return new MatchNoDocsQuery("The \"" + getName() + "\" query is against an empty field"); } for (var token : tokens) { - if (shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreq, bestWeight)) { - qb.add(new BoostQuery(mapper.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); + if (shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight)) { + qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); } } return qb.build();