Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CAST function #952

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
- `source = table | eval e = isblank(a) | fields e`
- `source = table | eval e = cast(a as timestamp) | fields e`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
Expand Down Expand Up @@ -486,4 +487,11 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols`

> ppl-correlation-command is an experimental command - it may be removed in future versions

#### **Cast**
[See additional command details](functions/ppl-conversion.md)
- `source = table | eval int_to_string = cast(1 as string) | fields int_to_string`
- `source = table | eval int_to_string = cast(int_col as string), string_to_int = cast(string_col as integer) | fields int_to_string, string_to_int`
- `source = table | eval cdate = CAST('2012-08-07' as date), ctime = cast('2012-08-07T08:07:06' as timestamp) | fields cdate, ctime`
- `source = table | eval chained_cast = cast(cast("true" as boolean) as integer) | fields chained_cast`

---
43 changes: 21 additions & 22 deletions docs/ppl-lang/functions/ppl-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@
`cast(expr as dateType)` cast the expr to dataType. return the value of dataType. The following conversion rules are used:

```
+------------+--------+--------+---------+-------------+--------+--------+
| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE | TIME |
+------------+--------+--------+---------+-------------+--------+--------+
| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() | TIME() |
+------------+--------+--------+---------+-------------+--------+--------+
| NUMBER | Note1 | | v!=0 | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() | TIME() |
+------------+--------+--------+---------+-------------+--------+--------+
| DATE | Note1 | N/A | N/A | N/A | | N/A |
+------------+--------+--------+---------+-------------+--------+--------+
| TIME | Note1 | N/A | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+--------+
+------------+--------+--------+---------+-------------+--------+
| Src/Target | STRING | NUMBER | BOOLEAN | TIMESTAMP | DATE |
+------------+--------+--------+---------+-------------+--------+
| STRING | | Note1 | Note1 | TIMESTAMP() | DATE() |
+------------+--------+--------+---------+-------------+--------+
| NUMBER | Note1 | | v!=0 | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+
| BOOLEAN | Note1 | v?1:0 | | N/A | N/A |
+------------+--------+--------+---------+-------------+--------+
| TIMESTAMP | Note1 | N/A | N/A | | DATE() |
+------------+--------+--------+---------+-------------+--------+
| DATE | Note1 | N/A | N/A | N/A | |
+------------+--------+--------+---------+-------------+--------+
```
- `NUMBER` includes `INTEGER`, `LONG`, `FLOAT`, `DOUBLE`.

Cast to **string** example:

Expand All @@ -36,7 +35,7 @@ Cast to **string** example:

Cast to **number** example:

os> source=people | eval `cbool` = CAST(true as int), `cstring` = CAST('1' as int) | fields `cbool`, `cstring`
os> source=people | eval `cbool` = CAST(true as integer), `cstring` = CAST('1' as integer) | fields `cbool`, `cstring`
fetched rows / total rows = 1/1
+---------+-----------+
| cbool | cstring |
Expand All @@ -46,13 +45,13 @@ Cast to **number** example:

Cast to **date** example:

os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctime` = CAST('01:01:01' as time), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctime`, `ctimestamp`
os> source=people | eval `cdate` = CAST('2012-08-07' as date), `ctimestamp` = CAST('2012-08-07 01:01:01' as timestamp) | fields `cdate`, `ctimestamp`
fetched rows / total rows = 1/1
+------------+----------+---------------------+
| cdate | ctime | ctimestamp |
|------------+----------+---------------------|
| 2012-08-07 | 01:01:01 | 2012-08-07 01:01:01 |
+------------+----------+---------------------+
+------------+---------------------+
| cdate | ctimestamp |
|------------+---------------------|
| 2012-08-07 | 2012-08-07 01:01:01 |
+------------+---------------------+

Cast function can be **chained**:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import java.sql.Date
import java.sql.Timestamp

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCastITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createNullableJsonContentTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cast number to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_string = cast(id as string),
| id_double = cast(id as double),
| id_long = cast(id as long),
| id_boolean = cast(id as boolean)
| | fields id, id_string, id_double, id_long, id_boolean | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id", "IntegerType"),
("id_string", "StringType"),
("id_double", "DoubleType"),
("id_long", "LongType"),
("id_boolean", "BooleanType"))))
assertSameRows(Seq(Row(1, "1", 1.0, 1L, true)), frame)
}

test("test cast string to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| id_int = cast(cast(id as string) as integer),
| cast_true = cast("True" as boolean),
| cast_false = cast("false" as boolean),
| cast_timestamp = cast("2024-11-26 23:39:06" as timestamp),
| cast_date = cast("2024-11-26" as date)
| | fields id_int, cast_true, cast_false, cast_timestamp, cast_date | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("id_int", "IntegerType"),
("cast_true", "BooleanType"),
("cast_false", "BooleanType"),
("cast_timestamp", "TimestampType"),
("cast_date", "DateType"))))
assertSameRows(
Seq(
Row(
1,
true,
false,
Timestamp.valueOf("2024-11-26 23:39:06"),
Date.valueOf("2024-11-26"))),
frame)
}

test("test cast time related types to compatible data types") {
val frame = sql(s"""
| source=$testTable | eval
| timestamp = cast("2024-11-26 23:39:06" as timestamp),
| ts_str = cast(timestamp as string),
| ts_date = cast(timestamp as date),
| date_str = cast(ts_date as string),
| date_ts = cast(ts_date as timestamp)
| | fields timestamp, ts_str, ts_date, date_str, date_ts | head 1
| """.stripMargin)

assert(
frame.dtypes.sameElements(
Array(
("timestamp", "TimestampType"),
("ts_str", "StringType"),
("ts_date", "DateType"),
("date_str", "StringType"),
("date_ts", "TimestampType"))))
assertSameRows(
Seq(
Row(
Timestamp.valueOf("2024-11-26 23:39:06"),
"2024-11-26 23:39:06",
Date.valueOf("2024-11-26"),
"2024-11-26",
Timestamp.valueOf("2024-11-26 00:00:00"))),
frame)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ primaryExpression
: evalFunctionCall
| fieldExpression
| literalValue
| dataTypeFunctionCall
;

positionFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
Expand Down Expand Up @@ -188,6 +189,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitCast(Cast node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of cast
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class Cast extends UnresolvedExpression {
private final UnresolvedExpression expression;
private final DataType dataType;

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(expression);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitCast(this, context);
}

@Override
public String toString() {
return String.format("CAST(%s AS %s)", expression, dataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ public enum DataType {
INTERVAL(ExprCoreType.INTERVAL);

@Getter private final ExprCoreType coreType;

public static DataType fromString(String name) {
return valueOf(name.toUpperCase());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
Expand Down Expand Up @@ -466,6 +468,17 @@ public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext c
return context.getNamedParseExpressions().push(LambdaFunction$.MODULE$.apply(functionResult, seq(argsResult), false));
}

@Override
public Expression visitCast(Cast node, CatalystPlanContext context) {
analyze(node.getExpression(), context);
Optional<Expression> ret = context.popNamedParseExpressions();
if (ret.isEmpty()) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", node.getExpression()));
}
return context.getNamedParseExpressions().push(Cast$.MODULE$.apply(ret.get(), translate(node.getDataType()), false));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
Expand Down Expand Up @@ -279,9 +280,9 @@ public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFuncti
return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg());
}

@Override
public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
@Override public UnresolvedExpression visitDataTypeFunctionCall(OpenSearchPPLParser.DataTypeFunctionCallContext ctx) {
// TODO: for long term consideration, needs to implement DataTypeBuilder/Visitor to parse all data types
return new Cast(this.visit(ctx.expression()), DataType.fromString(ctx.convertedDataType().getText()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qianheng-aws can u plz check is it simpler using this DataTypeTransformer here instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used DataTypeTransformer in CatalystExpressionVisitor. DataTypeTransformer will translate opensearch datatype into spark datatype, while it's all about opensearch AST here.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
Expand Down Expand Up @@ -49,8 +50,12 @@ static <T> Seq<T> seq(List<T> list) {

static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
switch (source.getCoreType()) {
case TIME:
case DATE:
return DateType$.MODULE$;
case TIMESTAMP:
return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
case INTEGER:
return IntegerType$.MODULE$;
case LONG:
Expand All @@ -68,7 +73,7 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
case UNDEFINED:
return NullType$.MODULE$;
default:
return StringType$.MODULE$;
throw new IllegalArgumentException("Unsupported data type for Spark: " + source);
}
}

Expand Down Expand Up @@ -120,4 +125,4 @@ static String translate(SpanUnit unit) {
}
return "";
}
}
}
Loading
Loading