Skip to content

Commit

Permalink
Support CAST command
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws committed Nov 27, 2024
1 parent 3ff2ef2 commit 1864f6e
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/ppl-lang/functions/ppl-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
+------------+--------+--------+---------+-------------+--------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() |
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| DATE | Note1 | N/A | N/A | N/A | | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIME | Note1 | N/A | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+--------+
```
Note: Spark does not support the `TIME` type. Using the `CAST` function will convert it to **STRING**.

Cast to **string** example:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import java.sql.Date
import java.sql.Timestamp

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCastITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createNullableJsonContentTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cast number to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_string = cast(id as string),
| id_double = cast(id as double),
| id_long = cast(id as long),
| id_boolean = cast(id as boolean)
| | fields id, id_string, id_double, id_long, id_boolean | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id", "IntegerType"),
("id_string", "StringType"),
("id_double", "DoubleType"),
("id_long", "LongType"),
("id_boolean", "BooleanType"))))
assertSameRows(Seq(Row(1, "1", 1.0, 1L, true)), frame)
}

test("test cast string to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_int = cast(cast(id as string) as int),
| cast_true = cast("True" as boolean),
| cast_false = cast("false" as boolean),
| cast_timestamp = cast("2024-11-26 23:39:06" as timestamp),
| cast_date = cast("2024-11-26" as date),
| cast_time = cast("12:34:56" as time)
| | fields id_int, cast_true, cast_false, cast_timestamp, cast_date, cast_time | head 1
| """.stripMargin)

// Note: Spark doesn't support data type of `Time`, cast it to StringTypes by default.
assert(
frame.dtypes.sameElements(Array(
("id_int", "IntegerType"),
("cast_true", "BooleanType"),
("cast_false", "BooleanType"),
("cast_timestamp", "TimestampType"),
("cast_date", "DateType"),
("cast_time", "StringType"))))
assertSameRows(
Seq(
Row(
1,
true,
false,
Timestamp.valueOf("2024-11-26 23:39:06"),
Date.valueOf("2024-11-26"),
"12:34:56")),
frame)
}

test("test cast time related types to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| timestamp = cast("2024-11-26 23:39:06" as timestamp),
| ts_str = cast(timestamp as string),
| ts_date = cast(timestamp as date),
| date_str = cast(ts_date as string),
| date_ts = cast(ts_date as timestamp)
| | fields timestamp, ts_str, ts_date, date_str, date_ts | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("timestamp", "TimestampType"),
("ts_str", "StringType"),
("ts_date", "DateType"),
("date_str", "StringType"),
("date_ts", "TimestampType"))))
assertSameRows(
Seq(
Row(
Timestamp.valueOf("2024-11-26 23:39:06"),
"2024-11-26 23:39:06",
Date.valueOf("2024-11-26"),
"2024-11-26",
Timestamp.valueOf("2024-11-26 00:00:00"))),
frame)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ primaryExpression
: evalFunctionCall
| fieldExpression
| literalValue
| dataTypeFunctionCall
;

positionFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
Expand Down Expand Up @@ -188,6 +189,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitCast(Cast node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of cast
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class Cast extends UnresolvedExpression {
private final UnresolvedExpression expression;
private final DataType dataType;

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(expression);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitCast(this, context);
}

@Override
public String toString() {
return String.format("CAST(%s AS %s)", expression, dataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,15 @@ public enum DataType {
INTERVAL(ExprCoreType.INTERVAL);

@Getter private final ExprCoreType coreType;

public static DataType fromString(String name) {
String upperName = name.toUpperCase();
// Cover some dataType alias
switch (upperName) {
case "INT":
return INTEGER;
default:
return valueOf(upperName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
Expand Down Expand Up @@ -466,6 +468,17 @@ public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext c
return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false));
}

@Override
public Expression visitCast(Cast node, CatalystPlanContext context) {
analyze(node.getExpression(), context);
Optional<Expression> ret = context.popNamedParseExpressions();
if (ret.isEmpty()) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", node.getExpression()));
}
return context.getNamedParseExpressions().push(Cast$.MODULE$.apply(ret.get(), translate(node.getDataType()), false));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
Expand Down Expand Up @@ -279,9 +280,8 @@ public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFuncti
return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg());
}

@Override
public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
@Override public UnresolvedExpression visitDataTypeFunctionCall(OpenSearchPPLParser.DataTypeFunctionCallContext ctx) {
return new Cast(this.visit(ctx.expression()), DataType.fromString(ctx.convertedDataType().getText()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
Expand Down Expand Up @@ -49,8 +50,10 @@ static <T> Seq<T> seq(List<T> list) {

static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
switch (source.getCoreType()) {
case TIME:
case DATE:
return DateType$.MODULE$;
case TIMESTAMP:
return DataTypes.TimestampType;
case INTEGER:
return IntegerType$.MODULE$;
case LONG:
Expand Down Expand Up @@ -120,4 +123,4 @@ static String translate(SpanUnit unit) {
}
return "";
}
}
}
Loading

0 comments on commit 1864f6e

Please sign in to comment.