Skip to content

Commit

Permalink
Merge pull request #89 from mlverse/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Dec 14, 2023
2 parents 532d4d6 + c6ddffb commit 2c96dce
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 38 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_utf8_character)
importFrom(rlang,enquo)
importFrom(rlang,exec)
importFrom(rlang,is_character)
importFrom(rlang,is_string)
importFrom(rlang,parse_exprs)
Expand Down
36 changes: 23 additions & 13 deletions R/databricks-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,29 @@ databricks_token <- function(token = NULL, fail = FALSE) {
set_names(token, name)
}

databricks_dbr_version <- function(cluster_id,
host = NULL,
token = NULL) {
databricks_dbr_version_name <- function(cluster_id,
host = NULL,
token = NULL) {
bullets <- NULL
cli_div(theme = cli_colors())
cli_alert_warning(
"{.header Retrieving version from cluster }{.emph '{cluster_id}'}"
cli_progress_step(
"{.header Retrieving info for cluster:}{.emph '{cluster_id}'}"
)
cluster_info <- databricks_dbr_info(
cluster_id = cluster_id,
host = host,
token = token
)
cluster_name <- substr(cluster_info$cluster_name, 1, 100)
sp_version <- cluster_info$spark_version
if (!is.null(sp_version)) {
sp_sep <- unlist(strsplit(sp_version, "\\."))
version <- paste0(sp_sep[1], ".", sp_sep[2])
cli_bullets(c(
" " = "{.class Cluster version: }{.emph '{version}'}"
))
} else {
version <- ""
}
cli_end()
version
list(version = version, name = cluster_name)
}

databricks_dbr_info <- function(cluster_id,
Expand Down Expand Up @@ -124,6 +123,17 @@ databricks_dbr_info <- function(cluster_id,
out
}

databricks_dbr_version <- function(cluster_id,
host = NULL,
token = NULL) {
vn <- databricks_dbr_version_name(
cluster_id = cluster_id,
host = host,
token = token
)
vn$version
}

databricks_cluster_get <- function(cluster_id,
host = NULL,
token = NULL) {
Expand Down Expand Up @@ -192,14 +202,14 @@ sanitize_host <- function(url) {
if (ret != url) {
cli_div(theme = cli_colors())
cli_alert_warning(
"{.header Sanitizing Databricks Host ({.code master}) entry:}"
"{.header Changing host URL to:} {.emph {ret}}"
)
cli_bullets(c(
" " = "{.header Original: {.emph {url}}}",
" " = "{.header Using:} {.emph {ret}}",
# " " = "{.header Original: {.emph {url}}}",
# " " = "{.header Using:} {.emph {ret}}",
" " = paste0(
"{.class Set {.code host_sanitize = FALSE} ",
"in {.code spark_connect()} to avoid this change}"
"in {.code spark_connect()} to avoid changing it}"
)
))

Expand Down
5 changes: 3 additions & 2 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
#' @importFrom dplyr filter mutate
#' @importFrom purrr map_lgl map_chr map pmap_chr imap discard
#' @importFrom purrr map_lgl map_chr map pmap_chr imap
#' @importFrom rlang enquo `!!` `!!!` quo_is_null sym arg_match warn abort `%||%`
#' @importFrom rlang is_string is_character as_utf8_character parse_exprs set_names
#' @importFrom rlang enquo `!!` `!!!` quo_is_null sym warn abort `%||%`
#' @importFrom rlang is_string is_character parse_exprs set_names
#' @importFrom rlang exec arg_match as_utf8_character
#' @importFrom methods new is setOldClass
#' @importFrom tidyselect matches
#' @importFrom utils head type.convert compareVersion
Expand Down
60 changes: 37 additions & 23 deletions R/sparklyr-spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ spark_connect_method.spark_method_spark_connect <- function(
extensions,
scala_version,
...) {

version <- version %||% Sys.getenv("SPARK_VERSION")
version <- version %||% Sys.getenv("SPARK_VERSION")

if (version == "") {
cli_abort("Spark `version` is required, please provide")
Expand Down Expand Up @@ -67,18 +66,19 @@ spark_connect_method.spark_method_databricks_connect <- function(
host_sanitize <- args$host_sanitize %||% TRUE

method <- method[[1]]
token <- databricks_token(token, fail = TRUE)
token <- databricks_token(token, fail = FALSE)
cluster_id <- cluster_id %||% Sys.getenv("DATABRICKS_CLUSTER_ID")
master <- databricks_host(master)
if (host_sanitize) {
master <- databricks_host(master, fail = FALSE)
if (host_sanitize && master != "") {
master <- sanitize_host(master)
}
if (is.null(version) && !is.null(cluster_id)) {
version <- databricks_dbr_version(
cluster_id = cluster_id,
host = master,
token = token
)

cluster_info <- NULL
if (cluster_id != "" && master != "" && token != "") {
cluster_info <- databricks_dbr_version_name(cluster_id, master, token)
if (is.null(version)) {
version <- cluster_info$version
}
}

envname <- use_envname(
Expand All @@ -90,21 +90,35 @@ spark_connect_method.spark_method_databricks_connect <- function(
)

db <- import_check("databricks.connect", envname)
remote <- db$DatabricksSession$builder$remote(
host = master,
token = token,
cluster_id = cluster_id
)
user_agent <- build_user_agent()
conn <- remote$userAgent(user_agent)
con_class <- "connect_databricks"
cluster_info <- databricks_dbr_info(cluster_id, master, token)
cluster_name <- substr(cluster_info$cluster_name, 1, 100)
master_label <- glue("{cluster_name} ({cluster_id})")

if (!is.null(cluster_info)) {
msg <- "{.header Connecting to} '{.emph {cluster_info$name}}' {.header (DBR '{.emph {version}}')}"
master_label <- glue("{cluster_info$name} ({cluster_id})")
} else {
msg <- "{.header Connecting to} '{.emph {cluster_id}}'"
master_label <- glue("Databricks Connect - Cluster: {cluster_id}")
}

cli_div(theme = cli_colors())
cli_progress_step(msg)
cli_end()

remote_args <- list()
if (master != "") remote_args$host <- master
if (token != "") remote_args$token <- token
if (cluster_id != "") remote_args$cluster_id <- cluster_id

databricks_session <- function(...) {
user_agent <- build_user_agent()
db$DatabricksSession$builder$remote(...)$userAgent(user_agent)
}

conn <- exec(databricks_session, !!!remote_args)

initialize_connection(
conn = conn,
master_label = master_label,
con_class = con_class,
con_class = "connect_databricks",
cluster_id = cluster_id,
method = method,
config = config
Expand Down

0 comments on commit 2c96dce

Please sign in to comment.