Skip to content

Commit

Permalink
Move fields validation to FunctionExpression interface and derived cl…
Browse files Browse the repository at this point in the history
…ass OpenSearchFunctions.

Signed-off-by: forestmvey <[email protected]>
  • Loading branch information
forestmvey committed Nov 17, 2022
1 parent 769e7da commit 9f5304c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 46 deletions.
5 changes: 3 additions & 2 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
Expand Down Expand Up @@ -219,7 +218,9 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);

OpenSearchFunctions.validateFieldList((FunctionExpression)condition, context);
if (condition instanceof OpenSearchFunctions.OpenSearchFunction) {
((OpenSearchFunctions.OpenSearchFunction)condition).validateParameters(context);
}

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.expression.function.FunctionImplementation;
import org.opensearch.sql.expression.function.FunctionName;

Expand All @@ -32,4 +33,11 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitFunction(this, context);
}

/**
* Verify if function queries fields available in type environment.
* @param context : Context of fields querying.
*/
public void validateParameters(AnalysisContext context) {
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
Expand All @@ -27,14 +26,14 @@

@UtilityClass
public class OpenSearchFunctions {
private final List<String> singleFieldFunctionNames = ImmutableList.of(
private final List<String> singleFieldFunctionNames = List.of(
BuiltinFunctionName.MATCH.name(),
BuiltinFunctionName.MATCH_BOOL_PREFIX.name(),
BuiltinFunctionName.MATCHPHRASE.name(),
BuiltinFunctionName.MATCH_PHRASE_PREFIX.name()
);

private final List<String> multiFieldFunctionNames = ImmutableList.of(
private final List<String> multiFieldFunctionNames = List.of(
BuiltinFunctionName.MULTI_MATCH.name(),
BuiltinFunctionName.SIMPLE_QUERY_STRING.name(),
BuiltinFunctionName.QUERY_STRING.name()
Expand All @@ -58,35 +57,6 @@ public static boolean isMultiFieldFunction(String funcName) {
return multiFieldFunctionNames.contains(funcName.toUpperCase());
}

/**
* Verify if function queries fields available in type environment.
* @param node : Function used in query.
* @param context : Context of fields querying.
*/
public static void validateFieldList(FunctionExpression node, AnalysisContext context) {
String funcName = node.getFunctionName().toString();

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
((arg.getArgName().equals("field")
&& !arg.getValue().toString().contains("*"))
)).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(arg.getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
arg.getArgName().equals("fields")
).findFirst().ifPresent(fields ->
fields.getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}
}

/**
* Add functions specific to OpenSearch to repository.
*/
Expand Down Expand Up @@ -179,5 +149,35 @@ public String toString() {
.collect(Collectors.toList());
return String.format("%s(%s)", functionName, String.join(", ", args));
}

/**
* Verify if function queries fields available in type environment.
* @param context : Context of fields querying.
*/
@Override
public void validateParameters(AnalysisContext context) {
String funcName = this.getFunctionName().toString();

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
((arg.getArgName().equals("field")
&& !arg.getValue().toString().contains("*"))
)).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(arg.getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
arg.getArgName().equals("fields")
).findFirst().ifPresent(fields ->
fields.getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}
}

}
}
60 changes: 48 additions & 12 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,17 @@
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.window.WindowDefinition;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
Expand Down Expand Up @@ -268,6 +274,23 @@ public void analyze_filter_aggregation_relation() {
aggregate("MIN", qualifiedName("integer_value")), intLiteral(10))));
}

@Test
public void test_base_class_validate_parameters_method_does_nothing() {
var funcExpr = new FunctionExpression(FunctionName.of("func_name"),
List.of()) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
return null;
}

@Override
public ExprType type() {
return null;
}
};
funcExpr.validateParameters(analysisContext);
}

@Test
public void single_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
Expand All @@ -290,9 +313,9 @@ public void single_field_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.match(
dsl.namedArgument("field", DSL.literal("string_value")),
dsl.namedArgument("query", DSL.literal("query_value")))),
DSL.match(
DSL.namedArgument("field", DSL.literal("string_value")),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
Expand All @@ -305,9 +328,9 @@ public void single_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.match(
dsl.namedArgument("field", DSL.literal("wildcard_field*")),
dsl.namedArgument("query", DSL.literal("query_value")))),
DSL.match(
DSL.namedArgument("field", DSL.literal("wildcard_field*")),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
Expand Down Expand Up @@ -351,19 +374,32 @@ public void multi_field_relevance_query_mixed_fields_semantic_exception() {
exception.getMessage());
}

@Test
public void no_field_relevance_query_semantic_exception() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
DSL.query(
DSL.namedArgument("query", DSL.literal("string_value:query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query",
AstDSL.unresolvedArg("query", stringLiteral("string_value:query_value")))));
}

@Test
public void multi_field_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.query_string(
dsl.namedArgument("fields", DSL.literal(
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"string_value", ExprValueUtils.floatValue(1.F),
"integer_value", ExprValueUtils.floatValue(.3F))
))
)),
dsl.namedArgument("query", DSL.literal("query_value")))),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
Expand All @@ -377,14 +413,14 @@ public void multi_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.query_string(
dsl.namedArgument("fields", DSL.literal(
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"wildcard_field1*", ExprValueUtils.floatValue(1.F),
"wildcard_field2*", ExprValueUtils.floatValue(.3F))
))
)),
dsl.namedArgument("query", DSL.literal("query_value")))),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
Expand Down

0 comments on commit 9f5304c

Please sign in to comment.