Skip to content

Commit

Permalink
Make search functions translation aware (elastic#118355)
Browse files Browse the repository at this point in the history
* Introduce TranslationAware interface

* Serialize query builder

* Fix EsqlNodeSubclassTests

* Add javadoc

* Address review comments

* Revert changes on making constructors private
  • Loading branch information
ioanatia authored Dec 13, 2024
1 parent 48c892c commit a765f89
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ADD_DATA_STREAM_OPTIONS_TO_TEMPLATES = def(8_805_00_0);
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_806_00_0);
public static final TransportVersion SEMANTIC_QUERY_LENIENT = def(8_807_00_0);
public static final TransportVersion ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS = def(8_808_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.core.expression;

import org.elasticsearch.xpack.esql.core.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;

/**
* Expressions can implement this interface to control how they would be translated and pushed down as Lucene queries.
* When an expression implements {@link TranslationAware}, we call {@link #asQuery(TranslatorHandler)} to get the
* {@link Query} translation, instead of relying on the registered translators from EsqlExpressionTranslators.
*/
public interface TranslationAware {
Query asQuery(TranslatorHandler translatorHandler);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.core.querydsl.query;

import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.core.tree.Source;

/**
* Expressions that store their own {@link QueryBuilder} and implement
* {@link org.elasticsearch.xpack.esql.core.expression.TranslationAware} can use {@link TranslationAwareExpressionQuery}
* to wrap their {@link QueryBuilder}, instead of using the other existing {@link Query} implementations.
*/
public class TranslationAwareExpressionQuery extends Query {
private final QueryBuilder queryBuilder;

public TranslationAwareExpressionQuery(Source source, QueryBuilder queryBuilder) {
super(source);
this.queryBuilder = queryBuilder;
}

@Override
public QueryBuilder asBuilder() {
return queryBuilder;
}

@Override
protected String innerToString() {
return queryBuilder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ private static FunctionDefinition[][] functions() {
def(MvSum.class, MvSum::new, "mv_sum"),
def(Split.class, Split::new, "split") },
// fulltext functions
new FunctionDefinition[] { def(Match.class, Match::new, "match"), def(QueryString.class, QueryString::new, "qstr") } };
new FunctionDefinition[] { def(Match.class, bi(Match::new), "match"), def(QueryString.class, uni(QueryString::new), "qstr") } };

}

Expand All @@ -426,9 +426,9 @@ private static FunctionDefinition[][] snapshotFunctions() {
// The delay() function is for debug/snapshot environments only and should never be enabled in a non-snapshot build.
// This is an experimental function and can be removed without notice.
def(Delay.class, Delay::new, "delay"),
def(Kql.class, Kql::new, "kql"),
def(Kql.class, uni(Kql::new), "kql"),
def(Rate.class, Rate::withUnresolvedTimestamp, "rate"),
def(Term.class, Term::new, "term") } };
def(Term.class, bi(Term::new), "term") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
package org.elasticsearch.xpack.esql.expression.function.fulltext;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.planner.ExpressionTranslator;
import org.elasticsearch.xpack.esql.core.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.querydsl.query.TranslationAwareExpressionQuery;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
Expand All @@ -26,13 +33,15 @@
* These functions needs to be pushed down to Lucene queries to be executed - there's no Evaluator for them, but depend on
* {@link org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer} to rewrite them into Lucene queries.
*/
public abstract class FullTextFunction extends Function {
public abstract class FullTextFunction extends Function implements TranslationAware {

private final Expression query;
private final QueryBuilder queryBuilder;

protected FullTextFunction(Source source, Expression query, List<Expression> children) {
protected FullTextFunction(Source source, Expression query, List<Expression> children, QueryBuilder queryBuilder) {
super(source, children);
this.query = query;
this.queryBuilder = queryBuilder;
}

@Override
Expand Down Expand Up @@ -116,4 +125,37 @@ public Nullability nullable() {
public String functionType() {
return "function";
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), queryBuilder);
}

@Override
public boolean equals(Object obj) {
if (false == super.equals(obj)) {
return false;
}

return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder);
}

@Override
public Query asQuery(TranslatorHandler translatorHandler) {
if (queryBuilder != null) {
return new TranslationAwareExpressionQuery(source(), queryBuilder);
}

ExpressionTranslator<? extends FullTextFunction> translator = translator();
return translator.translate(this, translatorHandler);
}

public QueryBuilder queryBuilder() {
return queryBuilder;
}

@SuppressWarnings("rawtypes")
protected abstract ExpressionTranslator<? extends FullTextFunction> translator();

public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

package org.elasticsearch.xpack.esql.expression.function.fulltext;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.planner.ExpressionTranslator;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.EsqlExpressionTranslators;
import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery;

import java.io.IOException;
Expand All @@ -26,7 +30,7 @@
* Full text function that performs a {@link KqlQuery} .
*/
public class Kql extends FullTextFunction {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Kql", Kql::new);
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Kql", Kql::readFrom);

@FunctionInfo(
returnType = "boolean",
Expand All @@ -42,17 +46,30 @@ public Kql(
description = "Query string in KQL query string format."
) Expression queryString
) {
super(source, queryString, List.of(queryString));
super(source, queryString, List.of(queryString), null);
}

private Kql(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
public Kql(Source source, Expression queryString, QueryBuilder queryBuilder) {
super(source, queryString, List.of(queryString), queryBuilder);
}

private static Kql readFrom(StreamInput in) throws IOException {
Source source = Source.readFrom((PlanStreamInput) in);
Expression query = in.readNamedWriteable(Expression.class);
QueryBuilder queryBuilder = null;
if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS)) {
queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
}
return new Kql(source, query, queryBuilder);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(query());
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS)) {
out.writeOptionalNamedWriteable(queryBuilder());
}
}

