Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CAST function #952

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/ppl-lang/functions/ppl-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
+------------+--------+--------+---------+-------------+--------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() |
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| DATE | Note1 | N/A | N/A | N/A | | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIME | Note1 | N/A | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+--------+
```
Note: Spark does not support the `TIME` type. Using the `CAST` function will convert it to **STRING**.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, TIME is not a supported data type in Spark PPL, how about delete it in user doc and code?
cast("12:34:56" as time) should throw syntax exception IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataTypeTransformer will translate spark unsupported datatype into StringType by default. Let's see if it will hurt our existing functionality if changing that to throw exception for unsupported datatype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for throw exception. Not seen any reason to convert to String. @YANG-DB any thoughts? Might it be some special logic in SQL repo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know of such reason - I agree


Cast to **string** example:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import java.sql.Date
import java.sql.Timestamp

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCastITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createNullableJsonContentTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cast number to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_string = cast(id as string),
| id_double = cast(id as double),
| id_long = cast(id as long),
| id_boolean = cast(id as boolean)
| | fields id, id_string, id_double, id_long, id_boolean | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id", "IntegerType"),
("id_string", "StringType"),
("id_double", "DoubleType"),
("id_long", "LongType"),
("id_boolean", "BooleanType"))))
assertSameRows(Seq(Row(1, "1", 1.0, 1L, true)), frame)
}

test("test cast string to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_int = cast(cast(id as string) as int),
| cast_true = cast("True" as boolean),
| cast_false = cast("false" as boolean),
| cast_timestamp = cast("2024-11-26 23:39:06" as timestamp),
| cast_date = cast("2024-11-26" as date),
| cast_time = cast("12:34:56" as time)
| | fields id_int, cast_true, cast_false, cast_timestamp, cast_date, cast_time | head 1
| """.stripMargin)

// Note: Spark doesn't support data type of `Time`, cast it to StringTypes by default.
assert(
frame.dtypes.sameElements(Array(
("id_int", "IntegerType"),
("cast_true", "BooleanType"),
("cast_false", "BooleanType"),
("cast_timestamp", "TimestampType"),
("cast_date", "DateType"),
("cast_time", "StringType"))))
assertSameRows(
Seq(
Row(
1,
true,
false,
Timestamp.valueOf("2024-11-26 23:39:06"),
Date.valueOf("2024-11-26"),
"12:34:56")),
frame)
}

test("test cast time related types to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| timestamp = cast("2024-11-26 23:39:06" as timestamp),
| ts_str = cast(timestamp as string),
| ts_date = cast(timestamp as date),
| date_str = cast(ts_date as string),
| date_ts = cast(ts_date as timestamp)
| | fields timestamp, ts_str, ts_date, date_str, date_ts | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("timestamp", "TimestampType"),
("ts_str", "StringType"),
("ts_date", "DateType"),
("date_str", "StringType"),
("date_ts", "TimestampType"))))
assertSameRows(
Seq(
Row(
Timestamp.valueOf("2024-11-26 23:39:06"),
"2024-11-26 23:39:06",
Date.valueOf("2024-11-26"),
"2024-11-26",
Timestamp.valueOf("2024-11-26 00:00:00"))),
frame)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ primaryExpression
: evalFunctionCall
| fieldExpression
| literalValue
| dataTypeFunctionCall
;

positionFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
Expand Down Expand Up @@ -188,6 +189,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitCast(Cast node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of cast
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class Cast extends UnresolvedExpression {
private final UnresolvedExpression expression;
private final DataType dataType;

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(expression);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitCast(this, context);
}

@Override
public String toString() {
return String.format("CAST(%s AS %s)", expression, dataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,15 @@ public enum DataType {
INTERVAL(ExprCoreType.INTERVAL);

@Getter private final ExprCoreType coreType;

public static DataType fromString(String name) {
String upperName = name.toUpperCase();
// Cover some dataType alias
switch (upperName) {
YANG-DB marked this conversation as resolved.
Show resolved Hide resolved
case "INT":
return INTEGER;
default:
return valueOf(upperName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
Expand Down Expand Up @@ -466,6 +468,17 @@ public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext c
return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false));
}

@Override
public Expression visitCast(Cast node, CatalystPlanContext context) {
analyze(node.getExpression(), context);
Optional<Expression> ret = context.popNamedParseExpressions();
if (ret.isEmpty()) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", node.getExpression()));
}
return context.getNamedParseExpressions().push(Cast$.MODULE$.apply(ret.get(), translate(node.getDataType()), false));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
Expand Down Expand Up @@ -279,9 +280,9 @@ public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFuncti
return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg());
}

@Override
public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
@Override public UnresolvedExpression visitDataTypeFunctionCall(OpenSearchPPLParser.DataTypeFunctionCallContext ctx) {
// TODO: for long term consideration, needs to implement DataTypeBuilder/Visitor to parse all data types
return new Cast(this.visit(ctx.expression()), DataType.fromString(ctx.convertedDataType().getText()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qianheng-aws can u plz check is it simpler using this DataTypeTransformer here instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used DataTypeTransformer in CatalystExpressionVisitor. DataTypeTransformer will translate opensearch datatype into spark datatype, while it's all about opensearch AST here.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
Expand Down Expand Up @@ -49,8 +50,10 @@ static <T> Seq<T> seq(List<T> list) {

static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
switch (source.getCoreType()) {
case TIME:
case DATE:
return DateType$.MODULE$;
case TIMESTAMP:
return DataTypes.TimestampType;
case INTEGER:
return IntegerType$.MODULE$;
case LONG:
Expand Down Expand Up @@ -120,4 +123,4 @@ static String translate(SpanUnit unit) {
}
return "";
}
}
}
Loading
Loading