diff --git a/.github/workflows/spark-tests.yaml b/.github/workflows/spark-tests.yaml index a726221..6cd45cf 100644 --- a/.github/workflows/spark-tests.yaml +++ b/.github/workflows/spark-tests.yaml @@ -38,7 +38,9 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::devtools + extra-packages: | + any::devtools + any::arrow needs: check - name: Set up Python 3.10 diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index dd88f03..a45c4f8 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -33,6 +33,7 @@ jobs: extra-packages: | any::covr any::devtools + any::arrow needs: coverage - name: Cache Spark diff --git a/DESCRIPTION b/DESCRIPTION index 5fcc279..e499b23 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: pysparklyr Title: Provides a 'PySpark' Back-End for the 'sparklyr' Package -Version: 0.1.3 +Version: 0.1.3.9000 Authors@R: c( person("Edgar", "Ruiz", , "edgar@posit.co", role = c("aut", "cre")), person(given = "Posit Software, PBC", role = c("cph", "fnd")) @@ -22,7 +22,7 @@ Imports: reticulate (>= 1.33), methods, rlang, - sparklyr (>= 1.8.4), + sparklyr (>= 1.8.4.9004), tidyselect, fs, magrittr, @@ -41,3 +41,5 @@ Suggests: tibble, withr Config/testthat/edition: 3 +Remotes: + sparklyr/sparklyr diff --git a/NAMESPACE b/NAMESPACE index 65e278f..8bb4839 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -40,6 +40,7 @@ S3method(sdf_copy_to,pyspark_connection) S3method(sdf_read_column,spark_pyjobj) S3method(sdf_register,spark_pyobj) S3method(sdf_schema,tbl_pyspark) +S3method(spark_apply,tbl_pyspark) S3method(spark_connect_method,spark_method_databricks_connect) S3method(spark_connect_method,spark_method_spark_connect) S3method(spark_connection,pyspark_connection) @@ -157,6 +158,7 @@ importFrom(sparklyr,sdf_copy_to) importFrom(sparklyr,sdf_read_column) importFrom(sparklyr,sdf_register) importFrom(sparklyr,sdf_schema) +importFrom(sparklyr,spark_apply) importFrom(sparklyr,spark_connect_method) importFrom(sparklyr,spark_connection) importFrom(sparklyr,spark_dataframe) diff --git a/NEWS.md b/NEWS.md index 52eafe4..47310e0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# pysparklyr dev + +### New + +* Adds support for `spark_apply()` via the `rpy2` Python library. + # pysparklyr 0.1.3 ### New diff --git a/R/ml-utils.R b/R/ml-utils.R index 3fb5ba8..e16a9ad 100644 --- a/R/ml-utils.R +++ b/R/ml-utils.R @@ -130,29 +130,9 @@ get_params <- function(x) { } ml_installed <- function(envname = NULL) { - ml_libraries <- pysparklyr_env$ml_libraries - installed_libraries <- py_list_packages(envname = envname)$package - find_ml <- map_lgl(ml_libraries, ~ .x %in% installed_libraries) - if (!all(find_ml)) { - cli_div(theme = cli_colors()) - msg1 <- "Required Python libraries to run ML functions are missing" - if (check_interactive()) { - missing_ml <- ml_libraries[!find_ml] - cli_alert_warning(msg1) - cli_bullets(c( - " " = "{.header Could not find: {missing_ml}}", - " " = "Do you wish to install? {.class (This will be a one time operation)}" - )) - choice <- menu(choices = c("Yes", "Cancel")) - if (choice == 1) { - py_install(missing_ml) - } - if (choice == 2) { - stop_quietly() - } - } else { - cli_abort(msg1) - } - cli_end() - } + py_check_installed( + envname = envname, + libraries = pysparklyr_env$ml_libraries, + msg = "Required Python libraries to run ML functions are missing" + ) } diff --git a/R/package.R b/R/package.R index bd0fadf..b6bb07f 100644 --- a/R/package.R +++ b/R/package.R @@ -11,7 +11,7 @@ #' @importFrom sparklyr spark_write_orc spark_write_json spark_write_table #' @importFrom sparklyr ml_pipeline ml_predict ml_transform ml_fit #' @importFrom sparklyr ml_logistic_regression ft_standard_scaler ft_max_abs_scaler -#' @importFrom sparklyr ml_save ml_load spark_jobj spark_install_find +#' @importFrom sparklyr ml_save ml_load spark_jobj spark_install_find spark_apply #' @importFrom tidyselect tidyselect_data_has_predicates #' @importFrom dplyr tbl collect tibble same_src compute as_tibble group_vars #' @importFrom dplyr sample_n sample_frac slice_sample select tbl_ptype group_by diff --git a/R/python-install.R b/R/python-install.R index 5bdc4f4..05286ae 100644 --- a/R/python-install.R +++ b/R/python-install.R @@ -217,7 +217,8 @@ install_environment <- function( "PyArrow", "grpcio", "google-api-python-client", - "grpcio_status" + "grpcio_status", + "rpy2" ) if (add_torch && install_ml) { diff --git a/R/sparklyr-spark-apply.R b/R/sparklyr-spark-apply.R new file mode 100644 index 0000000..264e545 --- /dev/null +++ b/R/sparklyr-spark-apply.R @@ -0,0 +1,212 @@ +#' @export +spark_apply.tbl_pyspark <- function( + x, + f, + columns = NULL, + memory = TRUE, + group_by = NULL, + packages = NULL, + context = NULL, + name = NULL, + barrier = NULL, + fetch_result_as_sdf = TRUE, + partition_index_param = "", + arrow_max_records_per_batch = NULL, + auto_deps = FALSE, + ...) { + py_check_installed( + libraries = "rpy2", + msg = "Requires an additional Python library" + ) + cli_div(theme = cli_colors()) + if (!is.null(packages)) { + cli_abort("`packages` is not yet supported for this backend") + } + if (!is.null(context)) { + cli_abort("`context` is not supported for this backend") + } + if (auto_deps) { + cli_abort("`auto_deps` is not supported for this backend") + } + if (partition_index_param != "") { + cli_abort("`partition_index_param` is not supported for this backend") + } + if (!is.null(arrow_max_records_per_batch)) { + sc <- python_sdf(x)$sparkSession + conf_name <- "spark.sql.execution.arrow.maxRecordsPerBatch" + conf_curr <- sc$conf$get(conf_name) + conf_req <- as.character(arrow_max_records_per_batch) + if(conf_curr != conf_req) { + cli_div(theme = cli_colors()) + cli_inform( + "{.header Changing {.emph {conf_name}} to: {prettyNum(conf_req, big.mark = ',')}}" + ) + cli_end() + sc$conf$set(conf_name, conf_req) + } + } + cli_end() + sa_in_pandas( + x = x, + .f = f, + .schema = columns, + .group_by = group_by, + .as_sdf = fetch_result_as_sdf, + .name = name, + .barrier = barrier, + ... = ... + ) +} + +sa_in_pandas <- function( + x, + .f, + ..., + .schema = NULL, + .schema_arg = "columns", + .group_by = NULL, + .as_sdf = TRUE, + .name = NULL, + .barrier = NULL) { + schema_msg <- FALSE + if (is.null(.schema)) { + r_fn <- .f %>% + sa_function_to_string( + .r_only = TRUE, + .group_by = .group_by, + .colnames = NULL, + ... = ... + ) %>% + rlang::parse_expr() %>% + eval() + r_df <- x %>% + head(10) %>% + collect() + r_exec <- r_fn(r_df) + col_names <- colnames(r_exec) + col_names <- gsub("\\.", "_", col_names) + colnames(r_exec) <- col_names + .schema <- r_exec %>% + imap(~ { + x_class <- class(.x) + if ("POSIXt" %in% x_class) x_class <- "timestamp" + if (x_class == "character") x_class <- "string" + if (x_class == "numeric") x_class <- "double" + if (x_class == "integer") x_class <- "long" + paste0(.y, " ", x_class) + }) %>% + paste0(collapse = ", ") + schema_msg <- TRUE + } else { + fields <- unlist(strsplit(.schema, ",")) + col_names <- map_chr(fields, ~ unlist(strsplit(trimws(.x), " "))[[1]]) + col_names <- gsub("\\.", "_", col_names) + } + .f %>% + sa_function_to_string( + .group_by = .group_by, + .colnames = col_names, + ... = ... + ) %>% + py_run_string() + main <- reticulate::import_main() + df <- python_sdf(x) + if (is.null(df)) { + df <- x %>% + compute() %>% + python_sdf() + } + if (!is.null(.group_by)) { + # TODO: Add support for multiple grouping columns + renamed_gp <- paste0("_", .group_by) + w_gp <- df$withColumn(colName = renamed_gp, col = df[.group_by]) + tbl_gp <- w_gp$groupby(renamed_gp) + p_df <- tbl_gp$applyInPandas( + main$r_apply, + schema = .schema + ) + } else { + p_df <- df$mapInPandas( + main$r_apply, + schema = .schema, + barrier = .barrier %||% FALSE + ) + } + if (.as_sdf) { + ret <- tbl_pyspark_temp( + x = p_df, + conn = spark_connection(x), + tmp_name = .name + ) + } else { + ret <- to_pandas_cleaned(p_df) + } + if(schema_msg) { + schema_arg <- .schema_arg + schema <- .schema + cli_div(theme = cli_colors()) + cli_inform(c( + "{.header To increase performance, use the following schema:}", + "{.emph {schema_arg} = \"{schema}\" }" + )) + cli_end() + } + ret +} + +sa_function_to_string <- function( + .f, + .group_by = NULL, + .r_only = FALSE, + .colnames = NULL, + ... + ) { + path_scripts <- system.file("udf", package = "pysparklyr") + if(dir_exists("inst/udf")) { + path_scripts <- path_expand("inst/udf") + } + udf_fn <- ifelse(is.null(.group_by), "map", "apply") + fn_r <- paste0( + readLines(path(path_scripts, glue("udf-{udf_fn}.R"))), + collapse = "" + ) + fn_python <- paste0( + readLines(path(path_scripts, glue("udf-{udf_fn}.py"))), + collapse = "\n" + ) + if (!is.null(.group_by)) { + fn_r <- gsub( + "gp_field <- 'am'", + paste0("gp_field <- '", .group_by, "'"), + fn_r + ) + } + if(is.null(.colnames)) { + .colnames <- "NULL" + } else { + .colnames <- paste0("'", .colnames, "'", collapse = ", ") + } + fn_r <- gsub( + "col_names <- c\\('am', 'x'\\)", + paste0("col_names <- c(", .colnames, ")"), + fn_r + ) + fn <- purrr::as_mapper(.f = .f, ... = ...) + fn_str <- paste0(deparse(fn), collapse = "") + if (inherits(fn, "rlang_lambda_function")) { + fn_str <- paste0( + "function(...) {x <- (", + fn_str, + "); x(...)}" + ) + } + fn_str <- gsub("\"", "'", fn_str) + fn_rep <- "function\\(\\.\\.\\.\\) 1" + fn_r_new <- gsub(fn_rep, fn_str, fn_r) + if (.r_only) { + ret <- fn_r_new + } else { + ret <- gsub(fn_rep, fn_r_new, fn_python) + } + ret +} diff --git a/R/sparklyr-spark-connect.R b/R/sparklyr-spark-connect.R index 41847ae..74c97a5 100644 --- a/R/sparklyr-spark-connect.R +++ b/R/sparklyr-spark-connect.R @@ -169,10 +169,16 @@ initialize_connection <- function( message = "'SparkSession' object has no attribute 'setLocalProperty'", module = "pyspark" ) + warnings$filterwarnings( + "ignore", + message = "Index.format is deprecated and will be removed in a future version" + ) session <- conn$getOrCreate() get_version <- try(session$version, silent = TRUE) if (inherits(get_version, "try-error")) databricks_dbr_error(get_version) session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L) + session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true") + session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") # do we need this `spark_context` object? spark_context <- list(spark_context = session) diff --git a/R/utils.R b/R/utils.R index 57de702..50a72aa 100644 --- a/R/utils.R +++ b/R/utils.R @@ -93,6 +93,36 @@ current_product_connect <- function() { out } +py_check_installed <- function( + envname = NULL, + libraries = "", + msg = "" + ) { + installed_libraries <- py_list_packages(envname = envname)$package + find_libs <- map_lgl(libraries, ~ .x %in% installed_libraries) + if (!all(find_libs)) { + cli_div(theme = cli_colors()) + if (check_interactive()) { + missing_lib <- libraries[!find_libs] + cli_alert_warning(msg) + cli_bullets(c( + " " = "{.header Could not find: {missing_lib}}", + " " = "Do you wish to install? {.class (This will be a one time operation)}" + )) + choice <- menu(choices = c("Yes", "Cancel")) + if (choice == 1) { + py_install(missing_lib) + } + if (choice == 2) { + stop_quietly() + } + } else { + cli_abort(msg) + } + cli_end() + } +} + stop_quietly <- function() { opt <- options(show.error.messages = FALSE) on.exit(options(opt)) diff --git a/inst/udf/udf-apply.R b/inst/udf/udf-apply.R new file mode 100644 index 0000000..ea6a62b --- /dev/null +++ b/inst/udf/udf-apply.R @@ -0,0 +1,20 @@ +function(df = mtcars) { + library(arrow); + fn <- function(...) 1; + fn_run <- fn(df); + gp_field <- 'am'; + col_names <- c('am', 'x'); + gp <- df[1, gp_field]; + if(is.vector(fn_run)) { + ret <- data.frame(x = fn_run); + } else { + ret <- as.data.frame(fn_run); + }; + ret[gp_field] <- gp; + cols <- colnames(ret); + gp_cols <- cols == gp_field; + new_cols <- c(cols[gp_cols], cols[!gp_cols]); + ret <- ret[, new_cols]; + if(!is.null(col_names)) colnames(ret) <- col_names; + ret +} diff --git a/inst/udf/udf-apply.py b/inst/udf/udf-apply.py new file mode 100644 index 0000000..50d4648 --- /dev/null +++ b/inst/udf/udf-apply.py @@ -0,0 +1,8 @@ +import pandas as pd +import rpy2.robjects as robjects +from rpy2.robjects import pandas2ri +def r_apply(pdf: pd.DataFrame) -> pd.DataFrame: + pandas2ri.activate() + r_func =robjects.r("function(...) 1") + ret = r_func(pdf) + return pandas2ri.rpy2py_dataframe(ret) diff --git a/inst/udf/udf-map.R b/inst/udf/udf-map.R new file mode 100644 index 0000000..1638d27 --- /dev/null +++ b/inst/udf/udf-map.R @@ -0,0 +1,13 @@ +function(df = mtcars) { + library(arrow); + fn <- function(...) 1; + col_names <- c('am', 'x'); + fn_run <- fn(df); + if(is.vector(fn_run)) { + ret <- data.frame(x = fn_run); + } else { + ret <- as.data.frame(fn_run); + }; + if(!is.null(col_names)) colnames(ret) <- col_names; + ret +} diff --git a/inst/udf/udf-map.py b/inst/udf/udf-map.py new file mode 100644 index 0000000..47c02e2 --- /dev/null +++ b/inst/udf/udf-map.py @@ -0,0 +1,8 @@ +import rpy2.robjects as robjects +from rpy2.robjects import pandas2ri +def r_apply(iterator): + for pdf in iterator: + pandas2ri.activate() + r_func = robjects.r("function(...) 1") + ret = r_func(pdf) + yield pandas2ri.rpy2py_dataframe(ret) diff --git a/tests/testthat/_snaps/data-write.md b/tests/testthat/_snaps/data-write.md index 6777464..d8cc2e1 100644 --- a/tests/testthat/_snaps/data-write.md +++ b/tests/testthat/_snaps/data-write.md @@ -4,7 +4,8 @@ spark_read_csv(sc = sc, name = "csv_1", path = file_name, overwrite = TRUE, repartition = 2) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection mpg cyl disp hp drat wt qsec vs am gear carb 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 @@ -25,7 +26,8 @@ spark_read_csv(sc = sc, name = "csv_2", path = file_name, overwrite = TRUE, columns = paste0(names(mtcars), "t")) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection mpgt cylt dispt hpt dratt wtt qsect vst amt geart carbt 1 21.0 6.0 160.0 110.0 3.9 2.62 16.46 0.0 1.0 4.0 4.0 @@ -46,7 +48,8 @@ spark_read_csv(sc = sc, name = "csv_3", path = file_name, overwrite = TRUE, memory = TRUE) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection mpg cyl disp hp drat wt qsec vs am gear carb 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 @@ -66,7 +69,8 @@ Code spark_read_parquet(sc, "csv_1", file_name, overwrite = TRUE) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection mpg cyl disp hp drat wt qsec vs am gear carb 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 @@ -86,7 +90,8 @@ Code spark_read_orc(sc, "csv_1", file_name, overwrite = TRUE) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection mpg cyl disp hp drat wt qsec vs am gear carb 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 @@ -106,7 +111,8 @@ Code spark_read_json(sc, "csv_1", file_name, overwrite = TRUE) Output - # Source: spark [?? x 11] + # Source: table [?? x 11] + # Database: spark_connection am carb cyl disp drat gear hp mpg qsec vs wt 1 1 4 6 160 3.9 4 110 21 16.5 0 2.62 diff --git a/tests/testthat/_snaps/dplyr.md b/tests/testthat/_snaps/dplyr.md index 736d66e..ab91029 100644 --- a/tests/testthat/_snaps/dplyr.md +++ b/tests/testthat/_snaps/dplyr.md @@ -3,7 +3,8 @@ Code tbl_ordered Output - # Source: spark [?? x 11] + # Source: SQL [?? x 11] + # Database: spark_connection # Ordered by: mpg, qsec, hp mpg cyl disp hp drat wt qsec vs am gear carb @@ -24,7 +25,8 @@ Code print(head(tbl_ordered)) Output - # Source: spark [?? x 11] + # Source: SQL [6 x 11] + # Database: spark_connection # Ordered by: mpg, qsec, hp mpg cyl disp hp drat wt qsec vs am gear carb @@ -40,8 +42,9 @@ Code tbl_am[1] Output - # Source: spark [?? x 1] - # Groups: am + # Source: SQL [2 x 1] + # Database: spark_connection + # Groups: am am 1 0 @@ -52,7 +55,8 @@ Code tbl_join Output - # Source: spark [?? x 11] + # Source: SQL [?? x 11] + # Database: spark_connection # Ordered by: mpg, qsec, hp mpg cyl disp hp drat wt qsec vs am gear carb diff --git a/tests/testthat/_snaps/pivot-longer.md b/tests/testthat/_snaps/pivot-longer.md index 7fe0842..775f2e5 100644 --- a/tests/testthat/_snaps/pivot-longer.md +++ b/tests/testthat/_snaps/pivot-longer.md @@ -1,9 +1,10 @@ # Pivot longer Code - tbl_pivot %>% tidyr::pivot_longer(-id, names_to = c(".value", "n"), names_sep = "_") + tbl_pivot %>% tidyr::pivot_longer(-id, names_to = c(".value", "n"), names_sep = "_") %>% + collect() Output - # Source: spark [?? x 5] + # A tibble: 4 x 5 id n z y x 1 A 1 1 2 3 diff --git a/tests/testthat/_snaps/python-install.md b/tests/testthat/_snaps/python-install.md index f0c1aec..7cad6c0 100644 --- a/tests/testthat/_snaps/python-install.md +++ b/tests/testthat/_snaps/python-install.md @@ -4,11 +4,11 @@ x Output $packages - [1] "pyspark==3.5.0" "pandas!=2.1.0" - [3] "PyArrow" "grpcio" - [5] "google-api-python-client" "grpcio_status" - [7] "torch" "torcheval" - [9] "scikit-learn" + [1] "pyspark==3.5.0" "pandas!=2.1.0" + [3] "PyArrow" "grpcio" + [5] "google-api-python-client" "grpcio_status" + [7] "rpy2" "torch" + [9] "torcheval" "scikit-learn" $envname unavailable @@ -37,6 +37,7 @@ [1] "pyspark==3.5.*" "pandas!=2.1.0" [3] "PyArrow" "grpcio" [5] "google-api-python-client" "grpcio_status" + [7] "rpy2" $envname unavailable diff --git a/tests/testthat/test-pivot-longer.R b/tests/testthat/test-pivot-longer.R index 91f9937..31a049a 100644 --- a/tests/testthat/test-pivot-longer.R +++ b/tests/testthat/test-pivot-longer.R @@ -12,7 +12,8 @@ test_that("Pivot longer", { tidyr::pivot_longer( -id, names_to = c(".value", "n"), names_sep = "_" - ) + ) %>% + collect() ) }) diff --git a/tests/testthat/test-sparklyr-spark-apply.R b/tests/testthat/test-sparklyr-spark-apply.R new file mode 100644 index 0000000..3784239 --- /dev/null +++ b/tests/testthat/test-sparklyr-spark-apply.R @@ -0,0 +1,35 @@ +test_that("spark_apply() works", { + tbl_mtcars <- use_test_table_mtcars() + expect_s3_class( + spark_apply(tbl_mtcars, nrow, group_by = "am", columns = "am double, x long"), + "tbl_spark" + ) + skip_spark_min_version(3.5) + expect_s3_class( + spark_apply(tbl_mtcars, function(x) x), + "tbl_spark" + ) + expect_s3_class( + spark_apply(tbl_mtcars, function(x) x, fetch_result_as_sdf = FALSE), + "data.frame" + ) + expect_s3_class( + spark_apply(tbl_mtcars, function(x) x, arrow_max_records_per_batch = 5000), + "tbl_spark" + ) + expect_s3_class( + spark_apply(tbl_mtcars, ~ .x), + "tbl_spark" + ) + expect_s3_class( + spark_apply(dplyr::filter(tbl_mtcars, am == 0), ~ .x), + "tbl_spark" + ) +}) + +test_that("Errors are output by specific params", { + expect_error(spark_apply(tbl_mtcars, nrow, packages = "test")) + expect_error(spark_apply(tbl_mtcars, nrow, context = "")) + expect_error(spark_apply(tbl_mtcars, nrow, auto_deps = TRUE)) + expect_error(spark_apply(tbl_mtcars, nrow, partition_index_param = "test")) +})