Skip to content

Commit

Permalink
Merge pull request #103 from mlverse/updates
Browse files Browse the repository at this point in the history
Adds coverage tests
  • Loading branch information
edgararuiz authored Jan 2, 2024
2 parents df25e7c + 12b5424 commit b4ee5a8
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 52 deletions.
1 change: 0 additions & 1 deletion .github/workflows/spark-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:
fail-fast: false
matrix:
config:
- {spark: '3.5.0', pyspark: '3.5', hadoop: '3', name: 'PySpark 3.5'}
- {spark: '3.4.1', pyspark: '3.4', hadoop: '3', name: 'PySpark 3.4'}

env:
Expand Down
4 changes: 2 additions & 2 deletions R/ml-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ get_params <- function(x) {
})
}

ml_installed <- function() {
ml_installed <- function(envname = NULL) {
ml_libraries <- pysparklyr_env$ml_libraries
installed_libraries <- py_list_packages()$package
installed_libraries <- py_list_packages(envname = envname)$package
find_ml <- map_lgl(ml_libraries, ~ .x %in% installed_libraries)
if (!all(find_ml)) {
cli_div(theme = cli_colors())
Expand Down
2 changes: 1 addition & 1 deletion R/python-use-envname.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ use_envname <- function(
choice <- menu(choices = c("Yes", "No", "Cancel"))
if (choice == 1) {
ret <- set_names(envname, "prompt")
rlang::exec(
exec(
.fn = glue("install_{backend}"),
version = version,
as_job = FALSE
Expand Down
29 changes: 17 additions & 12 deletions R/sparklyr-spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,29 +216,24 @@ build_user_agent <- function() {
}

if (is.null(product)) {
check_rstudio <- try(RStudio.Version(), silent = TRUE)
if (!inherits(check_rstudio, "try-error")) {
if (check_rstudio()) {
rstudio_version <- int_rstudio_version()
prod <- "rstudio"

edition <- check_rstudio$edition
edition <- rstudio_version$edition
if (length(edition) == 0) edition <- ""

mod <- check_rstudio$mode
mod <- rstudio_version$mode
if (length(mod) == 0) mod <- ""

if (edition == "Professional") {
if (mod == "server") {
prod <- "workbench-rstudio"
} else {
prod <- "rstudio-pro"
}
}

if (Sys.getenv("R_CONFIG_ACTIVE") == "rstudio_cloud") {
prod <- "cloud-rstudio"
}

product <- glue("posit-{prod}/{check_rstudio$long_version}")
product <- glue("posit-{prod}/{rstudio_version$long_version}")
}
}

Expand All @@ -250,6 +245,12 @@ build_user_agent <- function() {
)
}

int_rstudio_version <- function() {
out <- try(RStudio.Version(), silent = TRUE)
if(!inherits(out, "try-error")) return(out)
return(NULL)
}

connection_label <- function(x) {
x <- x[[1]]
ret <- "Connection"
Expand All @@ -261,8 +262,12 @@ connection_label <- function(x) {
method <- con$method
}
if (!is.null(method)) {
if (method == "spark_connect" | method == "pyspark") ret <- "Spark Connect"
if (method == "databricks_connect" | method == "databricks") ret <- "Databricks Connect"
if (method == "spark_connect" | method == "pyspark") {
ret <- "Spark Connect"
}
if (method == "databricks_connect" | method == "databricks") {
ret <- "Databricks Connect"
}
}
ret
}
12 changes: 12 additions & 0 deletions tests/testthat/_snaps/dplyr.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@

# Misc functions

Code
tbl_am[1]
Output
# Source: spark<?> [?? x 1]
# Groups: am
am
<dbl>
1 0
2 1

---

Code
tbl_join
Output
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/_snaps/ml-utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ml_formula() works

Code
ml_formula(am ~ mpg, mtcars)
Output
$label
[1] "am"
$features
[1] "mpg"

# ml_installed() works on simulated interactive session

Code
ml_installed(envname = test_env)
Message
! Required Python libraries to run ML functions are missing
Could not find: torch, torcheval, and scikit-learn
Do you wish to install? (This will be a one time operation)

23 changes: 20 additions & 3 deletions tests/testthat/helper-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ tests_disable_all <- function() {
r_scripts <- dir_ls(test_path(), glob = "*.R")
test_scripts <- r_scripts[substr(path_file(r_scripts), 1, 5) == "test-"]
map(
test_scripts, ~{
test_scripts, ~ {
ln <- readLines(.x)
writeLines(c("skip(\"temp\")", ln), con = .x)
}
Expand All @@ -89,7 +89,7 @@ tests_enable_all <- function() {
r_scripts <- dir_ls(test_path(), glob = "*.R")
test_scripts <- r_scripts[substr(path_file(r_scripts), 1, 5) == "test-"]
map(
test_scripts, ~{
test_scripts, ~ {
ln <- readLines(.x)
new_ln <- ln[ln != "skip(\"temp\")"]
writeLines(new_ln, con = .x)
Expand All @@ -102,7 +102,7 @@ test_databricks_cluster_id <- function() {
}

test_databricks_cluster_version <- function() {
if(is.null(.test_env$dbr)) {
if (is.null(.test_env$dbr)) {
dbr <- databricks_dbr_version(
cluster_id = test_databricks_cluster_id(),
host = databricks_host(),
Expand All @@ -112,3 +112,20 @@ test_databricks_cluster_version <- function() {
}
.test_env$dbr
}

test_databricks_stump_env <- function() {
env_name <- use_envname(
version = test_databricks_cluster_version(),
backend = "databricks"
)
env_path <- path(use_test_env(), env_name)
if (names(env_name) != "exact") {
py_install(
package = "numpy",
envname = env_path,
pip = TRUE,
python = Sys.which("python")
)
}
path(env_path, "bin", "python")
}
17 changes: 3 additions & 14 deletions tests/testthat/test-deploy.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
skip_if_not_databricks()

test_databricks_deploy_env_path <- function() {
env_name <- use_envname(
version = test_databricks_cluster_version(),
backend = "databricks"
)
if (names(env_name) != "exact") {
py_install("numpy", env_name, pip = TRUE, python = Sys.which("python"))
}
path(reticulate::virtualenv_python(env_name))
}

test_databricks_deploy_output <- function() {
list(
appDir = path(getwd()),
python = test_databricks_deploy_env_path(),
python = test_databricks_stump_env(),
envVars = c("DATABRICKS_HOST", "DATABRICKS_TOKEN"),
server = "my_server",
account = "my_account",
Expand Down Expand Up @@ -44,7 +33,7 @@ test_that("Basic use, passing DBR version works", {
accounts = function(...) accounts_df()
)
# Initializes environment
invisible(test_databricks_deploy_env_path())
invisible(test_databricks_stump_env())

expect_equal(
deploy_databricks(version = test_databricks_cluster_version()),
Expand Down Expand Up @@ -209,7 +198,7 @@ test_that("Rare cases for finding environments works", {
withr::with_envvar(
new = c("WORKON_HOME" = use_test_env()),
{
env_path <- test_databricks_deploy_env_path()
env_path <- test_databricks_stump_env()
local_mocked_bindings(
py_exe = function(...) {
return(NULL)
Expand Down
60 changes: 58 additions & 2 deletions tests/testthat/test-dplyr.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
test_that("head() works", {
tbl_mtcars <- use_test_table_mtcars()

expect_equal(
tbl_mtcars %>%
head(5) %>%
collect() %>%
nrow(),
5
)
})

test_that("copy_to() works", {
tbl_ordered <- use_test_table_mtcars() %>%
arrange(mpg, qsec, hp)
Expand All @@ -8,21 +20,27 @@ test_that("copy_to() works", {
})

test_that("Sampling functions works", {
tbl_n <- use_test_table_mtcars() %>%
tbl_mtcars <- use_test_table_mtcars()
tbl_n <- tbl_mtcars %>%
sample_n(5) %>%
collect() %>%
count() %>%
pull()

expect_equal(tbl_n, 5)

tbl_frac <- use_test_table_mtcars() %>%
tbl_frac <- tbl_mtcars %>%
sample_frac(0.2) %>%
collect() %>%
count() %>%
pull()

expect_lt(tbl_frac, 30)

expect_error(
sample_frac(tbl_mtcars, size = 0.5, weight = mpg),
"`weight` is not supported"
)
})

test_that("Misc functions", {
Expand All @@ -33,9 +51,47 @@ test_that("Misc functions", {

expect_silent(compute(tbl_am, name = "am"))

expect_silent(compute(tbl_am, name = NULL))

tbl_join <- use_test_table_mtcars() %>%
left_join(tbl_am, by = "am") %>%
arrange(mpg, qsec, hp)

expect_error(tbl_ptype(tbl_am))

expect_snapshot(tbl_am[1])

expect_snapshot(tbl_join)
})

test_that("sdf_copy_to() workks", {
sc <- use_test_spark_connect()
test_df <- data.frame(a = 1:1000, b= 1:1000)
expect_s3_class(
sdf_copy_to(sc, test_df, name = "test_df"),
"tbl_pyspark"
)
expect_error(
sdf_copy_to(sc, test_df, name = "test_df") ,
"Temp table test_df already exists, use `overwrite = TRUE` to replace"
)
expect_s3_class(
sdf_copy_to(sc, test_df, name = "test_df", overwrite = TRUE),
"tbl_pyspark"
)
expect_s3_class(
sdf_copy_to(sc, test_df, name = "test_df", overwrite = TRUE, repartition = 2),
"tbl_pyspark"
)
})

test_that("sdf_register() works", {
tbl_mtcars <- use_test_table_mtcars()
sc <- use_test_spark_connect()
obj <- python_sdf(tbl_mtcars)
py_obj <- as_spark_pyobj(obj, sc)
expect_s3_class(
sdf_register(py_obj),
"tbl_pyspark"
)
})
56 changes: 56 additions & 0 deletions tests/testthat/test-ml-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
test_that("ml_formula() works", {
expect_snapshot(ml_formula(am ~ mpg, mtcars))
expect_error(
ml_formula(am ~ mpg * cyl, mtcars),
"Formula resulted in an invalid parameter set"
)
})

test_that("snake_to_camel() works", {
expect_equal(
snake_to_camel("var_one"),
"varOne"
)
})

test_that("ml_connect_not_supported() works", {
expect_silent(
ml_connect_not_supported(
args = list(),
not_supported = c(
"elastic_net_param", "reg_param", "threshold",
"aggregation_depth", "fit_intercept",
"raw_prediction_col", "uid", "weight_col"
)
)
)

expect_error(
ml_connect_not_supported(
args = list(reg_param = 1),
not_supported = c(
"elastic_net_param", "reg_param", "threshold",
"aggregation_depth", "fit_intercept",
"raw_prediction_col", "uid", "weight_col"
),
"The following argument(s) are not supported by Spark Connect:"
)
)
})

test_that("ml_installed() works on simulated interactive session", {
skip_if_not_databricks()
test_env <- test_databricks_stump_env() %>%
path_dir() %>%
path_dir()
print(test_env)
local_mocked_bindings(
check_interactive = function(...) TRUE,
check_rstudio = function(...) TRUE,
menu = function(...) return(1),
py_install = function(...) invisible()
)
expect_snapshot(
ml_installed(envname = test_env)
)
})
Loading

0 comments on commit b4ee5a8

Please sign in to comment.