Skip to content

Commit

Permalink
support spark prior to 3.5 with its extended table identifier (existi…
Browse files Browse the repository at this point in the history
…ng table identifier only has 2 parts)

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 15, 2024
1 parent a9e7c6e commit 7bcce2f
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,62 +36,79 @@ class FlintSparkPPLFieldSummaryITSuite
}
}

test("test fillnull with one null replacement value and one column") {
// val frame = sql(s"""
// | source = $testTable | fieldsummary includefields= status_code, id, response_time topvalues=5 nulls=true
// | """.stripMargin)

test("test fieldsummary with single field includefields(status_code) & nulls=true ") {
val frame = sql(s"""
| SELECT
| 'status_code' AS Field,
| COUNT(status_code) AS Count,
| COUNT(DISTINCT status_code) AS Distinct,
| MIN(status_code) AS Min,
| MAX(status_code) AS Max,
| AVG(CAST(status_code AS DOUBLE)) AS Avg,
| typeof(status_code) AS Type,
| (SELECT COLLECT_LIST(STRUCT(status_code, count_status))
| FROM (
| SELECT status_code, COUNT(*) AS count_status
| FROM $testTable
| GROUP BY status_code
| ORDER BY count_status DESC
| LIMIT 5
| )) AS top_values,
| COUNT(*) - COUNT(status_code) AS Nulls
| FROM $testTable
| GROUP BY typeof(status_code)
|
| UNION ALL
|
| SELECT
| 'id' AS Field,
| COUNT(id) AS Count,
| COUNT(DISTINCT id) AS Distinct,
| MIN(id) AS Min,
| MAX(id) AS Max,
| AVG(CAST(id AS DOUBLE)) AS Avg,
| typeof(id) AS Type,
| (SELECT COLLECT_LIST(STRUCT(id, count_id))
| FROM (
| SELECT id, COUNT(*) AS count_id
| FROM $testTable
| GROUP BY id
| ORDER BY count_id DESC
| LIMIT 5
| )) AS top_values,
| COUNT(*) - COUNT(id) AS Nulls
| FROM $testTable
| GROUP BY typeof(id)
|""".stripMargin)
| source = $testTable | fieldsummary includefields= status_code nulls=true
| """.stripMargin)

/*
val frame = sql(s"""
| SELECT
| 'status_code' AS Field,
| COUNT(status_code) AS Count,
| COUNT(DISTINCT status_code) AS Distinct,
| MIN(status_code) AS Min,
| MAX(status_code) AS Max,
| AVG(CAST(status_code AS DOUBLE)) AS Avg,
| typeof(status_code) AS Type,
| COUNT(*) - COUNT(status_code) AS Nulls
| FROM $testTable
| GROUP BY typeof(status_code)
| """.stripMargin)
*/

// val frame = sql(s"""
// | SELECT
// | 'status_code' AS Field,
// | COUNT(status_code) AS Count,
// | COUNT(DISTINCT status_code) AS Distinct,
// | MIN(status_code) AS Min,
// | MAX(status_code) AS Max,
// | AVG(CAST(status_code AS DOUBLE)) AS Avg,
// | typeof(status_code) AS Type,
// | (SELECT COLLECT_LIST(STRUCT(status_code, count_status))
// | FROM (
// | SELECT status_code, COUNT(*) AS count_status
// | FROM $testTable
// | GROUP BY status_code
// | ORDER BY count_status DESC
// | LIMIT 5
// | )) AS top_values,
// | COUNT(*) - COUNT(status_code) AS Nulls
// | FROM $testTable
// | GROUP BY typeof(status_code)
// |
// | UNION ALL
// |
// | SELECT
// | 'id' AS Field,
// | COUNT(id) AS Count,
// | COUNT(DISTINCT id) AS Distinct,
// | MIN(id) AS Min,
// | MAX(id) AS Max,
// | AVG(CAST(id AS DOUBLE)) AS Avg,
// | typeof(id) AS Type,
// | (SELECT COLLECT_LIST(STRUCT(id, count_id))
// | FROM (
// | SELECT id, COUNT(*) AS count_id
// | FROM $testTable
// | GROUP BY id
// | ORDER BY count_id DESC
// | LIMIT 5
// | )) AS top_values,
// | COUNT(*) - COUNT(id) AS Nulls
// | FROM $testTable
// | GROUP BY typeof(id)
// |""".stripMargin)

val results: Array[Row] = frame.collect()
// Print each row in a readable format
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// scalastyle:off println
results.foreach(row => println(row.mkString(", ")))
println(logicalPlan)
// scalastyle:on println

// val logicalPlan: LogicalPlan = frame.queryExecution.logical
// val expectedPlan = ?
// comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.flint.spark.ppl.OpenSearchPPLParser;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Field;
Expand All @@ -18,7 +17,6 @@
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.List;
import java.util.stream.Collectors;

import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.INCLUDEFIELDS;
import static org.opensearch.flint.spark.ppl.OpenSearchPPLParser.NULLS;
Expand All @@ -30,7 +28,7 @@
public class FieldSummary extends UnresolvedPlan {
private List<Field> includeFields;
private int topValues;
private boolean nulls;
private boolean ignoreNull;
private List<UnresolvedExpression> collect;
private UnresolvedPlan child;

Expand All @@ -40,7 +38,7 @@ public FieldSummary(List<UnresolvedExpression> collect) {
.forEach(exp -> {
switch (((NamedExpression) exp).getExpressionId()) {
case NULLS:
this.nulls = (boolean) ((Literal) exp.getChild().get(0)).getValue();
this.ignoreNull = (boolean) ((Literal) exp.getChild().get(0)).getValue();
break;
case TOPVALUES:
this.topValues = (int) ((Literal) exp.getChild().get(0)).getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public enum BuiltinFunctionName {
AVG(FunctionName.of("avg")),
SUM(FunctionName.of("sum")),
COUNT(FunctionName.of("count")),
COUNT_DISTINCT(FunctionName.of("count_distinct")),
MIN(FunctionName.of("min")),
MAX(FunctionName.of("max")),
// sample variance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,27 @@ public LogicalPlan visitHead(Head node, CatalystPlanContext context) {
}

@Override
/**
* 'Union false, false
* :- 'Aggregate ['typeof('status_code)], [status_code AS Field#20, 'COUNT('status_code) AS Count#21, 'COUNT(distinct 'status_code) AS Distinct#22, 'MIN('status_code) AS Min#23, 'MAX('status_code) AS Max#24, 'AVG(cast('status_code as double)) AS Avg#25, 'typeof('status_code) AS Type#26, scalar-subquery#28 [] AS top_values#29, ('COUNT(1) - 'COUNT('status_code)) AS Nulls#30]
* : : +- 'Project [unresolvedalias('COLLECT_LIST(struct(status_code, 'status_code, count_status, 'count_status)), None)]
* : : +- 'SubqueryAlias __auto_generated_subquery_name
* : : +- 'GlobalLimit 5
* : : +- 'LocalLimit 5
* : : +- 'Sort ['count_status DESC NULLS LAST], true
* : : +- 'Aggregate ['status_code], ['status_code, 'COUNT(1) AS count_status#27]
* : : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* +- 'Aggregate ['typeof('id)], [id AS Field#31, 'COUNT('id) AS Count#32, 'COUNT(distinct 'id) AS Distinct#33, 'MIN('id) AS Min#34, 'MAX('id) AS Max#35, 'AVG(cast('id as double)) AS Avg#36, 'typeof('id) AS Type#37, scalar-subquery#39 [] AS top_values#40, ('COUNT(1) - 'COUNT('id)) AS Nulls#41]
* : +- 'Project [unresolvedalias('COLLECT_LIST(struct(id, 'id, count_id, 'count_id)), None)]
* : +- 'SubqueryAlias __auto_generated_subquery_name
* : +- 'GlobalLimit 5
* : +- 'LocalLimit 5
* : +- 'Sort ['count_id DESC NULLS LAST], true
* : +- 'Aggregate ['id], ['id, 'COUNT(1) AS count_id#38]
* : +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
* +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false
*/
public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) {
fieldSummary.getChild().get(0).accept(this, context);
return FieldSummaryTransformer.translate(fieldSummary, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

import java.util.List;
import java.util.Optional;

import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;
Expand All @@ -27,30 +29,36 @@
*/
public interface AggregatorTranslator {

static String aggregationAlias(BuiltinFunctionName functionName, QualifiedName name) {
return functionName.name()+"("+name.toString()+")";
}

static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) {
if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty())
throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName());

boolean distinct = aggregateFunction.getDistinct();
// Additional aggregation function operators will be added here
switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) {
BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get();
switch (functionName) {
case MAX:
return new UnresolvedFunction(seq("MAX"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("MAX"), seq(arg), distinct, empty(),false);
case MIN:
return new UnresolvedFunction(seq("MIN"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("MIN"), seq(arg), distinct, empty(),false);
case AVG:
return new UnresolvedFunction(seq("AVG"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("AVG"), seq(arg), distinct, empty(),false);
case COUNT:
return new UnresolvedFunction(seq("COUNT"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("COUNT"), seq(arg), distinct, empty(),false);
case SUM:
return new UnresolvedFunction(seq("SUM"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("SUM"), seq(arg), distinct, empty(),false);
case STDDEV_POP:
return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), distinct, empty(),false);
case STDDEV_SAMP:
return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), distinct, empty(),false);
case PERCENTILE:
return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false);
case PERCENTILE_APPROX:
return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false);
return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false);
}
throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName());
}
Expand Down
Loading

0 comments on commit 7bcce2f

Please sign in to comment.