diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 7d0a452e1b..fa2cce2739 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -11,18 +11,10 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; -import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; -import static org.opensearch.sql.utils.MLCommonsConstants.MODELID; -import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; -import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP; -import static org.opensearch.sql.utils.MLCommonsConstants.STATUS; -import static org.opensearch.sql.utils.MLCommonsConstants.TASKID; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; -import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; -import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT; import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME; import com.google.common.collect.ImmutableList; @@ -76,6 +68,7 @@ import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; @@ -83,6 +76,7 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.planner.logical.LogicalAD; @@ -225,6 +219,8 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); + OpenSearchFunctions.validateFieldList((FunctionExpression)condition, context); + ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); Expression optimized = optimizer.optimize(condition, context); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index 97afe3675e..5433f031be 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -8,10 +8,15 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import java.util.List; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; +import org.opensearch.sql.analysis.AnalysisContext; +import org.opensearch.sql.analysis.TypeEnvironment; +import org.opensearch.sql.analysis.symbol.Namespace; +import org.opensearch.sql.analysis.symbol.Symbol; +import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -22,6 +27,66 @@ @UtilityClass public class OpenSearchFunctions { + private final List singleFieldFunctionNames = ImmutableList.of( + BuiltinFunctionName.MATCH.name(), + BuiltinFunctionName.MATCH_BOOL_PREFIX.name(), + BuiltinFunctionName.MATCHPHRASE.name(), + BuiltinFunctionName.MATCH_PHRASE_PREFIX.name() + ); + + private final List multiFieldFunctionNames = ImmutableList.of( + BuiltinFunctionName.MULTI_MATCH.name(), + BuiltinFunctionName.SIMPLE_QUERY_STRING.name(), + BuiltinFunctionName.QUERY_STRING.name() + ); + + /** + * Check if supplied function name is valid SingleFieldRelevanceFunction. + * @param funcName : Name of function + * @return : True if function is single-field function + */ + public static boolean isSingleFieldFunction(String funcName) { + return singleFieldFunctionNames.contains(funcName.toUpperCase()); + } + + /** + * Check if supplied function name is valid MultiFieldRelevanceFunction. + * @param funcName : Name of function + * @return : True if function is multi-field function + */ + public static boolean isMultiFieldFunction(String funcName) { + return multiFieldFunctionNames.contains(funcName.toUpperCase()); + } + + /** + * Verify if function queries fields available in type environment. + * @param node : Function used in query. + * @param context : Context of fields querying. + */ + public static void validateFieldList(FunctionExpression node, AnalysisContext context) { + String funcName = node.getFunctionName().toString(); + + TypeEnvironment typeEnv = context.peek(); + if (isSingleFieldFunction(funcName)) { + node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> + ((arg.getArgName().equals("field") + && !arg.getValue().toString().contains("*")) + )).findFirst().ifPresent(arg -> + typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, + StringUtils.unquoteText(arg.getValue().toString())) + ) + ); + } else if (isMultiFieldFunction(funcName)) { + node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg -> + arg.getArgName().equals("fields") + ).findFirst().ifPresent(fields -> + fields.getValue().valueOf(null).tupleValue() + .entrySet().stream().filter(k -> !(k.getKey().contains("*")) + ).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey()))) + ); + } + } + /** * Add functions specific to OpenSearch to repository. */ diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 723ab736da..c2eeb46bae 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -58,6 +58,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -71,11 +72,14 @@ import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -264,6 +268,131 @@ public void analyze_filter_aggregation_relation() { aggregate("MIN", qualifiedName("integer_value")), intLiteral(10)))); } + @Test + public void single_field_relevance_query_semantic_exception() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("match", + AstDSL.unresolvedArg("field", stringLiteral("missing_value")), + AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env", + exception.getMessage()); + } + + @Test + public void single_field_relevance_query() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + dsl.match( + dsl.namedArgument("field", DSL.literal("string_value")), + dsl.namedArgument("query", DSL.literal("query_value")))), + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("match", + AstDSL.unresolvedArg("field", stringLiteral("string_value")), + AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); + } + + @Test + public void single_field_wildcard_relevance_query() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + dsl.match( + dsl.namedArgument("field", DSL.literal("wildcard_field*")), + dsl.namedArgument("query", DSL.literal("query_value")))), + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("match", + AstDSL.unresolvedArg("field", stringLiteral("wildcard_field*")), + AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); + } + + @Test + public void multi_field_relevance_query_semantic_exception() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "missing_value1", 1.F, "missing_value2", .3F))), + AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value1) in type env", + exception.getMessage()); + } + + @Test + public void multi_field_relevance_query_mixed_fields_semantic_exception() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "string_value", 1.F, "missing_value", .3F))), + AstDSL.unresolvedArg("query", stringLiteral("query_value")))))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env", + exception.getMessage()); + } + + @Test + public void multi_field_relevance_query() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + dsl.query_string( + dsl.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "string_value", ExprValueUtils.floatValue(1.F), + "integer_value", ExprValueUtils.floatValue(.3F)) + )) + )), + dsl.namedArgument("query", DSL.literal("query_value")))), + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "string_value", 1.F, "integer_value", .3F))), + AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); + } + + @Test + public void multi_field_wildcard_relevance_query() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + dsl.query_string( + dsl.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "wildcard_field1*", ExprValueUtils.floatValue(1.F), + "wildcard_field2*", ExprValueUtils.floatValue(.3F)) + )) + )), + dsl.namedArgument("query", DSL.literal("query_value")))), + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "wildcard_field1*", 1.F, "wildcard_field2*", .3F))), + AstDSL.unresolvedArg("query", stringLiteral("query_value"))))); + } + @Test public void rename_relation() { assertAnalyzeEqual( diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java index 813f6b7a9c..36019ee873 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java @@ -15,6 +15,7 @@ import org.json.JSONObject; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.utils.StringUtils; public class MatchIT extends SQLIntegTestCase { @Override @@ -35,4 +36,14 @@ public void match_in_having() throws IOException { verifySchema(result, schema("lastname", "text")); verifyDataRows(result, rows("Bates")); } + + @Test + public void missing_field_test() { + String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT); + final RuntimeException exception = + expectThrows(RuntimeException.class, () -> executeJdbcRequest(query)); + assertTrue(exception.getMessage() + .contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") && + exception.getMessage().contains("SemanticCheckException")); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java index 398a7a9d94..8562a001a5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java @@ -11,6 +11,7 @@ import org.json.JSONObject; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.utils.StringUtils; public class QueryStringIT extends SQLIntegTestCase { @Override @@ -65,4 +66,14 @@ public void wildcard_test() throws IOException { JSONObject result3 = executeJdbcRequest(query3); assertEquals(10, result3.getInt("total")); } + + @Test + public void missing_field_test() { + String query = StringUtils.format("SELECT * FROM %s WHERE query_string([invalid], 'beer')", TEST_INDEX_BEER); + final RuntimeException exception = + expectThrows(RuntimeException.class, () -> executeJdbcRequest(query)); + assertTrue(exception.getMessage() + .contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") && + exception.getMessage().contains("SemanticCheckException")); + } }