Skip to content

Commit

Permalink
clarify case weight support in show_model_info() (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Apr 4, 2024
1 parent e5c7f92 commit eb526fa
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 20 deletions.
14 changes: 10 additions & 4 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ show_model_info <- function(model) {
) %>%
dplyr::select(engine, mode, has_wts)

engines %>%
engine_weight_info <- engines %>%
dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
dplyr::mutate(
engine = paste0(engine, has_wts),
Expand All @@ -1005,9 +1005,15 @@ show_model_info <- function(model) {
lab = paste0(" ", mode, engine, "\n")
) %>%
dplyr::ungroup() %>%
dplyr::pull(lab) %>%
cat(sep = "")
cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
dplyr::pull(lab)

cat(engine_weight_info, sep = "")

if (!all(weight_info$has_wts == "")) {
cat("\n", cli::symbol$sup_1, "The model can use case weights.", sep = "")
}

cat("\n\n")
} else {
cat(" no registered engines.\n\n")
}
Expand Down
98 changes: 98 additions & 0 deletions tests/testthat/_snaps/registration.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,101 @@
Error in `check_mode_for_new_engine()`:
! "regression" is not a known mode for model `sponge()`.

# showing model info

Code
show_model_info("rand_forest")
Output
Information for `rand_forest`
modes: unknown, classification, regression, censored regression
engines:
classification: randomForest, ranger1, spark
regression: randomForest, ranger1, spark
1The model can use case weights.
arguments:
ranger:
mtry --> mtry
trees --> num.trees
min_n --> min.node.size
randomForest:
mtry --> mtry
trees --> ntree
min_n --> nodesize
spark:
mtry --> feature_subset_strategy
trees --> num_trees
min_n --> min_instances_per_node
fit modules:
engine mode
ranger classification
ranger regression
randomForest classification
randomForest regression
spark classification
spark regression
prediction modules:
mode engine methods
classification randomForest class, prob, raw
classification ranger class, conf_int, prob, raw
classification spark class, prob
regression randomForest numeric, raw
regression ranger conf_int, numeric, raw
regression spark numeric

---

Code
show_model_info("mlp")
Output
Information for `mlp`
modes: unknown, classification, regression
engines:
classification: brulee, keras, nnet
regression: brulee, keras, nnet
arguments:
keras:
hidden_units --> hidden_units
penalty --> penalty
dropout --> dropout
epochs --> epochs
activation --> activation
nnet:
hidden_units --> size
penalty --> decay
epochs --> maxit
brulee:
hidden_units --> hidden_units
penalty --> penalty
epochs --> epochs
dropout --> dropout
learn_rate --> learn_rate
activation --> activation
fit modules:
engine mode
keras regression
keras classification
nnet regression
nnet classification
brulee regression
brulee classification
prediction modules:
mode engine methods
classification brulee class, prob
classification keras class, prob, raw
classification nnet class, prob, raw
regression brulee numeric
regression keras numeric, raw
regression nnet numeric, raw

21 changes: 5 additions & 16 deletions tests/testthat/test_registration.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,21 +496,10 @@ test_that('adding a new predict method', {


test_that('showing model info', {
expect_output(
show_model_info("rand_forest"),
"Information for `rand_forest`"
)
expect_output(
show_model_info("rand_forest"),
"trees --> ntree"
)
expect_output(
show_model_info("rand_forest"),
"fit modules:"
)
expect_output(
show_model_info("rand_forest"),
"prediction modules:"
)
expect_snapshot(show_model_info("rand_forest"))

# ensure that we don't mention case weight support when the
# notation would be ambiguous (#1000)
expect_snapshot(show_model_info("mlp"))
})

0 comments on commit eb526fa

Please sign in to comment.