diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 26ddd3613..5a61992de 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -118,6 +118,7 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` - `source = table | eval r = coalesce(a, b, c) | fields r` - `source = table | eval e = isempty(a) | fields e` - `source = table | eval e = isblank(a) | fields e` +- `source = table | eval e = cast(a as timestamp) | fields e` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` - `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` @@ -486,4 +487,11 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions +#### **Cast** +[See additional command details](functions/ppl-conversion.md) +- `source = table | eval int_to_string = cast(1 as string) | fields int_to_string` +- `source = table | eval int_to_string = cast(int_col as string), string_to_int = cast(string_col as integer) | fields int_to_string, string_to_int` +- `source = table | eval cdate = CAST('2012-08-07' as date), ctime = cast('2012-08-07T08:07:06' as timestamp) | fields cdate, ctime` +- `source = table | eval chained_cast = cast(cast("true" as boolean) as integer) | fields chained_cast` + --- diff --git a/docs/ppl-lang/functions/ppl-conversion.md b/docs/ppl-lang/functions/ppl-conversion.md index 48e4106ca..7d3535936 100644 --- a/docs/ppl-lang/functions/ppl-conversion.md +++ b/docs/ppl-lang/functions/ppl-conversion.md @@ -7,22 +7,21 @@ `cast(expr as dateType)` cast the expr to dataType. return the value of dataType. The following conversion rules are used: ``` -+------------+--------+--------+---------+-------------+--------+--------+ -| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE | TIME | -+------------+--------+--------+---------+-------------+--------+--------+ -| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() | TIME() | -+------------+--------+--------+---------+-------------+--------+--------+ -| NUMBER | Note1 | | v!=0 | N/A | N/A | N/A | -+------------+--------+--------+---------+-------------+--------+--------+ -| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A | -+------------+--------+--------+---------+-------------+--------+--------+ -| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() | -+------------+--------+--------+---------+-------------+--------+--------+ -| DATE | Note1 | N/A | N/A | N/A | | N/A | -+------------+--------+--------+---------+-------------+--------+--------+ -| TIME | Note1 | N/A | N/A | N/A | N/A | | -+------------+--------+--------+---------+-------------+--------+--------+ ++------------+--------+--------+---------+-------------+--------+ +| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE | ++------------+--------+--------+---------+-------------+--------+ +| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() | ++------------+--------+--------+---------+-------------+--------+ +| NUMBER | Note1 | | v!=0 | N/A | N/A | ++------------+--------+--------+---------+-------------+--------+ +| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | ++------------+--------+--------+---------+-------------+--------+ +| TIMESTAMP | Note1 | N/A | N/A | | DATE() | ++------------+--------+--------+---------+-------------+--------+ +| DATE | Note1 | N/A | N/A | N/A | | ++------------+--------+--------+---------+-------------+--------+ ``` +- `NUMBER` includes `INTEGER`, `LONG`, `FLOAT`, `DOUBLE`. Cast to **string** example: @@ -36,7 +35,7 @@ Cast to **string** example: Cast to **number** example: - os> source=people | eval `cbool` = CAST(true as int), `cstring` = CAST('1' as int) | fields `cbool`, `cstring` + os> source=people | eval `cbool` = CAST(true as integer), `cstring` = CAST('1' as integer) | fields `cbool`, `cstring` fetched rows / total rows = 1/1 +---------+-----------+ | cbool | cstring | @@ -46,13 +45,13 @@ Cast to **number** example: Cast to **date** example: - os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctime` = CAST('01:01:01' as time), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctime`, `ctimestamp` + os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctimestamp` fetched rows / total rows = 1/1 - +------------+----------+---------------------+ - | cdate | ctime | ctimestamp | - |------------+----------+---------------------| - | 2012-08-07 | 01:01:01 | 2012-08-07 01:01:01 | - +------------+----------+---------------------+ + +------------+---------------------+ + | cdate | ctimestamp | + |------------+---------------------| + | 2012-08-07 | 2012-08-07 01:01:01 | + +------------+---------------------+ Cast function can be **chained**: diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCastITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCastITSuite.scala new file mode 100644 index 000000000..a9b01b9e3 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCastITSuite.scala @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.sql.Date +import java.sql.Timestamp + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCastITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test table + createNullableJsonContentTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test cast number to compatible data types") { + val frame = sql(s""" + | source=$testTable | eval + | id_string = cast(id as string), + | id_double = cast(id as double), + | id_long = cast(id as long), + | id_boolean = cast(id as boolean) + | | fields id, id_string, id_double, id_long, id_boolean | head 1 + | """.stripMargin) + + assert( + frame.dtypes.sameElements( + Array( + ("id", "IntegerType"), + ("id_string", "StringType"), + ("id_double", "DoubleType"), + ("id_long", "LongType"), + ("id_boolean", "BooleanType")))) + assertSameRows(Seq(Row(1, "1", 1.0, 1L, true)), frame) + } + + test("test cast string to compatible data types") { + val frame = sql(s""" + | source=$testTable | eval + | id_int = cast(cast(id as string) as integer), + | cast_true = cast("True" as boolean), + | cast_false = cast("false" as boolean), + | cast_timestamp = cast("2024-11-26 23:39:06" as timestamp), + | cast_date = cast("2024-11-26" as date) + | | fields id_int, cast_true, cast_false, cast_timestamp, cast_date | head 1 + | """.stripMargin) + + assert( + frame.dtypes.sameElements( + Array( + ("id_int", "IntegerType"), + ("cast_true", "BooleanType"), + ("cast_false", "BooleanType"), + ("cast_timestamp", "TimestampType"), + ("cast_date", "DateType")))) + assertSameRows( + Seq( + Row( + 1, + true, + false, + Timestamp.valueOf("2024-11-26 23:39:06"), + Date.valueOf("2024-11-26"))), + frame) + } + + test("test cast time related types to compatible data types") { + val frame = sql(s""" + | source=$testTable | eval + | timestamp = cast("2024-11-26 23:39:06" as timestamp), + | ts_str = cast(timestamp as string), + | ts_date = cast(timestamp as date), + | date_str = cast(ts_date as string), + | date_ts = cast(ts_date as timestamp) + | | fields timestamp, ts_str, ts_date, date_str, date_ts | head 1 + | """.stripMargin) + + assert( + frame.dtypes.sameElements( + Array( + ("timestamp", "TimestampType"), + ("ts_str", "StringType"), + ("ts_date", "DateType"), + ("date_str", "StringType"), + ("date_ts", "TimestampType")))) + assertSameRows( + Seq( + Row( + Timestamp.valueOf("2024-11-26 23:39:06"), + "2024-11-26 23:39:06", + Date.valueOf("2024-11-26"), + "2024-11-26", + Timestamp.valueOf("2024-11-26 00:00:00"))), + frame) + } + +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 4ce21040e..073e6a332 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -445,6 +445,7 @@ primaryExpression : evalFunctionCall | fieldExpression | literalValue + | dataTypeFunctionCall ; positionFunction diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 1fd6a56fc..8ff59dee0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -13,6 +13,7 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; @@ -193,6 +194,10 @@ public T visitFunction(Function node, C context) { return visitChildren(node, context); } + public T visitCast(Cast node, C context) { + return visitChildren(node, context); + } + public T visitLambdaFunction(LambdaFunction node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cast.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cast.java new file mode 100644 index 000000000..0668fbf7b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cast.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of cast + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class Cast extends UnresolvedExpression { + private final UnresolvedExpression expression; + private final DataType dataType; + + @Override + public List getChild() { + return Collections.singletonList(expression); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCast(this, context); + } + + @Override + public String toString() { + return String.format("CAST(%s AS %s)", expression, dataType); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java index 9843158b4..6f0de02f5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java @@ -30,4 +30,8 @@ public enum DataType { INTERVAL(ExprCoreType.INTERVAL); @Getter private final ExprCoreType coreType; + + public static DataType fromString(String name) { + return valueOf(name.toUpperCase()); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 4c8d117b3..35ac7ed47 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -10,6 +10,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.Cast$; import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -41,6 +42,7 @@ import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; @@ -466,6 +468,17 @@ public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext c return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false)); } + @Override + public Expression visitCast(Cast node, CatalystPlanContext context) { + analyze(node.getExpression(), context); + Optional ret = context.popNamedParseExpressions(); + if (ret.isEmpty()) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", node.getExpression())); + } + return context.getNamedParseExpressions().push(Cast$.MODULE$.apply(ret.get(), translate(node.getDataType()), false)); + } + private List visitExpressionList(List expressionList, CatalystPlanContext context) { return expressionList.isEmpty() ? emptyList() diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index e5098e4a1..d7636f526 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -19,6 +19,7 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; @@ -278,9 +279,9 @@ public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFuncti return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg()); } - @Override - public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) { - return new Literal(ctx.getText(), DataType.STRING); + @Override public UnresolvedExpression visitDataTypeFunctionCall(OpenSearchPPLParser.DataTypeFunctionCallContext ctx) { + // TODO: for long term consideration, needs to implement DataTypeBuilder/Visitor to parse all data types + return new Cast(this.visit(ctx.expression()), DataType.fromString(ctx.convertedDataType().getText())); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index e4defad52..f583d7847 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -9,6 +9,7 @@ import org.apache.spark.sql.types.BooleanType$; import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DateType$; import org.apache.spark.sql.types.DoubleType$; import org.apache.spark.sql.types.FloatType$; @@ -49,8 +50,12 @@ static Seq seq(List list) { static DataType translate(org.opensearch.sql.ast.expression.DataType source) { switch (source.getCoreType()) { - case TIME: + case DATE: return DateType$.MODULE$; + case TIMESTAMP: + return DataTypes.TimestampType; + case STRING: + return DataTypes.StringType; case INTEGER: return IntegerType$.MODULE$; case LONG: @@ -68,7 +73,7 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { case UNDEFINED: return NullType$.MODULE$; default: - return StringType$.MODULE$; + throw new IllegalArgumentException("Unsupported data type for Spark: " + source); } } @@ -120,4 +125,4 @@ static String translate(SpanUnit unit) { } return ""; } -} \ No newline at end of file +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCastTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCastTestSuite.scala new file mode 100644 index 000000000..829b7ff1f --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCastTestSuite.scala @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.types.{IntegerType, StringType} + +class PPLLogicalPlanCastTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test cast with case sensitive") { + val table = UnresolvedRelation(Seq("t")) + val expectedPlan = Project( + seq(UnresolvedStar(None)), + Project( + seq(UnresolvedStar(None), Alias(Cast(UnresolvedAttribute("a"), StringType), "a")()), + table)) + + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, """source=t | eval a = cast(a as STRING)"""), context) + comparePlans(expectedPlan, logPlan, false) + + // test case insensitive + val context2 = new CatalystPlanContext + val logPlan2 = + planTransformer.visit( + plan(pplParser, """source=t | eval a = cast(a as string)"""), + context2) + comparePlans(expectedPlan, logPlan2, false) + } + + test("test cast literal") { + val table = UnresolvedRelation(Seq("t")) + val expectedPlan = Project( + seq(UnresolvedStar(None)), + Project( + seq( + UnresolvedStar(None), + Alias(Cast(Cast(Literal("a"), IntegerType), StringType), "a")()), + table)) + + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t | eval a = cast(cast("a" as INTEGER) as STRING)"""), + context) + comparePlans(expectedPlan, logPlan, false) + } + + test("test chained cast") { + val table = UnresolvedRelation(Seq("t")) + val expectedPlan = Project( + seq(UnresolvedStar(None)), + Project( + seq( + UnresolvedStar(None), + Alias(Cast(Cast(UnresolvedAttribute("a"), IntegerType), StringType), "a")()), + table)) + + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, """source=t | eval a = cast(cast(a as INTEGER) as STRING)"""), + context) + comparePlans(expectedPlan, logPlan, false) + } + + test("test cast with unsupported dataType") { + // Unsupported data type for opensearch parser + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException] { + planTransformer.visit( + plan(pplParser, """source=t | eval a = cast(a as UNSUPPORTED_DATATYPE)"""), + context) + } + assert( + exception.getMessage.contains( + "Failed to parse query due to offending symbol [UNSUPPORTED_DATATYPE]")) + + // Unsupported data type for Spark + val context2 = new CatalystPlanContext + val exception2 = intercept[IllegalArgumentException] { + planTransformer.visit(plan(pplParser, """source=t | eval a = cast(a as time)"""), context2) + } + assert(exception2.getMessage == "Unsupported data type for Spark: TIME") + } + +}