@Override
Expand All @@ -62,12 +79,21 @@ public String getWriteableName() {

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Kql(source(), newChildren.get(0));
return new Kql(source(), newChildren.get(0), queryBuilder());
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Kql::new, query());
return NodeInfo.create(this, Kql::new, query(), queryBuilder());
}

@Override
protected ExpressionTranslator<Kql> translator() {
return new EsqlExpressionTranslators.KqlFunctionTranslator();
}

@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Kql(source(), query(), queryBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
package org.elasticsearch.xpack.esql.expression.function.fulltext;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.planner.ExpressionTranslator;
import org.elasticsearch.xpack.esql.core.querydsl.query.QueryStringQuery;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand All @@ -27,6 +30,7 @@
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.EsqlExpressionTranslators;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;

import java.io.IOException;
Expand Down Expand Up @@ -109,22 +113,33 @@ public Match(
description = "Value to find in the provided field."
) Expression matchQuery
) {
super(source, matchQuery, List.of(field, matchQuery));
this(source, field, matchQuery, null);
}

public Match(Source source, Expression field, Expression matchQuery, QueryBuilder queryBuilder) {
super(source, matchQuery, List.of(field, matchQuery), queryBuilder);
this.field = field;
}

private static Match readFrom(StreamInput in) throws IOException {
Source source = Source.readFrom((PlanStreamInput) in);
Expression field = in.readNamedWriteable(Expression.class);
Expression query = in.readNamedWriteable(Expression.class);
return new Match(source, field, query);
QueryBuilder queryBuilder = null;
if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS)) {
queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
}
return new Match(source, field, query, queryBuilder);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field());
out.writeNamedWriteable(query());
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS)) {
out.writeOptionalNamedWriteable(queryBuilder());
}
}

@Override
Expand Down Expand Up @@ -224,12 +239,12 @@ public Object queryAsObject() {

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Match(source(), newChildren.get(0), newChildren.get(1));
return new Match(source(), newChildren.get(0), newChildren.get(1), queryBuilder());
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Match::new, field, query());
return NodeInfo.create(this, Match::new, field, query(), queryBuilder());
}

protected TypeResolutions.ParamOrdinal queryParamOrdinal() {
Expand All @@ -245,6 +260,16 @@ public String functionType() {
return isOperator() ? "operator" : super.functionType();
}

@Override
protected ExpressionTranslator<Match> translator() {
return new EsqlExpressionTranslators.MatchFunctionTranslator();
}

@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Match(source(), field, query(), queryBuilder);
}

@Override
public String functionName() {
return isOperator() ? ":" : super.functionName();
Expand Down
Loading

0 comments on commit a765f89

Please sign in to comment.