Skip to content

Commit

Permalink
Merge pull request #106 from mlverse/udf
Browse files Browse the repository at this point in the history
Adds support for R UDFs
  • Loading branch information
edgararuiz authored Feb 2, 2024
2 parents 93439a6 + a300510 commit 93c022e
Show file tree
Hide file tree
Showing 21 changed files with 388 additions and 49 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/spark-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::devtools
extra-packages: |
any::devtools
any::arrow
needs: check

- name: Set up Python 3.10
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
extra-packages: |
any::covr
any::devtools
any::arrow
needs: coverage

- name: Cache Spark
Expand Down
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: pysparklyr
Title: Provides a 'PySpark' Back-End for the 'sparklyr' Package
Version: 0.1.3
Version: 0.1.3.9000
Authors@R: c(
person("Edgar", "Ruiz", , "[email protected]", role = c("aut", "cre")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
Expand All @@ -22,7 +22,7 @@ Imports:
reticulate (>= 1.33),
methods,
rlang,
sparklyr (>= 1.8.4),
sparklyr (>= 1.8.4.9004),
tidyselect,
fs,
magrittr,
Expand All @@ -41,3 +41,5 @@ Suggests:
tibble,
withr
Config/testthat/edition: 3
Remotes:
sparklyr/sparklyr
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ S3method(sdf_copy_to,pyspark_connection)
S3method(sdf_read_column,spark_pyjobj)
S3method(sdf_register,spark_pyobj)
S3method(sdf_schema,tbl_pyspark)
S3method(spark_apply,tbl_pyspark)
S3method(spark_connect_method,spark_method_databricks_connect)
S3method(spark_connect_method,spark_method_spark_connect)
S3method(spark_connection,pyspark_connection)
Expand Down Expand Up @@ -157,6 +158,7 @@ importFrom(sparklyr,sdf_copy_to)
importFrom(sparklyr,sdf_read_column)
importFrom(sparklyr,sdf_register)
importFrom(sparklyr,sdf_schema)
importFrom(sparklyr,spark_apply)
importFrom(sparklyr,spark_connect_method)
importFrom(sparklyr,spark_connection)
importFrom(sparklyr,spark_dataframe)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# pysparklyr dev

### New

* Adds support for `spark_apply()` via the `rpy2` Python library.

# pysparklyr 0.1.3

### New
Expand Down
30 changes: 5 additions & 25 deletions R/ml-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,9 @@ get_params <- function(x) {
}

ml_installed <- function(envname = NULL) {
ml_libraries <- pysparklyr_env$ml_libraries
installed_libraries <- py_list_packages(envname = envname)$package
find_ml <- map_lgl(ml_libraries, ~ .x %in% installed_libraries)
if (!all(find_ml)) {
cli_div(theme = cli_colors())
msg1 <- "Required Python libraries to run ML functions are missing"
if (check_interactive()) {
missing_ml <- ml_libraries[!find_ml]
cli_alert_warning(msg1)
cli_bullets(c(
" " = "{.header Could not find: {missing_ml}}",
" " = "Do you wish to install? {.class (This will be a one time operation)}"
))
choice <- menu(choices = c("Yes", "Cancel"))
if (choice == 1) {
py_install(missing_ml)
}
if (choice == 2) {
stop_quietly()
}
} else {
cli_abort(msg1)
}
cli_end()
}
py_check_installed(
envname = envname,
libraries = pysparklyr_env$ml_libraries,
msg = "Required Python libraries to run ML functions are missing"
)
}
2 changes: 1 addition & 1 deletion R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' @importFrom sparklyr spark_write_orc spark_write_json spark_write_table
#' @importFrom sparklyr ml_pipeline ml_predict ml_transform ml_fit
#' @importFrom sparklyr ml_logistic_regression ft_standard_scaler ft_max_abs_scaler
#' @importFrom sparklyr ml_save ml_load spark_jobj spark_install_find
#' @importFrom sparklyr ml_save ml_load spark_jobj spark_install_find spark_apply
#' @importFrom tidyselect tidyselect_data_has_predicates
#' @importFrom dplyr tbl collect tibble same_src compute as_tibble group_vars
#' @importFrom dplyr sample_n sample_frac slice_sample select tbl_ptype group_by
Expand Down
3 changes: 2 additions & 1 deletion R/python-install.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ install_environment <- function(
"PyArrow",
"grpcio",
"google-api-python-client",
"grpcio_status"
"grpcio_status",
"rpy2"
)

if (add_torch && install_ml) {
Expand Down
212 changes: 212 additions & 0 deletions R/sparklyr-spark-apply.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#' @export
spark_apply.tbl_pyspark <- function(
x,
f,
columns = NULL,
memory = TRUE,
group_by = NULL,
packages = NULL,
context = NULL,
name = NULL,
barrier = NULL,
fetch_result_as_sdf = TRUE,
partition_index_param = "",
arrow_max_records_per_batch = NULL,
auto_deps = FALSE,
...) {
py_check_installed(
libraries = "rpy2",
msg = "Requires an additional Python library"
)
cli_div(theme = cli_colors())
if (!is.null(packages)) {
cli_abort("`packages` is not yet supported for this backend")
}
if (!is.null(context)) {
cli_abort("`context` is not supported for this backend")
}
if (auto_deps) {
cli_abort("`auto_deps` is not supported for this backend")
}
if (partition_index_param != "") {
cli_abort("`partition_index_param` is not supported for this backend")
}
if (!is.null(arrow_max_records_per_batch)) {
sc <- python_sdf(x)$sparkSession
conf_name <- "spark.sql.execution.arrow.maxRecordsPerBatch"
conf_curr <- sc$conf$get(conf_name)
conf_req <- as.character(arrow_max_records_per_batch)
if(conf_curr != conf_req) {
cli_div(theme = cli_colors())
cli_inform(
"{.header Changing {.emph {conf_name}} to: {prettyNum(conf_req, big.mark = ',')}}"
)
cli_end()
sc$conf$set(conf_name, conf_req)
}
}
cli_end()
sa_in_pandas(
x = x,
.f = f,
.schema = columns,
.group_by = group_by,
.as_sdf = fetch_result_as_sdf,
.name = name,
.barrier = barrier,
... = ...
)
}

sa_in_pandas <- function(
x,
.f,
...,
.schema = NULL,
.schema_arg = "columns",
.group_by = NULL,
.as_sdf = TRUE,
.name = NULL,
.barrier = NULL) {
schema_msg <- FALSE
if (is.null(.schema)) {
r_fn <- .f %>%
sa_function_to_string(
.r_only = TRUE,
.group_by = .group_by,
.colnames = NULL,
... = ...
) %>%
rlang::parse_expr() %>%
eval()
r_df <- x %>%
head(10) %>%
collect()
r_exec <- r_fn(r_df)
col_names <- colnames(r_exec)
col_names <- gsub("\\.", "_", col_names)
colnames(r_exec) <- col_names
.schema <- r_exec %>%
imap(~ {
x_class <- class(.x)
if ("POSIXt" %in% x_class) x_class <- "timestamp"
if (x_class == "character") x_class <- "string"
if (x_class == "numeric") x_class <- "double"
if (x_class == "integer") x_class <- "long"
paste0(.y, " ", x_class)
}) %>%
paste0(collapse = ", ")
schema_msg <- TRUE
} else {
fields <- unlist(strsplit(.schema, ","))
col_names <- map_chr(fields, ~ unlist(strsplit(trimws(.x), " "))[[1]])
col_names <- gsub("\\.", "_", col_names)
}
.f %>%
sa_function_to_string(
.group_by = .group_by,
.colnames = col_names,
... = ...
) %>%
py_run_string()
main <- reticulate::import_main()
df <- python_sdf(x)
if (is.null(df)) {
df <- x %>%
compute() %>%
python_sdf()
}
if (!is.null(.group_by)) {
# TODO: Add support for multiple grouping columns
renamed_gp <- paste0("_", .group_by)
w_gp <- df$withColumn(colName = renamed_gp, col = df[.group_by])
tbl_gp <- w_gp$groupby(renamed_gp)
p_df <- tbl_gp$applyInPandas(
main$r_apply,
schema = .schema
)
} else {
p_df <- df$mapInPandas(
main$r_apply,
schema = .schema,
barrier = .barrier %||% FALSE
)
}
if (.as_sdf) {
ret <- tbl_pyspark_temp(
x = p_df,
conn = spark_connection(x),
tmp_name = .name
)
} else {
ret <- to_pandas_cleaned(p_df)
}
if(schema_msg) {
schema_arg <- .schema_arg
schema <- .schema
cli_div(theme = cli_colors())
cli_inform(c(
"{.header To increase performance, use the following schema:}",
"{.emph {schema_arg} = \"{schema}\" }"
))
cli_end()
}
ret
}

sa_function_to_string <- function(
.f,
.group_by = NULL,
.r_only = FALSE,
.colnames = NULL,
...
) {
path_scripts <- system.file("udf", package = "pysparklyr")
if(dir_exists("inst/udf")) {
path_scripts <- path_expand("inst/udf")
}
udf_fn <- ifelse(is.null(.group_by), "map", "apply")
fn_r <- paste0(
readLines(path(path_scripts, glue("udf-{udf_fn}.R"))),
collapse = ""
)
fn_python <- paste0(
readLines(path(path_scripts, glue("udf-{udf_fn}.py"))),
collapse = "\n"
)
if (!is.null(.group_by)) {
fn_r <- gsub(
"gp_field <- 'am'",
paste0("gp_field <- '", .group_by, "'"),
fn_r
)
}
if(is.null(.colnames)) {
.colnames <- "NULL"
} else {
.colnames <- paste0("'", .colnames, "'", collapse = ", ")
}
fn_r <- gsub(
"col_names <- c\\('am', 'x'\\)",
paste0("col_names <- c(", .colnames, ")"),
fn_r
)
fn <- purrr::as_mapper(.f = .f, ... = ...)
fn_str <- paste0(deparse(fn), collapse = "")
if (inherits(fn, "rlang_lambda_function")) {
fn_str <- paste0(
"function(...) {x <- (",
fn_str,
"); x(...)}"
)
}
fn_str <- gsub("\"", "'", fn_str)
fn_rep <- "function\\(\\.\\.\\.\\) 1"
fn_r_new <- gsub(fn_rep, fn_str, fn_r)
if (.r_only) {
ret <- fn_r_new
} else {
ret <- gsub(fn_rep, fn_r_new, fn_python)
}
ret
}
6 changes: 6 additions & 0 deletions R/sparklyr-spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,16 @@ initialize_connection <- function(
message = "'SparkSession' object has no attribute 'setLocalProperty'",
module = "pyspark"
)
warnings$filterwarnings(
"ignore",
message = "Index.format is deprecated and will be removed in a future version"
)
session <- conn$getOrCreate()
get_version <- try(session$version, silent = TRUE)
if (inherits(get_version, "try-error")) databricks_dbr_error(get_version)
session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L)
session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true")
session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")

# do we need this `spark_context` object?
spark_context <- list(spark_context = session)
Expand Down
30 changes: 30 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,36 @@ current_product_connect <- function() {
out
}

py_check_installed <- function(
envname = NULL,
libraries = "",
msg = ""
) {
installed_libraries <- py_list_packages(envname = envname)$package
find_libs <- map_lgl(libraries, ~ .x %in% installed_libraries)
if (!all(find_libs)) {
cli_div(theme = cli_colors())
if (check_interactive()) {
missing_lib <- libraries[!find_libs]
cli_alert_warning(msg)
cli_bullets(c(
" " = "{.header Could not find: {missing_lib}}",
" " = "Do you wish to install? {.class (This will be a one time operation)}"
))
choice <- menu(choices = c("Yes", "Cancel"))
if (choice == 1) {
py_install(missing_lib)
}
if (choice == 2) {
stop_quietly()
}
} else {
cli_abort(msg)
}
cli_end()
}
}

stop_quietly <- function() {
opt <- options(show.error.messages = FALSE)
on.exit(options(opt))
Expand Down
Loading

0 comments on commit 93c022e

Please sign in to comment.