Skip to content

Commit

Permalink
predict ordinal factors from ordinal regression models (#1217)
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Oct 29, 2024
1 parent a4f9811 commit a9889f0
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 6 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

## Bug Fixes

* Make sure that parsnip does not convert ordered factor predictions to be unordered.

* Ensure that `knit_engine_docs()` has the required packages installed (#1156).

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
Expand Down
1 change: 1 addition & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
#' \itemize{
#' \item \code{lvl}: If the outcome is a factor, this contains
#' the factor levels at the time of model fitting.
#' \item \code{ordered}: If the outcome is a factor, was it an ordered factor?
#' \item \code{spec}: The model specification object
#' (\code{object} in the call to \code{fit})
#' \item \code{fit}: when the model is executed without error,
Expand Down
5 changes: 3 additions & 2 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ form_form <-
fit_call <- make_form_call(object, env = env)

res <- list(
lvl = y_levels,
lvl = y_levels$lvl,
ordered = y_levels$ordered,
spec = object
)

Expand Down Expand Up @@ -98,7 +99,7 @@ xy_xy <- function(object,

fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)
res <- list(lvl = levels(env$y), ordered = is.ordered(env$y), spec = object)

time <- proc.time()
res$fit <- eval_mod(
Expand Down
7 changes: 5 additions & 2 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ convert_arg <- function(x) {

levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark")) {
res <- NULL
res <- list(lvls = NULL, ordered = FALSE)
} else {
res <- levels(eval_tidy(rlang::f_lhs(f), dat))
res <- list()
y_data <- eval_tidy(rlang::f_lhs(f), dat)
res$lvls <- levels(y_data)
res$ordered <- is.ordered(y_data)
}
res
}
Expand Down
6 changes: 4 additions & 2 deletions R/predict_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ predict_class.model_fit <- function(object, new_data, ...) {

# coerce levels to those in `object`
if (is.vector(res) || is.factor(res)) {
res <- factor(as.character(res), levels = object$lvl)
res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered)
} else {
if (!inherits(res, "tbl_spark")) {
# Now case where a parsnip model generated `res`
if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) {
res <- res[[1]]
} else {
res$values <- factor(as.character(res$values), levels = object$lvl)
res$values <- factor(as.character(res$values),
levels = object$lvl,
ordered = object$ordered)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions man/fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions tests/testthat/test-predict_formats.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,50 @@ test_that('classification predictions', {
c(".pred_high", ".pred_low"))
})


test_that('ordinal classification predictions', {
skip_if_not_installed("modeldata")
skip_if_not_installed("rpart")

set.seed(382)
dat_tr <-
modeldata::sim_multinomial(
200,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))
dat_te <-
modeldata::sim_multinomial(
5,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))

###

mod_f_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit(class ~ ., data = dat_tr)
expect_true("ordered" %in% names(mod_f_fit))
mod_f_pred <- predict(mod_f_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))

###

mod_xy_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit_xy(x = dat_tr %>% dplyr::select(-class), dat_tr$class)

expect_true("ordered" %in% names(mod_xy_fit))
mod_xy_pred <- predict(mod_xy_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))
})


test_that('non-standard levels', {
expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1])))
expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1])))
Expand Down

0 comments on commit a9889f0

Please sign in to comment.