diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index ec26efc0..806b2537 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -1312,6 +1312,18 @@ public DataFrame join( JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy)))); } + public DataFrame join(Column func) { + return new DataFrame(this.df.join(func.toScalaColumn())); + } + + public DataFrame join(Column func, Column[] partitionBy, Column[] orderBy) { + return new DataFrame( + this.df.join( + func.toScalaColumn(), + JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(partitionBy)), + JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy)))); + } + com.snowflake.snowpark.DataFrame getScalaDataFrame() { return this.df; } diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index 267dc38d..f60e4cb1 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -41,7 +41,8 @@ public static TableFunction split_to_table() { } public static Column split_to_table(Column str, String delimiter) { - return new Column(com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); + return new Column( + com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); } /** @@ -81,4 +82,11 @@ public static Column split_to_table(Column str, String delimiter) { public static TableFunction flatten() { return new TableFunction(com.snowflake.snowpark.tableFunctions.flatten()); } + + public static Column flatten( + Column input, String path, boolean outer, boolean recursive, String mode) { + return new Column( + com.snowflake.snowpark.tableFunctions.flatten( + input.toScalaColumn(), path, outer, recursive, mode)); + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java index 225f0d5b..1d8c16f9 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java @@ -88,4 +88,31 @@ public void argumentInTableFunction() { getSession().tableFunction(new TableFunction("flatten").call(args)).select("value"), new Row[] {Row.create("77"), Row.create("88")}); } + + @Test + public void argumentInSplitToTable() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("split by space")}, + StructType.create(new StructField("col", DataTypes.StringType))); + checkAnswer( + df.join(TableFunctions.split_to_table(df.col("col"), " ")).select("value"), + new Row[] {Row.create("split"), Row.create("by"), Row.create("space")}); + } + + @Test + public void argumentInFlatten() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("{\"a\":1, \"b\":[77, 88]}")}, + StructType.create(new StructField("col", DataTypes.StringType))); + checkAnswer( + df.join( + TableFunctions.flatten( + Functions.parse_json(df.col("col")), "b", true, true, "both")) + .select("value"), + new Row[] {Row.create("77"), Row.create("88")}); + } }