-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow tabnet_model =
and from_epoch =
in parsnip tabnet()
models so that Stack can use a pre-trained TabNet model
#142
Comments
Hello @cgoo4 FYI, I'm about to release a parsnip I'll let you know when available so that you can validate it on your own dataset. |
tabnet_model =
and from_epoch =
in parsnip tabnet()
models so that Stack can use a pre-trained TabNet model
Hello @cgoo4, Could you try the to install {tabnet} from branch feature/parsnip_from_pretrain, like with pak::pak("mlverse/tabnet@feature/parsnip_from_pretrain") and tell me if that fits your need ? |
Hi @cregouby - Thank you! The toy example runs successfully per below. (I'm using an M2 machine hence On my own data though, I'm getting computational errors (see further down). library(tabnet)
library(tidymodels)
library(modeldata)
library(stacks)
set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <- recipe(Class ~ ., train) |>
step_normalize(all_numeric())
set.seed(1)
tab_pre <- tab_rec |>
tabnet_pretrain(
train,
valid_split = 0.2,
epochs = 100,
device = "cpu",
verbose = TRUE
)
#> [Epoch 001] Loss: 108.511490, Valid loss: 29.568882
#> [Epoch 002] Loss: 30.584080, Valid loss: 17.283676
#> [Epoch 003] Loss: 9.266540, Valid loss: 13.562984
#> [Epoch 004] Loss: 5.846971, Valid loss: 10.691854
#> [Epoch 005] Loss: 6.189508, Valid loss: 7.484963
#> [Epoch 006] Loss: 5.970494, Valid loss: 6.027362
#> [Epoch 007] Loss: 5.492610, Valid loss: 5.489138
#> [Epoch 008] Loss: 4.127725, Valid loss: 5.134202
#> [Epoch 009] Loss: 6.233320, Valid loss: 4.775207
#> [Epoch 010] Loss: 3.267596, Valid loss: 4.366734
#> [Epoch 011] Loss: 3.254126, Valid loss: 6.512185
#> [Epoch 012] Loss: 3.117017, Valid loss: 5.826352
#> [Epoch 013] Loss: 2.751130, Valid loss: 5.357126
#> [Epoch 014] Loss: 2.585568, Valid loss: 4.796397
#> [Epoch 015] Loss: 2.389232, Valid loss: 4.247961
#> [Epoch 016] Loss: 2.352181, Valid loss: 3.750866
#> [Epoch 017] Loss: 2.095828, Valid loss: 3.386385
#> [Epoch 018] Loss: 2.168191, Valid loss: 3.083429
#> [Epoch 019] Loss: 1.916646, Valid loss: 3.169604
#> [Epoch 020] Loss: 1.869341, Valid loss: 3.251067
#> [Epoch 021] Loss: 1.737848, Valid loss: 3.130211
#> [Epoch 022] Loss: 1.651209, Valid loss: 2.922243
#> [Epoch 023] Loss: 1.588258, Valid loss: 2.717718
#> [Epoch 024] Loss: 1.565487, Valid loss: 2.503327
#> [Epoch 025] Loss: 1.503044, Valid loss: 2.293428
#> [Epoch 026] Loss: 1.417796, Valid loss: 2.115288
#> [Epoch 027] Loss: 1.414081, Valid loss: 1.961773
#> [Epoch 028] Loss: 1.469063, Valid loss: 1.834276
#> [Epoch 029] Loss: 1.345036, Valid loss: 1.741629
#> [Epoch 030] Loss: 1.342513, Valid loss: 1.661885
#> [Epoch 031] Loss: 1.302708, Valid loss: 1.595290
#> [Epoch 032] Loss: 1.309536, Valid loss: 1.536704
#> [Epoch 033] Loss: 1.226201, Valid loss: 1.483173
#> [Epoch 034] Loss: 1.210594, Valid loss: 1.437791
#> [Epoch 035] Loss: 1.226383, Valid loss: 1.401923
#> [Epoch 036] Loss: 1.189032, Valid loss: 1.368132
#> [Epoch 037] Loss: 1.154899, Valid loss: 1.340692
#> [Epoch 038] Loss: 1.258244, Valid loss: 1.324930
#> [Epoch 039] Loss: 1.238932, Valid loss: 1.308550
#> [Epoch 040] Loss: 1.216526, Valid loss: 1.294628
#> [Epoch 041] Loss: 1.140847, Valid loss: 1.282725
#> [Epoch 042] Loss: 1.115530, Valid loss: 1.272120
#> [Epoch 043] Loss: 1.115844, Valid loss: 1.262455
#> [Epoch 044] Loss: 1.213085, Valid loss: 1.255370
#> [Epoch 045] Loss: 1.172749, Valid loss: 1.248636
#> [Epoch 046] Loss: 1.067289, Valid loss: 1.241004
#> [Epoch 047] Loss: 1.121537, Valid loss: 1.235389
#> [Epoch 048] Loss: 1.146737, Valid loss: 1.231018
#> [Epoch 049] Loss: 1.105253, Valid loss: 1.225921
#> [Epoch 050] Loss: 1.105276, Valid loss: 1.221406
#> [Epoch 051] Loss: 1.105012, Valid loss: 1.215876
#> [Epoch 052] Loss: 1.105758, Valid loss: 1.210291
#> [Epoch 053] Loss: 1.025887, Valid loss: 1.204283
#> [Epoch 054] Loss: 1.119924, Valid loss: 1.197808
#> [Epoch 055] Loss: 1.097717, Valid loss: 1.190847
#> [Epoch 056] Loss: 1.091637, Valid loss: 1.185091
#> [Epoch 057] Loss: 1.097960, Valid loss: 1.182087
#> [Epoch 058] Loss: 1.102373, Valid loss: 1.178165
#> [Epoch 059] Loss: 1.135786, Valid loss: 1.174002
#> [Epoch 060] Loss: 1.132688, Valid loss: 1.170313
#> [Epoch 061] Loss: 1.039531, Valid loss: 1.165211
#> [Epoch 062] Loss: 1.070144, Valid loss: 1.161195
#> [Epoch 063] Loss: 1.098994, Valid loss: 1.157930
#> [Epoch 064] Loss: 1.043513, Valid loss: 1.154124
#> [Epoch 065] Loss: 1.077071, Valid loss: 1.151235
#> [Epoch 066] Loss: 1.022600, Valid loss: 1.147900
#> [Epoch 067] Loss: 1.127478, Valid loss: 1.146163
#> [Epoch 068] Loss: 1.061102, Valid loss: 1.143913
#> [Epoch 069] Loss: 1.067844, Valid loss: 1.142187
#> [Epoch 070] Loss: 1.079090, Valid loss: 1.140889
#> [Epoch 071] Loss: 1.173972, Valid loss: 1.140199
#> [Epoch 072] Loss: 1.060898, Valid loss: 1.138638
#> [Epoch 073] Loss: 1.066270, Valid loss: 1.136953
#> [Epoch 074] Loss: 1.202267, Valid loss: 1.136177
#> [Epoch 075] Loss: 1.144205, Valid loss: 1.135328
#> [Epoch 076] Loss: 0.976403, Valid loss: 1.132538
#> [Epoch 077] Loss: 1.015884, Valid loss: 1.130507
#> [Epoch 078] Loss: 1.098029, Valid loss: 1.130162
#> [Epoch 079] Loss: 1.006299, Valid loss: 1.128126
#> [Epoch 080] Loss: 1.045715, Valid loss: 1.126469
#> [Epoch 081] Loss: 1.052163, Valid loss: 1.124908
#> [Epoch 082] Loss: 0.981435, Valid loss: 1.121406
#> [Epoch 083] Loss: 1.032371, Valid loss: 1.118657
#> [Epoch 084] Loss: 1.079322, Valid loss: 1.117403
#> [Epoch 085] Loss: 0.993631, Valid loss: 1.114072
#> [Epoch 086] Loss: 1.100637, Valid loss: 1.113469
#> [Epoch 087] Loss: 1.046454, Valid loss: 1.111916
#> [Epoch 088] Loss: 1.074078, Valid loss: 1.110664
#> [Epoch 089] Loss: 1.034861, Valid loss: 1.108330
#> [Epoch 090] Loss: 1.009123, Valid loss: 1.105716
#> [Epoch 091] Loss: 1.081966, Valid loss: 1.105542
#> [Epoch 092] Loss: 1.048128, Valid loss: 1.103355
#> [Epoch 093] Loss: 0.965812, Valid loss: 1.100522
#> [Epoch 094] Loss: 1.016202, Valid loss: 1.098672
#> [Epoch 095] Loss: 0.975355, Valid loss: 1.096074
#> [Epoch 096] Loss: 0.976712, Valid loss: 1.093621
#> [Epoch 097] Loss: 0.981546, Valid loss: 1.090942
#> [Epoch 098] Loss: 1.058951, Valid loss: 1.089792
#> [Epoch 099] Loss: 1.087424, Valid loss: 1.089323
#> [Epoch 100] Loss: 0.964833, Valid loss: 1.086923
autoplot(tab_pre) tab_mod <- tabnet(epochs = 100, tabnet_model = tab_pre, from_epoch = 100) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec)
xgb_rec <- recipe(Class ~ ., train) |>
step_normalize(all_numeric()) |>
step_dummy(c("term", "sub_grade", "addr_state", "verification_status", "emp_length"))
xgb_mod <-
boost_tree(mode = "classification")|>
set_engine("xgboost")
xgb_wf <- workflow() |>
add_model(xgb_mod) |>
add_recipe(xgb_rec)
set.seed(123)
folds <- vfold_cv(train, v = 3)
fit_tab <- tab_wf |>
fit_resamples(folds, control = control_stack_resamples())
fit_xgb <- xgb_wf |>
fit_resamples(folds, control = control_stack_resamples())
collect_metrics(fit_tab)
#> # A tibble: 2 × 6
#> .metric .estimator mean n std_err .config
#> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 accuracy binary 0.943 3 0.00676 Preprocessor1_Model1
#> 2 roc_auc binary 0.635 3 0.0191 Preprocessor1_Model1
collect_metrics(fit_xgb)
#> # A tibble: 2 × 6
#> .metric .estimator mean n std_err .config
#> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 accuracy binary 0.947 3 0.00479 Preprocessor1_Model1
#> 2 roc_auc binary 0.706 3 0.0133 Preprocessor1_Model1
data_st <-
stacks() |>
add_candidates(fit_tab) |>
add_candidates(fit_xgb)
model_st <-
data_st |>
blend_predictions()
autoplot(model_st, type = "weights") Created on 2024-01-11 with reprex v2.0.2 Session infosessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.3.2 (2023-10-31)
#> os macOS Sonoma 14.2.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/London
#> date 2024-01-11
#> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> backports 1.4.1 2021-12-13 [2] CRAN (R 4.3.0)
#> bit 4.0.5 2022-11-15 [2] CRAN (R 4.3.0)
#> bit64 4.0.5 2020-08-30 [2] CRAN (R 4.3.0)
#> broom * 1.0.5 2023-06-09 [2] CRAN (R 4.3.0)
#> butcher 0.3.3 2023-08-23 [1] CRAN (R 4.3.0)
#> callr 3.7.3 2022-11-02 [2] CRAN (R 4.3.0)
#> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.2)
#> cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.1)
#> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.2)
#> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0)
#> coro 1.0.3 2022-07-19 [2] CRAN (R 4.3.0)
#> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0)
#> curl 5.2.0 2023-12-08 [2] CRAN (R 4.3.1)
#> data.table 1.14.10 2023-12-08 [1] CRAN (R 4.3.1)
#> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0)
#> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.1)
#> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.1)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.0)
#> evaluate 0.23 2023-11-01 [2] CRAN (R 4.3.1)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.1)
#> farver 2.1.1 2022-07-06 [1] CRAN (R 4.3.0)
#> fastmap 1.1.1 2023-02-24 [2] CRAN (R 4.3.0)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
#> fs 1.6.3 2023-07-20 [2] CRAN (R 4.3.0)
#> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
#> future 1.33.1 2023-12-22 [1] CRAN (R 4.3.1)
#> future.apply 1.11.1 2023-12-21 [1] CRAN (R 4.3.1)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
#> ggplot2 * 3.4.4 2023-10-12 [1] CRAN (R 4.3.1)
#> glmnet * 4.1-8 2023-08-22 [1] CRAN (R 4.3.0)
#> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0)
#> glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.1)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
#> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.0)
#> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0)
#> highr 0.10 2022-12-22 [2] CRAN (R 4.3.0)
#> hms 1.1.3 2023-03-21 [2] CRAN (R 4.3.0)
#> htmltools 0.5.7 2023-11-03 [2] CRAN (R 4.3.1)
#> infer * 1.0.5 2023-09-06 [2] CRAN (R 4.3.0)
#> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
#> jsonlite 1.8.8 2023-12-04 [2] CRAN (R 4.3.1)
#> knitr 1.45 2023-10-30 [2] CRAN (R 4.3.1)
#> labeling 0.4.3 2023-08-29 [1] CRAN (R 4.3.0)
#> lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.1)
#> lava 1.7.3 2023-11-04 [1] CRAN (R 4.3.1)
#> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.1)
#> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0)
#> lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.3.1)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
#> MASS 7.3-60 2023-05-04 [2] CRAN (R 4.3.2)
#> Matrix * 1.6-4 2023-11-30 [2] CRAN (R 4.3.1)
#> modeldata * 1.2.0 2023-08-09 [2] CRAN (R 4.3.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0)
#> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.2)
#> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0)
#> parsnip * 1.1.1 2023-08-17 [1] CRAN (R 4.3.0)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
#> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.3.1)
#> processx 3.8.3 2023-12-10 [2] CRAN (R 4.3.1)
#> prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
#> progress 1.2.3 2023-12-06 [2] CRAN (R 4.3.1)
#> ps 1.7.5 2023-04-18 [2] CRAN (R 4.3.0)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0)
#> R.cache 0.16.0 2022-07-21 [2] CRAN (R 4.3.0)
#> R.methodsS3 1.8.2 2022-06-13 [2] CRAN (R 4.3.0)
#> R.oo 1.25.0 2022-06-12 [2] CRAN (R 4.3.0)
#> R.utils 2.12.3 2023-11-18 [2] CRAN (R 4.3.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
#> Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.1)
#> recipes * 1.0.9 2023-12-13 [1] CRAN (R 4.3.1)
#> reprex 2.0.2 2022-08-17 [2] CRAN (R 4.3.0)
#> rlang 1.1.3 2024-01-10 [1] CRAN (R 4.3.1)
#> rmarkdown 2.25 2023-09-18 [2] CRAN (R 4.3.1)
#> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.3.1)
#> rsample * 1.2.0 2023-08-23 [1] CRAN (R 4.3.0)
#> rstudioapi 0.15.0 2023-07-07 [2] CRAN (R 4.3.0)
#> safetensors 0.1.2 2023-09-12 [2] CRAN (R 4.3.0)
#> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.2)
#> sessioninfo 1.2.2 2021-12-06 [2] CRAN (R 4.3.0)
#> shape 1.4.6 2021-05-19 [1] CRAN (R 4.3.0)
#> stacks * 1.0.3.9000 2023-11-13 [1] Github (tidymodels/stacks@31199f8)
#> styler 1.10.2 2023-08-29 [2] CRAN (R 4.3.0)
#> survival 3.5-7 2023-08-14 [2] CRAN (R 4.3.2)
#> tabnet * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
#> tidymodels * 1.1.1 2023-08-24 [2] CRAN (R 4.3.0)
#> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0)
#> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0)
#> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0)
#> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.1)
#> torch 0.12.0 2024-01-05 [1] Github (mlverse/torch@23071c1)
#> tune * 1.1.2 2023-08-23 [1] CRAN (R 4.3.0)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.1)
#> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1)
#> workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0)
#> workflowsets * 1.0.1 2023-04-06 [2] CRAN (R 4.3.0)
#> xfun 0.41 2023-11-01 [2] CRAN (R 4.3.1)
#> xgboost * 1.7.6.1 2023-12-06 [2] CRAN (R 4.3.1)
#> xml2 1.3.6 2023-12-04 [2] CRAN (R 4.3.1)
#> yaml 2.3.8 2023-12-11 [2] CRAN (R 4.3.1)
#> yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0)
#> zeallot 0.1.0 2018-01-28 [2] CRAN (R 4.3.0)
#>
#> [1] /Users/carlgoodwin/Library/R/arm64/4.3/library
#> [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#>
#> ────────────────────────────────────────────────────────────────────────────── First few lines running (It works fine with
|
Good to see it works ! BTW you can use library(tabnet)
library(tidymodels)
#> Warning: package 'scales' was built under R version 4.3.1
#> Warning: package 'dplyr' was built under R version 4.3.1
#> Warning: package 'ggplot2' was built under R version 4.3.1
#> Warning: package 'recipes' was built under R version 4.3.1
set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <- recipe(Class ~ ., train) |>
step_normalize(all_numeric())
set.seed(1)
tictoc::tic()
tab_pre <- tab_rec |>
tabnet_pretrain(
train,
valid_split = 0.2,
epochs = 100,
device = "cpu",
verbose = FALSE
)
tictoc::toc()
#> 45.689 sec elapsed
tictoc::tic()
tab_pre_gpu <- tab_rec |>
tabnet_pretrain(
train,
valid_split = 0.2,
epochs = 100,
device = "mps",
verbose = FALSE
)
tictoc::toc()
#> 69.116 sec elapsed Created on 2024-01-12 with reprex v2.0.2 For the tensor-size error, I was not able to reproduce as I don't know what is behind |
@cregouby thank you. I'll take another look at I've been trying to isolate the tensor error. I've found that, with my own data, if I convert my two character variables to factors, then the stacking code all works fine, including the The puzzling thing is that if I convert the lending data factors to character the toy example doesn't fail when stacking! But puzzle aside, I think you can close this now :) It's a great enhancement. If I were to open a separate feature request for importance weights (with an example) is that something you could consider? This would be to address a class imbalance and would enable equivalent TabNet and XGBoost results. |
Ok, It seems you hit #124 with the character / factor problem. |
I'd like to be able to stack a TabNet model with, say, an XGBoost model to get blended predictions. The following tidymodels workflow (using the toy dataset) does achieve this. However, if the TabNet model requires pre-training, I can't currently see a way to incorporate that, i.e.
tabnet()
would need to supporttabnet_model
andfrom_epoch
. Would that be feasible?Created on 2024-01-09 with reprex v2.0.2
The text was updated successfully, but these errors were encountered: