Skip to content

Commit

Permalink
Support CAST function (#952)
Browse files Browse the repository at this point in the history
* Support CAST command

Signed-off-by: Heng Qian <[email protected]>

* Add a TODO

Signed-off-by: Heng Qian <[email protected]>

* Address comments

Signed-off-by: Heng Qian <[email protected]>

* Remove support for DataType alias

Signed-off-by: Heng Qian <[email protected]>

* Address comments

Signed-off-by: Heng Qian <[email protected]>

---------

Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws authored Nov 29, 2024
1 parent 3ff2ef2 commit 3ad88d9
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 28 deletions.
8 changes: 8 additions & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
- `source = table | eval e = isblank(a) | fields e`
- `source = table | eval e = cast(a as timestamp) | fields e`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
Expand Down Expand Up @@ -486,4 +487,11 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols`

> ppl-correlation-command is an experimental command - it may be removed in future versions
#### **Cast**
[See additional command details](functions/ppl-conversion.md)
- `source = table | eval int_to_string = cast(1 as string) | fields int_to_string`
- `source = table | eval int_to_string = cast(int_col as string), string_to_int = cast(string_col as integer) | fields int_to_string, string_to_int`
- `source = table | eval cdate = CAST('2012-08-07' as date), ctime = cast('2012-08-07T08:07:06' as timestamp) | fields cdate, ctime`
- `source = table | eval chained_cast = cast(cast("true" as boolean) as integer) | fields chained_cast`

---
43 changes: 21 additions & 22 deletions docs/ppl-lang/functions/ppl-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@
`cast(expr as dateType)` cast the expr to dataType. return the value of dataType. The following conversion rules are used:

```
+------------+--------+--------+---------+-------------+--------+--------+
| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE | TIME |
+------------+--------+--------+---------+-------------+--------+--------+
| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() | TIME() |
+------------+--------+--------+---------+-------------+--------+--------+
| NUMBER | Note1 | | v!=0 | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() |
+------------+--------+--------+---------+-------------+--------+--------+
| DATE | Note1 | N/A | N/A | N/A | | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIME | Note1 | N/A | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+--------+
+------------+--------+--------+---------+-------------+--------+
| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE |
+------------+--------+--------+---------+-------------+--------+
| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() |
+------------+--------+--------+---------+-------------+--------+
| NUMBER | Note1 | | v!=0 | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() |
+------------+--------+--------+---------+-------------+--------+
| DATE | Note1 | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+
```
- `NUMBER` includes `INTEGER`, `LONG`, `FLOAT`, `DOUBLE`.

Cast to **string** example:

Expand All @@ -36,7 +35,7 @@ Cast to **string** example:

Cast to **number** example:

os> source=people | eval `cbool` = CAST(true as int), `cstring` = CAST('1' as int) | fields `cbool`, `cstring`
os> source=people | eval `cbool` = CAST(true as integer), `cstring` = CAST('1' as integer) | fields `cbool`, `cstring`
fetched rows / total rows = 1/1
+---------+-----------+
| cbool | cstring |
Expand All @@ -46,13 +45,13 @@ Cast to **number** example:

Cast to **date** example:

os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctime` = CAST('01:01:01' as time), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctime`, `ctimestamp`
os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctimestamp`
fetched rows / total rows = 1/1
+------------+----------+---------------------+
| cdate | ctime | ctimestamp |
|------------+----------+---------------------|
| 2012-08-07 | 01:01:01 | 2012-08-07 01:01:01 |
+------------+----------+---------------------+
+------------+---------------------+
| cdate | ctimestamp |
|------------+---------------------|
| 2012-08-07 | 2012-08-07 01:01:01 |
+------------+---------------------+

Cast function can be **chained**:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import java.sql.Date
import java.sql.Timestamp

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCastITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createNullableJsonContentTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cast number to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_string = cast(id as string),
| id_double = cast(id as double),
| id_long = cast(id as long),
| id_boolean = cast(id as boolean)
| | fields id, id_string, id_double, id_long, id_boolean | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id", "IntegerType"),
("id_string", "StringType"),
("id_double", "DoubleType"),
("id_long", "LongType"),
("id_boolean", "BooleanType"))))
assertSameRows(Seq(Row(1, "1", 1.0, 1L, true)), frame)
}

test("test cast string to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_int = cast(cast(id as string) as integer),
| cast_true = cast("True" as boolean),
| cast_false = cast("false" as boolean),
| cast_timestamp = cast("2024-11-26 23:39:06" as timestamp),
| cast_date = cast("2024-11-26" as date)
| | fields id_int, cast_true, cast_false, cast_timestamp, cast_date | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id_int", "IntegerType"),
("cast_true", "BooleanType"),
("cast_false", "BooleanType"),
("cast_timestamp", "TimestampType"),
("cast_date", "DateType"))))
assertSameRows(
Seq(
Row(
1,
true,
false,
Timestamp.valueOf("2024-11-26 23:39:06"),
Date.valueOf("2024-11-26"))),
frame)
}

test("test cast time related types to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| timestamp = cast("2024-11-26 23:39:06" as timestamp),
| ts_str = cast(timestamp as string),
| ts_date = cast(timestamp as date),
| date_str = cast(ts_date as string),
| date_ts = cast(ts_date as timestamp)
| | fields timestamp, ts_str, ts_date, date_str, date_ts | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("timestamp", "TimestampType"),
("ts_str", "StringType"),
("ts_date", "DateType"),
("date_str", "StringType"),
("date_ts", "TimestampType"))))
assertSameRows(
Seq(
Row(
Timestamp.valueOf("2024-11-26 23:39:06"),
"2024-11-26 23:39:06",
Date.valueOf("2024-11-26"),
"2024-11-26",
Timestamp.valueOf("2024-11-26 00:00:00"))),
frame)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ primaryExpression
: evalFunctionCall
| fieldExpression
| literalValue
| dataTypeFunctionCall
;

positionFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
Expand Down Expand Up @@ -188,6 +189,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitCast(Cast node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of cast
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class Cast extends UnresolvedExpression {
private final UnresolvedExpression expression;
private final DataType dataType;

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(expression);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitCast(this, context);
}

@Override
public String toString() {
return String.format("CAST(%s AS %s)", expression, dataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ public enum DataType {
INTERVAL(ExprCoreType.INTERVAL);

@Getter private final ExprCoreType coreType;

public static DataType fromString(String name) {
return valueOf(name.toUpperCase());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
Expand Down Expand Up @@ -466,6 +468,17 @@ public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext c
return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false));
}

@Override
public Expression visitCast(Cast node, CatalystPlanContext context) {
analyze(node.getExpression(), context);
Optional<Expression> ret = context.popNamedParseExpressions();
if (ret.isEmpty()) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", node.getExpression()));
}
return context.getNamedParseExpressions().push(Cast$.MODULE$.apply(ret.get(), translate(node.getDataType()), false));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
Expand Down Expand Up @@ -279,9 +280,9 @@ public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFuncti
return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg());
}

@Override
public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
@Override public UnresolvedExpression visitDataTypeFunctionCall(OpenSearchPPLParser.DataTypeFunctionCallContext ctx) {
// TODO: for long term consideration, needs to implement DataTypeBuilder/Visitor to parse all data types
return new Cast(this.visit(ctx.expression()), DataType.fromString(ctx.convertedDataType().getText()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
Expand Down Expand Up @@ -49,8 +50,12 @@ static <T> Seq<T> seq(List<T> list) {

static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
switch (source.getCoreType()) {
case TIME:
case DATE:
return DateType$.MODULE$;
case TIMESTAMP:
return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
case INTEGER:
return IntegerType$.MODULE$;
case LONG:
Expand All @@ -68,7 +73,7 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
case UNDEFINED:
return NullType$.MODULE$;
default:
return StringType$.MODULE$;
throw new IllegalArgumentException("Unsupported data type for Spark: " + source);
}
}

Expand Down Expand Up @@ -120,4 +125,4 @@ static String translate(SpanUnit unit) {
}
return "";
}
}
}
Loading

0 comments on commit 3ad88d9

Please sign in to comment.