Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature suggestion: Extract splits from tune results as a resampling object #947

Open
jrosell opened this issue Oct 9, 2024 · 4 comments

Comments

@jrosell
Copy link

jrosell commented Oct 9, 2024

Feature suggestion

Now that we have the new {tailor} package for post-processing in titydmodels, I find myself in the need to reuse the splits from tune_results as a resampling object.

I believe this new extract_resamples function (or whatever name you prefer) could improve the interactive usage of tidymodels.

Here a minimal reproducible example to demonstrate its use:

# pak::pak(
#   paste0(
#     "tidymodels/",
#     c("tune", "workflows", "rsample", "tailor")
#   )
# )
library(tidyverse)
library(tidymodels)
library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered
library(tailor)
library(stacks)

# How well are our predictions calibrated?  Not so well
data(deliveries)
set.seed(1)
delivery_split <- initial_split(deliveries)
delivery_train <- training(delivery_split)
delivery_test  <- testing(delivery_split)
set.seed(1)
delivery_folds <- vfold_cv(delivery_train)
delivery_res <-
  workflow() %>%
  add_formula(time_to_delivery ~ .) %>%
  add_model(boost_tree(mode = "regression", trees = 3)) |> 
  fit_resamples(
    delivery_folds, 
    control = control_stack_resamples()
  )
delivery_res |> 
  collect_predictions() |> 
  cal_plot_regression(truth = time_to_delivery, estimate = .pred)

delivery_res |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard   9.52     10 0.0533  Preprocessor1_Model1
#> 2 rsq     standard   0.853    10 0.00357 Preprocessor1_Model1

# We want to reuse the already saved splits in the tune results as rset
extract_resamples <- \(x) {
  stopifnot(inherits(x, "tune_results"))
  result_rset <- manual_rset(x$splits, x$id)
  new_attrs <- attributes(result_rset)[c("names", "row.names")]
  existing_attrs <- attributes(x)$rset_info$att
  att <- modifyList(existing_attrs, new_attrs)
  desired_classes <- c(att$class, "rset", "tbl_df", "tbl", "data.frame")  
  att$class <- NULL  
  attributes(result_rset) <- att  
  class(result_rset) <- desired_classes
  result_rset
}
waldo::compare(delivery_folds, extract_resamples(delivery_res))
#> ✔ No differences

# Let's adjust numeric calibration extracting the saved splits
delivery_res_improved <-
  delivery_res |> 
  extract_workflow() |> 
  add_tailor(tailor() %>% adjust_numeric_calibration()) |> 
  fit_resamples(
    extract_resamples(delivery_res), 
    control = control_stack_resamples()
  )
delivery_res_improved |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard   2.71     10 0.0300  Preprocessor1_Model1
#> 2 rsq     standard   0.846    10 0.00432 Preprocessor1_Model1

# Much better
delivery_res_improved |> 
  collect_predictions() |>
  cal_plot_regression(truth = time_to_delivery, estimate = .pred)



sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.3 (2024-02-29)
#>  os       Ubuntu 22.04.4 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Madrid
#>  date     2024-10-09
#>  pandoc   2.9.2.1 @ /bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.3.0)
#>  broom        * 1.0.5      2023-06-09 [1] CRAN (R 4.3.1)
#>  butcher        0.3.3      2023-08-23 [1] CRAN (R 4.3.2)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.3.3)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.2)
#>  codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.3)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.0)
#>  data.table     1.15.99    2024-02-20 [1] Github (Rdatatable/data.table@8f8ef93)
#>  dials        * 1.3.0      2024-07-30 [1] RSPM
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.2)
#>  digest         0.6.35     2024-03-11 [1] RSPM (R 4.3.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.2)
#>  evaluate       0.23       2023-11-01 [1] CRAN (R 4.3.2)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.2)
#>  farver         2.1.1      2022-07-06 [1] CRAN (R 4.3.0)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.3.0)
#>  forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.3.2)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
#>  fs             1.6.3      2023-07-20 [1] CRAN (R 4.3.1)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
#>  future         1.33.1     2023-12-22 [1] CRAN (R 4.3.2)
#>  future.apply   1.11.1     2023-12-21 [1] CRAN (R 4.3.2)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.5.0      2024-02-23 [1] RSPM (R 4.3.0)
#>  globals        0.16.3     2024-03-08 [1] RSPM (R 4.3.0)
#>  glue           1.7.0      2024-01-09 [1] RSPM (R 4.3.0)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.3.1)
#>  hardhat        1.4.0      2024-06-02 [1] RSPM
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.3.0)
#>  htmltools      0.5.8      2024-03-25 [1] RSPM (R 4.3.0)
#>  infer        * 1.0.7      2024-03-25 [1] RSPM (R 4.3.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
#>  jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.3.2)
#>  knitr          1.45       2023-10-30 [1] CRAN (R 4.3.2)
#>  labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.1)
#>  lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.3)
#>  lava           1.8.0      2024-03-05 [1] RSPM (R 4.3.0)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.2)
#>  listenv        0.9.1      2024-01-29 [1] RSPM (R 4.3.0)
#>  lubridate    * 1.9.3      2023-09-27 [1] CRAN (R 4.3.2)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
#>  MASS           7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3)
#>  Matrix         1.6-5      2024-01-11 [1] RSPM (R 4.3.0)
#>  mgcv           1.9-1      2023-12-21 [2] CRAN (R 4.3.3)
#>  modeldata    * 1.3.0      2024-01-21 [1] RSPM (R 4.3.0)
#>  modelenv       0.1.1      2023-03-08 [1] CRAN (R 4.3.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.3.0)
#>  nlme           3.1-164    2023-11-27 [2] CRAN (R 4.3.3)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.3)
#>  parallelly     1.37.1     2024-02-29 [1] RSPM (R 4.3.0)
#>  parsnip      * 1.2.1.9002 2024-10-08 [1] Github (tidymodels/parsnip@5ce414e)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
#>  probably     * 1.0.3.9001 2024-10-08 [1] Github (tidymodels/probably@545f9ab)
#>  prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.2)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.1)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.3.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.3.1)
#>  R.oo           1.26.0     2024-01-24 [1] CRAN (R 4.3.2)
#>  R.utils        2.12.3     2023-11-18 [1] CRAN (R 4.3.2)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
#>  Rcpp           1.0.12     2024-01-09 [1] RSPM (R 4.3.0)
#>  readr        * 2.1.5      2024-01-10 [1] RSPM (R 4.3.0)
#>  recipes      * 1.0.10     2024-02-18 [1] RSPM (R 4.3.0)
#>  reprex         2.1.0.9000 2024-01-18 [1] Github (tidyverse/reprex@e1f65e9)
#>  rlang          1.1.3      2024-01-10 [1] RSPM (R 4.3.0)
#>  rmarkdown      2.26       2024-03-05 [1] RSPM (R 4.3.0)
#>  rpart          4.1.23     2023-12-05 [1] RSPM
#>  rsample      * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/rsample@f799dba)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.3.2)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.3.0)
#>  sparsevctrs    0.1.0.9002 2024-10-08 [1] Github (r-lib/sparsevctrs@b29b723)
#>  stacks       * 1.0.4      2024-03-21 [1] RSPM (R 4.3.0)
#>  stringi        1.8.3      2023-12-11 [1] CRAN (R 4.3.2)
#>  stringr      * 1.5.1      2023-11-14 [1] CRAN (R 4.3.2)
#>  styler         1.10.2     2023-08-29 [1] CRAN (R 4.3.2)
#>  survival       3.5-8      2024-02-14 [2] CRAN (R 4.3.3)
#>  tailor       * 0.0.0.9001 2024-10-08 [1] Github (tidymodels/tailor@317a4db)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.2.0      2024-03-25 [1] RSPM (R 4.3.0)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.3.2)
#>  tidyselect     1.2.1      2024-03-11 [1] RSPM (R 4.3.0)
#>  tidyverse    * 2.0.0.9000 2024-02-20 [1] Github (tidyverse/tidyverse@62f32d4)
#>  timechange     0.3.0      2024-01-18 [1] RSPM (R 4.3.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.2)
#>  tune         * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/tune@f8d734a)
#>  tzdb           0.4.0      2023-05-12 [1] CRAN (R 4.3.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.2)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.2)
#>  waldo          0.5.2      2023-11-02 [1] CRAN (R 4.3.2)
#>  withr          3.0.0      2024-01-16 [1] CRAN (R 4.3.2)
#>  workflows    * 1.1.4.9000 2024-10-08 [1] Github (tidymodels/workflows@78aa5df)
#>  workflowsets * 1.1.0      2024-03-21 [1] RSPM (R 4.3.0)
#>  xfun           0.43       2024-03-25 [1] RSPM (R 4.3.0)
#>  xgboost      * 1.7.7.1    2024-01-25 [1] RSPM (R 4.3.0)
#>  yaml           2.3.8      2023-12-11 [1] CRAN (R 4.3.2)
#>  yardstick    * 1.3.1      2024-03-21 [1] RSPM (R 4.3.0)
#> 
#>  [1] /home/jordi/R/x86_64-pc-linux-gnu-library/4.3
#>  [2] /opt/R/4.3.3/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Created on 2024-10-09 with [reprex v2.1.0.9000](https://reprex.tidyverse.org/)

This implementation seems to give identical results for my vfold_cv example, but I guess other rset type of objects should be tested.

@simonpcouch
Copy link
Contributor

Could you say a little bit more about why it is that you'd need to extract the splits from the tune_results rather than just reusing the splits you have already?

Note to self: FWIW, we did find a use for a similar helper in stacks:::.set_splits().

@jrosell
Copy link
Author

jrosell commented Oct 10, 2024

Well. In my pipelines I usually have one process for fitting resamples & tuning and sometimes I only save the tune_resamples object and not the rset... But, then "ups" I need the rset too because I want to check something and I didnt save it. {tailor} could increase the probability of this issue.

Furthermore, I want to try AutoGuon inference approach and this function could help.

@simonpcouch
Copy link
Contributor

Gotcha, thanks for the reply! I will leave this open as we can see some use cases for this, though it may not be at the top of our to-do for a bit.

@jrosell
Copy link
Author

jrosell commented Nov 8, 2024

If you guide me on what tests do you like to include, I would make a proper PR, so we can merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants