diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index fa86674e24..228b54ba0c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -75,7 +75,6 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.planner.logical.LogicalAD; @@ -221,10 +220,6 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); - if (condition instanceof OpenSearchFunctions.OpenSearchFunction) { - ((OpenSearchFunctions.OpenSearchFunction)condition).validateParameters(context); - } - ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); Expression optimized = optimizer.optimize(condition, context); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 719c3adbce..08b2446190 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -3,9 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; +import static org.opensearch.sql.expression.function.OpenSearchFunctions.isMultiFieldFunction; +import static org.opensearch.sql.expression.function.OpenSearchFunctions.isSingleFieldFunction; + import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -44,6 +46,7 @@ import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; @@ -180,6 +183,25 @@ public Expression visitFunction(Function node, AnalysisContext context) { node.getFuncArgs().stream() .map(unresolvedExpression -> analyze(unresolvedExpression, context)) .collect(Collectors.toList()); + + // Resolve fields and field parameters in TypeEnv + if (isSingleFieldFunction(functionName.getFunctionName()) + || isMultiFieldFunction(functionName.getFunctionName())) { + for (Expression arg : arguments) { + if (((NamedArgumentExpression) arg).getArgName().equals("field") + && !((NamedArgumentExpression)arg).getValue().toString().contains("*")) { + visitQualifiedName(new QualifiedName(StringUtils.unquoteText( + ((NamedArgumentExpression)arg).getValue().toString())), context); + } else if (((NamedArgumentExpression)arg).getArgName().equals("fields")) { + ((NamedArgumentExpression) arg).getValue().valueOf().tupleValue().entrySet().stream() + .filter(field -> !field.getKey().contains("*") + ).forEach( + entry -> visitQualifiedName(new QualifiedName(entry.getKey()), context) + ); + } + } + } + return (Expression) repository.compile(context.getFunctionProperties(), functionName, arguments); } diff --git a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java index 8d3f3b56b1..2a695f26e6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java @@ -11,7 +11,6 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; -import org.opensearch.sql.analysis.AnalysisContext; import org.opensearch.sql.expression.function.FunctionImplementation; import org.opensearch.sql.expression.function.FunctionName; @@ -33,11 +32,4 @@ public T accept(ExpressionNodeVisitor visitor, C context) { return visitor.visitFunction(this, context); } - /** - * Verify if function queries fields available in type environment. - * @param context : Context of fields querying. - */ - public void validateParameters(AnalysisContext context) { - return; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index b4a95493ef..91dcf1161b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -11,11 +11,6 @@ import java.util.List; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; -import org.opensearch.sql.analysis.AnalysisContext; -import org.opensearch.sql.analysis.TypeEnvironment; -import org.opensearch.sql.analysis.symbol.Namespace; -import org.opensearch.sql.analysis.symbol.Symbol; -import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -160,35 +155,5 @@ public String toString() { .collect(Collectors.toList()); return String.format("%s(%s)", functionName, String.join(", ", args)); } - - /** - * Verify if function queries fields available in type environment. - * @param context : Context of fields querying. - */ - @Override - public void validateParameters(AnalysisContext context) { - String funcName = this.getFunctionName().toString(); - - TypeEnvironment typeEnv = context.peek(); - if (isSingleFieldFunction(funcName)) { - this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> - ((arg.getArgName().equals("field") - && !arg.getValue().toString().contains("*")) - )).findFirst().ifPresent(arg -> - typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, - StringUtils.unquoteText(arg.getValue().toString())) - ) - ); - } else if (isMultiFieldFunction(funcName)) { - this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> - arg.getArgName().equals("fields") - ).findFirst().ifPresent(fields -> - fields.getValue().valueOf(null).tupleValue() - .entrySet().stream().filter(k -> !(k.getKey().contains("*")) - ).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey()))) - ); - } - } - } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index a21a61efe9..b4ebd96c5b 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -275,23 +275,6 @@ public void analyze_filter_aggregation_relation() { aggregate("MIN", qualifiedName("integer_value")), intLiteral(10)))); } - @Test - public void test_base_class_validate_parameters_method_does_nothing() { - var funcExpr = new FunctionExpression(FunctionName.of("func_name"), - List.of()) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return null; - } - - @Override - public ExprType type() { - return null; - } - }; - funcExpr.validateParameters(analysisContext); - } - @Test public void single_field_relevance_query_semantic_exception() { SemanticCheckException exception = @@ -324,21 +307,6 @@ public void single_field_relevance_query() { AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); } - @Test - public void single_field_wildcard_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match( - DSL.namedArgument("field", DSL.literal("wildcard_field*")), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("match", - AstDSL.unresolvedArg("field", stringLiteral("wildcard_field*")), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - @Test public void multi_field_relevance_query_semantic_exception() { SemanticCheckException exception = @@ -409,27 +377,6 @@ public void multi_field_relevance_query() { AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); } - @Test - public void multi_field_wildcard_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "wildcard_field1*", ExprValueUtils.floatValue(1.F), - "wildcard_field2*", ExprValueUtils.floatValue(.3F)) - )) - )), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "wildcard_field1*", 1.F, "wildcard_field2*", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - @Test public void rename_relation() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 7114b220ab..c8b40d1562 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -358,10 +358,10 @@ public void named_non_parse_expression() { void match_bool_prefix_expression() { assertAnalyzeEqual( DSL.match_bool_prefix( - DSL.namedArgument("field", DSL.literal("fieldA")), + DSL.namedArgument("field", DSL.literal("field_value1")), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("match_bool_prefix", - AstDSL.unresolvedArg("field", stringLiteral("fieldA")), + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -402,11 +402,11 @@ void multi_match_expression() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -416,12 +416,12 @@ void multi_match_expression_with_params() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -432,12 +432,12 @@ void multi_match_expression_two_fields() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -447,11 +447,11 @@ void simple_query_string_expression() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -461,12 +461,12 @@ void simple_query_string_expression_with_params() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -477,12 +477,12 @@ void simple_query_string_expression_two_fields() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -501,11 +501,11 @@ void query_string_expression() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -515,12 +515,12 @@ void query_string_expression_with_params() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value")), DSL.namedArgument("escape", DSL.literal("false"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")), AstDSL.unresolvedArg("escape", stringLiteral("false")))); } @@ -531,12 +531,12 @@ void query_string_expression_two_fields() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -572,7 +572,7 @@ void wildcard_query_expression_all_params() { public void match_phrase_prefix_all_params() { assertAnalyzeEqual( DSL.match_phrase_prefix( - DSL.namedArgument("field", "test"), + DSL.namedArgument("field", "field_value1"), DSL.namedArgument("query", "search query"), DSL.namedArgument("slop", "3"), DSL.namedArgument("boost", "1.5"), @@ -581,7 +581,7 @@ public void match_phrase_prefix_all_params() { DSL.namedArgument("zero_terms_query", "NONE") ), AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("test")), + unresolvedArg("field", stringLiteral("field_value1")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("slop", stringLiteral("3")), unresolvedArg("boost", stringLiteral("1.5")), diff --git a/core/src/test/java/org/opensearch/sql/config/TestConfig.java b/core/src/test/java/org/opensearch/sql/config/TestConfig.java index a0ef436162..4159ae12ff 100644 --- a/core/src/test/java/org/opensearch/sql/config/TestConfig.java +++ b/core/src/test/java/org/opensearch/sql/config/TestConfig.java @@ -57,6 +57,8 @@ public class TestConfig { .put("struct_value", ExprCoreType.STRUCT) .put("array_value", ExprCoreType.ARRAY) .put("timestamp_value", ExprCoreType.TIMESTAMP) + .put("field_value1", ExprCoreType.STRING) + .put("field_value2", ExprCoreType.STRING) .build(); @Bean