From e32b733f5206747a58684bd751557f0b6cf602dd Mon Sep 17 00:00:00 2001 From: forestmvey Date: Thu, 8 Dec 2022 11:32:51 -0800 Subject: [PATCH] Validating field and fields parameters in TypeEnv as part of ExpressionAnalyzer. Signed-off-by: forestmvey --- .../org/opensearch/sql/analysis/Analyzer.java | 5 - .../sql/expression/FunctionExpression.java | 8 - .../function/OpenSearchFunctions.java | 87 +-------- .../function/RelevanceFunctionResolver.java | 11 -- .../opensearch/sql/analysis/AnalyzerTest.java | 165 ------------------ .../sql/analysis/ExpressionAnalyzerTest.java | 50 +++--- .../org/opensearch/sql/config/TestConfig.java | 2 + .../RelevanceFunctionResolverTest.java | 11 +- .../lucene/relevance/SingleFieldQuery.java | 6 +- .../script/filter/FilterQueryBuilderTest.java | 8 - .../relevance/SingleFieldQueryTest.java | 17 ++ .../sql/sql/parser/AstExpressionBuilder.java | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 18 +- 13 files changed, 69 insertions(+), 321 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index fa86674e24..228b54ba0c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -75,7 +75,6 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.planner.logical.LogicalAD; @@ -221,10 +220,6 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); - if (condition instanceof OpenSearchFunctions.OpenSearchFunction) { - ((OpenSearchFunctions.OpenSearchFunction)condition).validateParameters(context); - } - ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); Expression optimized = optimizer.optimize(condition, context); diff --git a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java index 8d3f3b56b1..2a695f26e6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java @@ -11,7 +11,6 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; -import org.opensearch.sql.analysis.AnalysisContext; import org.opensearch.sql.expression.function.FunctionImplementation; import org.opensearch.sql.expression.function.FunctionName; @@ -33,11 +32,4 @@ public T accept(ExpressionNodeVisitor visitor, C context) { return visitor.visitFunction(this, context); } - /** - * Verify if function queries fields available in type environment. - * @param context : Context of fields querying. - */ - public void validateParameters(AnalysisContext context) { - return; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index b4a95493ef..842cf25cd6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -5,17 +5,9 @@ package org.opensearch.sql.expression.function; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; - import java.util.List; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; -import org.opensearch.sql.analysis.AnalysisContext; -import org.opensearch.sql.analysis.TypeEnvironment; -import org.opensearch.sql.analysis.symbol.Namespace; -import org.opensearch.sql.analysis.symbol.Symbol; -import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -26,37 +18,6 @@ @UtilityClass public class OpenSearchFunctions { - private final List singleFieldFunctionNames = List.of( - BuiltinFunctionName.MATCH.name(), - BuiltinFunctionName.MATCH_BOOL_PREFIX.name(), - BuiltinFunctionName.MATCHPHRASE.name(), - BuiltinFunctionName.MATCH_PHRASE_PREFIX.name() - ); - - private final List multiFieldFunctionNames = List.of( - BuiltinFunctionName.MULTI_MATCH.name(), - BuiltinFunctionName.SIMPLE_QUERY_STRING.name(), - BuiltinFunctionName.QUERY_STRING.name() - ); - - /** - * Check if supplied function name is valid SingleFieldRelevanceFunction. - * @param funcName : Name of function - * @return : True if function is single-field function - */ - public static boolean isSingleFieldFunction(String funcName) { - return singleFieldFunctionNames.contains(funcName.toUpperCase()); - } - - /** - * Check if supplied function name is valid MultiFieldRelevanceFunction. - * @param funcName : Name of function - * @return : True if function is multi-field function - */ - public static boolean isMultiFieldFunction(String funcName) { - return multiFieldFunctionNames.contains(funcName.toUpperCase()); - } - /** * Add functions specific to OpenSearch to repository. */ @@ -83,46 +44,46 @@ public void register(BuiltinFunctionRepository repository) { private static FunctionResolver match_bool_prefix() { FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); - return new RelevanceFunctionResolver(name, STRING); + return new RelevanceFunctionResolver(name); } private static FunctionResolver match(BuiltinFunctionName match) { FunctionName funcName = match.getName(); - return new RelevanceFunctionResolver(funcName, STRING); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver match_phrase_prefix() { FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); - return new RelevanceFunctionResolver(funcName, STRING); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { FunctionName funcName = matchPhrase.getName(); - return new RelevanceFunctionResolver(funcName, STRING); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) { - return new RelevanceFunctionResolver(multiMatchName.getName(), STRUCT); + return new RelevanceFunctionResolver(multiMatchName.getName()); } private static FunctionResolver simple_query_string() { FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName, STRUCT); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver query() { FunctionName funcName = BuiltinFunctionName.QUERY.getName(); - return new RelevanceFunctionResolver(funcName, STRING); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver query_string() { FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName, STRUCT); + return new RelevanceFunctionResolver(funcName); } private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { FunctionName funcName = wildcardQuery.getName(); - return new RelevanceFunctionResolver(funcName, STRING); + return new RelevanceFunctionResolver(funcName); } public static class OpenSearchFunction extends FunctionExpression { @@ -160,35 +121,5 @@ public String toString() { .collect(Collectors.toList()); return String.format("%s(%s)", functionName, String.join(", ", args)); } - - /** - * Verify if function queries fields available in type environment. - * @param context : Context of fields querying. - */ - @Override - public void validateParameters(AnalysisContext context) { - String funcName = this.getFunctionName().toString(); - - TypeEnvironment typeEnv = context.peek(); - if (isSingleFieldFunction(funcName)) { - this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> - ((arg.getArgName().equals("field") - && !arg.getValue().toString().contains("*")) - )).findFirst().ifPresent(arg -> - typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, - StringUtils.unquoteText(arg.getValue().toString())) - ) - ); - } else if (isMultiFieldFunction(funcName)) { - this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> - arg.getArgName().equals("fields") - ).findFirst().ifPresent(fields -> - fields.getValue().valueOf(null).tupleValue() - .entrySet().stream().filter(k -> !(k.getKey().contains("*")) - ).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey()))) - ); - } - } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java index 7066622e1b..ef0ac9226c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -20,9 +20,6 @@ public class RelevanceFunctionResolver @Getter private final FunctionName functionName; - @Getter - private final ExprType declaredFirstParamType; - @Override public Pair resolve(FunctionSignature unresolvedSignature) { if (!unresolvedSignature.getFunctionName().equals(functionName)) { @@ -30,14 +27,6 @@ public Pair resolve(FunctionSignature unreso functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName())); } List paramTypes = unresolvedSignature.getParamTypeList(); - ExprType providedFirstParamType = paramTypes.get(0); - - // Check if the first parameter is of the specified type. - if (!declaredFirstParamType.equals(providedFirstParamType)) { - throw new SemanticCheckException( - getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType)); - } - // Check if all but the first parameter are of type STRING. for (int i = 1; i < paramTypes.size(); i++) { ExprType paramType = paramTypes.get(i); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index a21a61efe9..044949ea35 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -59,7 +59,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -73,24 +72,15 @@ import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; -import org.opensearch.sql.data.model.ExprTupleValue; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.HighlightExpression; -import org.opensearch.sql.expression.env.Environment; -import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; @@ -275,161 +265,6 @@ public void analyze_filter_aggregation_relation() { aggregate("MIN", qualifiedName("integer_value")), intLiteral(10)))); } - @Test - public void test_base_class_validate_parameters_method_does_nothing() { - var funcExpr = new FunctionExpression(FunctionName.of("func_name"), - List.of()) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return null; - } - - @Override - public ExprType type() { - return null; - } - }; - funcExpr.validateParameters(analysisContext); - } - - @Test - public void single_field_relevance_query_semantic_exception() { - SemanticCheckException exception = - assertThrows( - SemanticCheckException.class, - () -> - analyze( - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("match", - AstDSL.unresolvedArg("field", stringLiteral("missing_value")), - AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); - assertEquals( - "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env", - exception.getMessage()); - } - - @Test - public void single_field_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match( - DSL.namedArgument("field", DSL.literal("string_value")), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("match", - AstDSL.unresolvedArg("field", stringLiteral("string_value")), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - - @Test - public void single_field_wildcard_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match( - DSL.namedArgument("field", DSL.literal("wildcard_field*")), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("match", - AstDSL.unresolvedArg("field", stringLiteral("wildcard_field*")), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - - @Test - public void multi_field_relevance_query_semantic_exception() { - SemanticCheckException exception = - assertThrows( - SemanticCheckException.class, - () -> - analyze( - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "missing_value1", 1.F, "missing_value2", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); - assertEquals( - "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value1) in type env", - exception.getMessage()); - } - - @Test - public void multi_field_relevance_query_mixed_fields_semantic_exception() { - SemanticCheckException exception = - assertThrows( - SemanticCheckException.class, - () -> - analyze( - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "string_value", 1.F, "missing_value", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); - assertEquals( - "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env", - exception.getMessage()); - } - - @Test - public void no_field_relevance_query_semantic_exception() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.query( - DSL.namedArgument("query", DSL.literal("string_value:query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query", - AstDSL.unresolvedArg("query", stringLiteral("string_value:query_value"))))); - } - - @Test - public void multi_field_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "string_value", ExprValueUtils.floatValue(1.F), - "integer_value", ExprValueUtils.floatValue(.3F)) - )) - )), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "string_value", 1.F, "integer_value", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - - @Test - public void multi_field_wildcard_relevance_query() { - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "wildcard_field1*", ExprValueUtils.floatValue(1.F), - "wildcard_field2*", ExprValueUtils.floatValue(.3F)) - )) - )), - DSL.namedArgument("query", DSL.literal("query_value")))), - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "wildcard_field1*", 1.F, "wildcard_field2*", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); - } - @Test public void rename_relation() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 7114b220ab..c8b40d1562 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -358,10 +358,10 @@ public void named_non_parse_expression() { void match_bool_prefix_expression() { assertAnalyzeEqual( DSL.match_bool_prefix( - DSL.namedArgument("field", DSL.literal("fieldA")), + DSL.namedArgument("field", DSL.literal("field_value1")), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("match_bool_prefix", - AstDSL.unresolvedArg("field", stringLiteral("fieldA")), + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -402,11 +402,11 @@ void multi_match_expression() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -416,12 +416,12 @@ void multi_match_expression_with_params() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -432,12 +432,12 @@ void multi_match_expression_two_fields() { DSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("multi_match", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -447,11 +447,11 @@ void simple_query_string_expression() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -461,12 +461,12 @@ void simple_query_string_expression_with_params() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -477,12 +477,12 @@ void simple_query_string_expression_two_fields() { DSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), AstDSL.function("simple_query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -501,11 +501,11 @@ void query_string_expression() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -515,12 +515,12 @@ void query_string_expression_with_params() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field", ExprValueUtils.floatValue(1.F)))))), + "field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value")), DSL.namedArgument("escape", DSL.literal("false"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field", 1.F))), + "field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")), AstDSL.unresolvedArg("escape", stringLiteral("false")))); } @@ -531,12 +531,12 @@ void query_string_expression_two_fields() { DSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))), + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), AstDSL.function("query_string", AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field1", 1.F, "field2", .3F))), + "field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -572,7 +572,7 @@ void wildcard_query_expression_all_params() { public void match_phrase_prefix_all_params() { assertAnalyzeEqual( DSL.match_phrase_prefix( - DSL.namedArgument("field", "test"), + DSL.namedArgument("field", "field_value1"), DSL.namedArgument("query", "search query"), DSL.namedArgument("slop", "3"), DSL.namedArgument("boost", "1.5"), @@ -581,7 +581,7 @@ public void match_phrase_prefix_all_params() { DSL.namedArgument("zero_terms_query", "NONE") ), AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("test")), + unresolvedArg("field", stringLiteral("field_value1")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("slop", stringLiteral("3")), unresolvedArg("boost", stringLiteral("1.5")), diff --git a/core/src/test/java/org/opensearch/sql/config/TestConfig.java b/core/src/test/java/org/opensearch/sql/config/TestConfig.java index a0ef436162..4159ae12ff 100644 --- a/core/src/test/java/org/opensearch/sql/config/TestConfig.java +++ b/core/src/test/java/org/opensearch/sql/config/TestConfig.java @@ -57,6 +57,8 @@ public class TestConfig { .put("struct_value", ExprCoreType.STRUCT) .put("array_value", ExprCoreType.ARRAY) .put("timestamp_value", ExprCoreType.TIMESTAMP) + .put("field_value1", ExprCoreType.STRING) + .put("field_value2", ExprCoreType.STRING) .build(); @Bean diff --git a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java index d8547057c4..deba721481 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java @@ -24,7 +24,7 @@ class RelevanceFunctionResolverTest { @BeforeEach void setUp() { - resolver = new RelevanceFunctionResolver(sampleFuncName, STRING); + resolver = new RelevanceFunctionResolver(sampleFuncName); } @Test @@ -44,15 +44,6 @@ void resolve_invalid_name_test() { exception.getMessage()); } - @Test - void resolve_invalid_first_param_type_test() { - var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER)); - Exception exception = assertThrows(SemanticCheckException.class, - () -> resolver.resolve(sig)); - assertEquals("Expected type STRING instead of INTEGER for parameter #1", - exception.getMessage()); - } - @Test void resolve_invalid_third_param_type_test() { var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING)); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java index a7d7584d4f..d90b3e83ac 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java @@ -9,7 +9,9 @@ import java.util.Map; import org.opensearch.index.query.QueryBuilder; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.ReferenceExpression; /** * Base class to represent builder class for relevance queries like match_query, match_bool_prefix, @@ -36,7 +38,9 @@ protected T createQueryBuilder(List arguments) { .orElseThrow(() -> new SemanticCheckException("'query' parameter is missing")); return createBuilder( - field.getValue().valueOf().stringValue(), + (field.getValue() instanceof LiteralExpression) + ? field.getValue().valueOf().stringValue() + : field.getValue().toString(), query.getValue().valueOf().stringValue()); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index cea4e2488a..4ad46509d3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -1177,14 +1177,6 @@ void should_build_match_bool_prefix_query_with_default_parameters() { DSL.namedArgument("query", literal("search query"))))); } - @Test - void multi_match_missing_fields() { - var msg = assertThrows(SemanticCheckException.class, () -> - DSL.multi_match( - DSL.namedArgument("query", literal("search query")))).getMessage(); - assertEquals("Expected type STRUCT instead of STRING for parameter #1", msg); - } - @Test void multi_match_missing_fields_even_with_struct() { FunctionExpression expr = DSL.multi_match( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java index b2d650602b..da3720640d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java @@ -16,9 +16,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; class SingleFieldQueryTest { SingleFieldQuery query; @@ -47,4 +50,18 @@ void createQueryBuilderTest() { verify(query).createBuilder(eq(sampleField), eq(sampleQuery)); } + + @Test + void createQueryBuilderQualifiedNameTest() { + String sampleQuery = "sample query"; + String sampleField = "fieldA"; + + query.createQueryBuilder(List.of(DSL.namedArgument("field", + new ReferenceExpression(sampleField, OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD)), + DSL.namedArgument("query", + new LiteralExpression(ExprValueUtils.stringValue(sampleQuery))))); + + verify(query).createBuilder(eq(sampleField), + eq(sampleQuery)); + } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index bae22595ca..9045870c04 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -475,7 +475,7 @@ private List singleFieldRelevanceArguments( // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); builder.add(new UnresolvedArgument("field", - new Literal(StringUtils.unquoteText(ctx.field.getText()), DataType.STRING))); + new QualifiedName(StringUtils.unquoteText(ctx.field.getText())))); builder.add(new UnresolvedArgument("query", new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING))); fillRelevanceArgs(ctx.relevanceArg(), builder); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index 9af4119fdf..c4223fa3aa 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -463,7 +463,7 @@ public void filteredDistinctCount() { public void matchPhraseQueryAllParameters() { assertEquals( AstDSL.function("matchphrasequery", - unresolvedArg("field", stringLiteral("test")), + unresolvedArg("field", qualifiedName("test")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("slop", stringLiteral("3")), unresolvedArg("analyzer", stringLiteral("standard")), @@ -479,7 +479,7 @@ public void matchPhraseQueryAllParameters() { public void matchPhrasePrefixAllParameters() { assertEquals( AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("test")), + unresolvedArg("field", qualifiedName("test")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("slop", stringLiteral("3")), unresolvedArg("boost", stringLiteral("1.5")), @@ -496,13 +496,13 @@ public void matchPhrasePrefixAllParameters() { @Test public void relevanceMatch() { assertEquals(AstDSL.function("match", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query"))), buildExprAst("match('message', 'search query')") ); assertEquals(AstDSL.function("match", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("analyzer", stringLiteral("keyword")), unresolvedArg("operator", stringLiteral("AND"))), @@ -512,13 +512,13 @@ public void relevanceMatch() { @Test public void relevanceMatchQuery() { assertEquals(AstDSL.function("matchquery", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query"))), buildExprAst("matchquery('message', 'search query')") ); assertEquals(AstDSL.function("matchquery", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("analyzer", stringLiteral("keyword")), unresolvedArg("operator", stringLiteral("AND"))), @@ -528,13 +528,13 @@ public void relevanceMatchQuery() { @Test public void relevanceMatch_Query() { assertEquals(AstDSL.function("match_query", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query"))), buildExprAst("match_query('message', 'search query')") ); assertEquals(AstDSL.function("match_query", - unresolvedArg("field", stringLiteral("message")), + unresolvedArg("field", qualifiedName("message")), unresolvedArg("query", stringLiteral("search query")), unresolvedArg("analyzer", stringLiteral("keyword")), unresolvedArg("operator", stringLiteral("AND"))), @@ -640,7 +640,7 @@ public void relevanceQuery_string() { @Test public void relevanceWildcard_query() { assertEquals(AstDSL.function("wildcard_query", - unresolvedArg("field", stringLiteral("field")), + unresolvedArg("field", qualifiedName("field")), unresolvedArg("query", stringLiteral("search query*")), unresolvedArg("boost", stringLiteral("1.5")), unresolvedArg("case_insensitive", stringLiteral("true")),