Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tabnet_model = and from_epoch = in parsnip tabnet() models so that Stack can use a pre-trained TabNet model #142

Closed
cgoo4 opened this issue Jan 9, 2024 · 6 comments · Fixed by #143
Assignees
Labels
enhancement New feature or request

Comments

@cgoo4
Copy link

cgoo4 commented Jan 9, 2024

I'd like to be able to stack a TabNet model with, say, an XGBoost model to get blended predictions. The following tidymodels workflow (using the toy dataset) does achieve this. However, if the TabNet model requires pre-training, I can't currently see a way to incorporate that, i.e. tabnet() would need to support tabnet_model and from_epoch. Would that be feasible?

library(tabnet)
library(tidymodels)
library(modeldata)
library(stacks)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <- recipe(Class ~ ., train) |> 
  step_normalize(all_numeric())

tab_mod <- tabnet(epochs = 10, batch_size = 128) |> 
  set_engine("torch", verbose = TRUE, device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec)

xgb_rec <- recipe(Class ~ ., train) |> 
  step_normalize(all_numeric()) |> 
  step_dummy(c('term', 'sub_grade', 'addr_state', 'verification_status', 'emp_length'))

xgb_mod <-
  boost_tree(mode = "classification")|> 
  set_engine("xgboost")

xgb_wf <- workflow() |> 
  add_model(xgb_mod) |> 
  add_recipe(xgb_rec)

set.seed(123)
folds <- vfold_cv(train, v = 3)

fit_tab <- tab_wf |> 
  fit_resamples(folds, control = control_stack_resamples())
#> [Epoch 001] Loss: 0.294646
#> [Epoch 002] Loss: 0.222646
#> [Epoch 003] Loss: 0.212307
#> [Epoch 004] Loss: 0.212186
#> [Epoch 005] Loss: 0.207021
#> [Epoch 006] Loss: 0.204229
#> [Epoch 007] Loss: 0.203046
#> [Epoch 008] Loss: 0.198770
#> [Epoch 009] Loss: 0.199130
#> [Epoch 010] Loss: 0.201151
#> [Epoch 001] Loss: 0.260944
#> [Epoch 002] Loss: 0.213966
#> [Epoch 003] Loss: 0.205509
#> [Epoch 004] Loss: 0.199555
#> [Epoch 005] Loss: 0.198867
#> [Epoch 006] Loss: 0.194544
#> [Epoch 007] Loss: 0.192421
#> [Epoch 008] Loss: 0.194404
#> [Epoch 009] Loss: 0.196428
#> [Epoch 010] Loss: 0.195690
#> [Epoch 001] Loss: 0.255175
#> [Epoch 002] Loss: 0.192142
#> [Epoch 003] Loss: 0.188243
#> [Epoch 004] Loss: 0.187549
#> [Epoch 005] Loss: 0.182563
#> [Epoch 006] Loss: 0.181317
#> [Epoch 007] Loss: 0.179819
#> [Epoch 008] Loss: 0.180077
#> [Epoch 009] Loss: 0.180719
#> [Epoch 010] Loss: 0.181777

fit_xgb <- xgb_wf |> 
  fit_resamples(folds, control = control_stack_resamples())

data_st <- 
  stacks() |> 
  add_candidates(fit_tab) |> 
  add_candidates(fit_xgb)

model_st <-
  data_st |> 
  blend_predictions()

autoplot(model_st, type = "weights")

model_st <-
  model_st |> 
  fit_members()
#> [Epoch 001] Loss: 0.288016
#> [Epoch 002] Loss: 0.204822
#> [Epoch 003] Loss: 0.196032
#> [Epoch 004] Loss: 0.192840
#> [Epoch 005] Loss: 0.192554
#> [Epoch 006] Loss: 0.192038
#> [Epoch 007] Loss: 0.190172
#> [Epoch 008] Loss: 0.188615
#> [Epoch 009] Loss: 0.189580
#> [Epoch 010] Loss: 0.189883

Created on 2024-01-09 with reprex v2.0.2

@cregouby
Copy link
Collaborator

cregouby commented Jan 10, 2024

Hello @cgoo4
That's an interesting demand that should not be too hard to implement.

FYI, I'm about to release a parsnip tabnet() with extended list of parameters in here : https://github.com/mlverse/tabnet/blob/feature/interpretabnet/R/parsnip.R. But it doesn't cover from_epoch nor tabnet_model yet.
You can try to modify it on your own, or wait for me to finish the interpretabnet story.

I'll let you know when available so that you can validate it on your own dataset.

@cregouby cregouby added the enhancement New feature or request label Jan 10, 2024
@cregouby cregouby changed the title Stack a pre-trained TabNet model Allow tabnet_model = and from_epoch = in parsnip tabnet() models so that Stack can use a pre-trained TabNet model Jan 10, 2024
@cregouby
Copy link
Collaborator

Hello @cgoo4,

Could you try the to install {tabnet} from branch feature/parsnip_from_pretrain, like with

pak::pak("mlverse/tabnet@feature/parsnip_from_pretrain")

and tell me if that fits your need ?

@cgoo4
Copy link
Author

cgoo4 commented Jan 11, 2024

Hi @cregouby - Thank you!

The toy example runs successfully per below. (I'm using an M2 machine hence device = "CPU".)

On my own data though, I'm getting computational errors (see further down).

library(tabnet)
library(tidymodels)
library(modeldata)
library(stacks)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <- recipe(Class ~ ., train) |> 
  step_normalize(all_numeric())

set.seed(1)

tab_pre <- tab_rec |>
  tabnet_pretrain(
    train, 
    valid_split = 0.2,
    epochs = 100,
    device = "cpu",
    verbose = TRUE
    )
#> [Epoch 001] Loss: 108.511490, Valid loss: 29.568882
#> [Epoch 002] Loss: 30.584080, Valid loss: 17.283676
#> [Epoch 003] Loss: 9.266540, Valid loss: 13.562984
#> [Epoch 004] Loss: 5.846971, Valid loss: 10.691854
#> [Epoch 005] Loss: 6.189508, Valid loss: 7.484963
#> [Epoch 006] Loss: 5.970494, Valid loss: 6.027362
#> [Epoch 007] Loss: 5.492610, Valid loss: 5.489138
#> [Epoch 008] Loss: 4.127725, Valid loss: 5.134202
#> [Epoch 009] Loss: 6.233320, Valid loss: 4.775207
#> [Epoch 010] Loss: 3.267596, Valid loss: 4.366734
#> [Epoch 011] Loss: 3.254126, Valid loss: 6.512185
#> [Epoch 012] Loss: 3.117017, Valid loss: 5.826352
#> [Epoch 013] Loss: 2.751130, Valid loss: 5.357126
#> [Epoch 014] Loss: 2.585568, Valid loss: 4.796397
#> [Epoch 015] Loss: 2.389232, Valid loss: 4.247961
#> [Epoch 016] Loss: 2.352181, Valid loss: 3.750866
#> [Epoch 017] Loss: 2.095828, Valid loss: 3.386385
#> [Epoch 018] Loss: 2.168191, Valid loss: 3.083429
#> [Epoch 019] Loss: 1.916646, Valid loss: 3.169604
#> [Epoch 020] Loss: 1.869341, Valid loss: 3.251067
#> [Epoch 021] Loss: 1.737848, Valid loss: 3.130211
#> [Epoch 022] Loss: 1.651209, Valid loss: 2.922243
#> [Epoch 023] Loss: 1.588258, Valid loss: 2.717718
#> [Epoch 024] Loss: 1.565487, Valid loss: 2.503327
#> [Epoch 025] Loss: 1.503044, Valid loss: 2.293428
#> [Epoch 026] Loss: 1.417796, Valid loss: 2.115288
#> [Epoch 027] Loss: 1.414081, Valid loss: 1.961773
#> [Epoch 028] Loss: 1.469063, Valid loss: 1.834276
#> [Epoch 029] Loss: 1.345036, Valid loss: 1.741629
#> [Epoch 030] Loss: 1.342513, Valid loss: 1.661885
#> [Epoch 031] Loss: 1.302708, Valid loss: 1.595290
#> [Epoch 032] Loss: 1.309536, Valid loss: 1.536704
#> [Epoch 033] Loss: 1.226201, Valid loss: 1.483173
#> [Epoch 034] Loss: 1.210594, Valid loss: 1.437791
#> [Epoch 035] Loss: 1.226383, Valid loss: 1.401923
#> [Epoch 036] Loss: 1.189032, Valid loss: 1.368132
#> [Epoch 037] Loss: 1.154899, Valid loss: 1.340692
#> [Epoch 038] Loss: 1.258244, Valid loss: 1.324930
#> [Epoch 039] Loss: 1.238932, Valid loss: 1.308550
#> [Epoch 040] Loss: 1.216526, Valid loss: 1.294628
#> [Epoch 041] Loss: 1.140847, Valid loss: 1.282725
#> [Epoch 042] Loss: 1.115530, Valid loss: 1.272120
#> [Epoch 043] Loss: 1.115844, Valid loss: 1.262455
#> [Epoch 044] Loss: 1.213085, Valid loss: 1.255370
#> [Epoch 045] Loss: 1.172749, Valid loss: 1.248636
#> [Epoch 046] Loss: 1.067289, Valid loss: 1.241004
#> [Epoch 047] Loss: 1.121537, Valid loss: 1.235389
#> [Epoch 048] Loss: 1.146737, Valid loss: 1.231018
#> [Epoch 049] Loss: 1.105253, Valid loss: 1.225921
#> [Epoch 050] Loss: 1.105276, Valid loss: 1.221406
#> [Epoch 051] Loss: 1.105012, Valid loss: 1.215876
#> [Epoch 052] Loss: 1.105758, Valid loss: 1.210291
#> [Epoch 053] Loss: 1.025887, Valid loss: 1.204283
#> [Epoch 054] Loss: 1.119924, Valid loss: 1.197808
#> [Epoch 055] Loss: 1.097717, Valid loss: 1.190847
#> [Epoch 056] Loss: 1.091637, Valid loss: 1.185091
#> [Epoch 057] Loss: 1.097960, Valid loss: 1.182087
#> [Epoch 058] Loss: 1.102373, Valid loss: 1.178165
#> [Epoch 059] Loss: 1.135786, Valid loss: 1.174002
#> [Epoch 060] Loss: 1.132688, Valid loss: 1.170313
#> [Epoch 061] Loss: 1.039531, Valid loss: 1.165211
#> [Epoch 062] Loss: 1.070144, Valid loss: 1.161195
#> [Epoch 063] Loss: 1.098994, Valid loss: 1.157930
#> [Epoch 064] Loss: 1.043513, Valid loss: 1.154124
#> [Epoch 065] Loss: 1.077071, Valid loss: 1.151235
#> [Epoch 066] Loss: 1.022600, Valid loss: 1.147900
#> [Epoch 067] Loss: 1.127478, Valid loss: 1.146163
#> [Epoch 068] Loss: 1.061102, Valid loss: 1.143913
#> [Epoch 069] Loss: 1.067844, Valid loss: 1.142187
#> [Epoch 070] Loss: 1.079090, Valid loss: 1.140889
#> [Epoch 071] Loss: 1.173972, Valid loss: 1.140199
#> [Epoch 072] Loss: 1.060898, Valid loss: 1.138638
#> [Epoch 073] Loss: 1.066270, Valid loss: 1.136953
#> [Epoch 074] Loss: 1.202267, Valid loss: 1.136177
#> [Epoch 075] Loss: 1.144205, Valid loss: 1.135328
#> [Epoch 076] Loss: 0.976403, Valid loss: 1.132538
#> [Epoch 077] Loss: 1.015884, Valid loss: 1.130507
#> [Epoch 078] Loss: 1.098029, Valid loss: 1.130162
#> [Epoch 079] Loss: 1.006299, Valid loss: 1.128126
#> [Epoch 080] Loss: 1.045715, Valid loss: 1.126469
#> [Epoch 081] Loss: 1.052163, Valid loss: 1.124908
#> [Epoch 082] Loss: 0.981435, Valid loss: 1.121406
#> [Epoch 083] Loss: 1.032371, Valid loss: 1.118657
#> [Epoch 084] Loss: 1.079322, Valid loss: 1.117403
#> [Epoch 085] Loss: 0.993631, Valid loss: 1.114072
#> [Epoch 086] Loss: 1.100637, Valid loss: 1.113469
#> [Epoch 087] Loss: 1.046454, Valid loss: 1.111916
#> [Epoch 088] Loss: 1.074078, Valid loss: 1.110664
#> [Epoch 089] Loss: 1.034861, Valid loss: 1.108330
#> [Epoch 090] Loss: 1.009123, Valid loss: 1.105716
#> [Epoch 091] Loss: 1.081966, Valid loss: 1.105542
#> [Epoch 092] Loss: 1.048128, Valid loss: 1.103355
#> [Epoch 093] Loss: 0.965812, Valid loss: 1.100522
#> [Epoch 094] Loss: 1.016202, Valid loss: 1.098672
#> [Epoch 095] Loss: 0.975355, Valid loss: 1.096074
#> [Epoch 096] Loss: 0.976712, Valid loss: 1.093621
#> [Epoch 097] Loss: 0.981546, Valid loss: 1.090942
#> [Epoch 098] Loss: 1.058951, Valid loss: 1.089792
#> [Epoch 099] Loss: 1.087424, Valid loss: 1.089323
#> [Epoch 100] Loss: 0.964833, Valid loss: 1.086923

autoplot(tab_pre)

tab_mod <- tabnet(epochs = 100, tabnet_model = tab_pre, from_epoch = 100) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec)

xgb_rec <- recipe(Class ~ ., train) |> 
  step_normalize(all_numeric()) |> 
  step_dummy(c("term", "sub_grade", "addr_state", "verification_status", "emp_length"))

xgb_mod <-
  boost_tree(mode = "classification")|> 
  set_engine("xgboost")

xgb_wf <- workflow() |> 
  add_model(xgb_mod) |> 
  add_recipe(xgb_rec)

set.seed(123)
folds <- vfold_cv(train, v = 3)

fit_tab <- tab_wf |> 
  fit_resamples(folds, control = control_stack_resamples())

fit_xgb <- xgb_wf |> 
  fit_resamples(folds, control = control_stack_resamples())

collect_metrics(fit_tab)
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.943     3 0.00676 Preprocessor1_Model1
#> 2 roc_auc  binary     0.635     3 0.0191  Preprocessor1_Model1
collect_metrics(fit_xgb)
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.947     3 0.00479 Preprocessor1_Model1
#> 2 roc_auc  binary     0.706     3 0.0133  Preprocessor1_Model1

data_st <- 
  stacks() |> 
  add_candidates(fit_tab) |> 
  add_candidates(fit_xgb)

model_st <-
  data_st |> 
  blend_predictions()

autoplot(model_st, type = "weights")

Created on 2024-01-11 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.2 (2023-10-31)
#>  os       macOS Sonoma 14.2.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-01-11
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [2] CRAN (R 4.3.0)
#>  bit            4.0.5      2022-11-15 [2] CRAN (R 4.3.0)
#>  bit64          4.0.5      2020-08-30 [2] CRAN (R 4.3.0)
#>  broom        * 1.0.5      2023-06-09 [2] CRAN (R 4.3.0)
#>  butcher        0.3.3      2023-08-23 [1] CRAN (R 4.3.0)
#>  callr          3.7.3      2022-11-02 [2] CRAN (R 4.3.0)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.3.2)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.1)
#>  codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.2)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.0)
#>  coro           1.0.3      2022-07-19 [2] CRAN (R 4.3.0)
#>  crayon         1.5.2      2022-09-29 [1] CRAN (R 4.3.0)
#>  curl           5.2.0      2023-12-08 [2] CRAN (R 4.3.1)
#>  data.table     1.14.10    2023-12-08 [1] CRAN (R 4.3.1)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.3.0)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.1)
#>  digest         0.6.33     2023-07-07 [1] CRAN (R 4.3.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.1)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.3.0)
#>  evaluate       0.23       2023-11-01 [2] CRAN (R 4.3.1)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.1)
#>  farver         2.1.1      2022-07-06 [1] CRAN (R 4.3.0)
#>  fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.3.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
#>  fs             1.6.3      2023-07-20 [2] CRAN (R 4.3.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
#>  future         1.33.1     2023-12-22 [1] CRAN (R 4.3.1)
#>  future.apply   1.11.1     2023-12-21 [1] CRAN (R 4.3.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.4.4      2023-10-12 [1] CRAN (R 4.3.1)
#>  glmnet       * 4.1-8      2023-08-22 [1] CRAN (R 4.3.0)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.3.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.3.1)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.3.0)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.3.0)
#>  highr          0.10       2022-12-22 [2] CRAN (R 4.3.0)
#>  hms            1.1.3      2023-03-21 [2] CRAN (R 4.3.0)
#>  htmltools      0.5.7      2023-11-03 [2] CRAN (R 4.3.1)
#>  infer        * 1.0.5      2023-09-06 [2] CRAN (R 4.3.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
#>  jsonlite       1.8.8      2023-12-04 [2] CRAN (R 4.3.1)
#>  knitr          1.45       2023-10-30 [2] CRAN (R 4.3.1)
#>  labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.0)
#>  lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.1)
#>  lava           1.7.3      2023-11-04 [1] CRAN (R 4.3.1)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.1)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.3.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.3.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
#>  MASS           7.3-60     2023-05-04 [2] CRAN (R 4.3.2)
#>  Matrix       * 1.6-4      2023-11-30 [2] CRAN (R 4.3.1)
#>  modeldata    * 1.2.0      2023-08-09 [2] CRAN (R 4.3.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.3.0)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.2)
#>  parallelly     1.36.0     2023-05-26 [1] CRAN (R 4.3.0)
#>  parsnip      * 1.1.1      2023-08-17 [1] CRAN (R 4.3.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
#>  prettyunits    1.2.0      2023-09-24 [1] CRAN (R 4.3.1)
#>  processx       3.8.3      2023-12-10 [2] CRAN (R 4.3.1)
#>  prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
#>  progress       1.2.3      2023-12-06 [2] CRAN (R 4.3.1)
#>  ps             1.7.5      2023-04-18 [2] CRAN (R 4.3.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.0)
#>  R.cache        0.16.0     2022-07-21 [2] CRAN (R 4.3.0)
#>  R.methodsS3    1.8.2      2022-06-13 [2] CRAN (R 4.3.0)
#>  R.oo           1.25.0     2022-06-12 [2] CRAN (R 4.3.0)
#>  R.utils        2.12.3     2023-11-18 [2] CRAN (R 4.3.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
#>  Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.3.1)
#>  recipes      * 1.0.9      2023-12-13 [1] CRAN (R 4.3.1)
#>  reprex         2.0.2      2022-08-17 [2] CRAN (R 4.3.0)
#>  rlang          1.1.3      2024-01-10 [1] CRAN (R 4.3.1)
#>  rmarkdown      2.25       2023-09-18 [2] CRAN (R 4.3.1)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.3.1)
#>  rsample      * 1.2.0      2023-08-23 [1] CRAN (R 4.3.0)
#>  rstudioapi     0.15.0     2023-07-07 [2] CRAN (R 4.3.0)
#>  safetensors    0.1.2      2023-09-12 [2] CRAN (R 4.3.0)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.3.2)
#>  sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.3.0)
#>  shape          1.4.6      2021-05-19 [1] CRAN (R 4.3.0)
#>  stacks       * 1.0.3.9000 2023-11-13 [1] Github (tidymodels/stacks@31199f8)
#>  styler         1.10.2     2023-08-29 [2] CRAN (R 4.3.0)
#>  survival       3.5-7      2023-08-14 [2] CRAN (R 4.3.2)
#>  tabnet       * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.1.1      2023-08-24 [2] CRAN (R 4.3.0)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.3.0)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.3.0)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.3.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.1)
#>  torch          0.12.0     2024-01-05 [1] Github (mlverse/torch@23071c1)
#>  tune         * 1.1.2      2023-08-23 [1] CRAN (R 4.3.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.1)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.1)
#>  withr          2.5.2      2023-10-30 [1] CRAN (R 4.3.1)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.3.0)
#>  workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.3.0)
#>  xfun           0.41       2023-11-01 [2] CRAN (R 4.3.1)
#>  xgboost      * 1.7.6.1    2023-12-06 [2] CRAN (R 4.3.1)
#>  xml2           1.3.6      2023-12-04 [2] CRAN (R 4.3.1)
#>  yaml           2.3.8      2023-12-11 [2] CRAN (R 4.3.1)
#>  yardstick    * 1.2.0      2023-04-21 [1] CRAN (R 4.3.0)
#>  zeallot        0.1.0      2018-01-28 [2] CRAN (R 4.3.0)
#> 
#>  [1] /Users/carlgoodwin/Library/R/arm64/4.3/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

First few lines running tabnet() after tablet_pretrain() on my own data:

(It works fine with tablet_pretrain() followed by tablet_fit(). Also okay using tabnet() without the pre-trained model.)

reg_tab_resamples <- reg_tab_wflow |> 
 fit_resamples(reg_folds, control = control_stack_resamples())
→ A | error:   The size of tensor a (234) must match the size of tensor b (240) at non-singleton dimension 0
               Exception raised from infer_size_impl at /Users/dfalbel/Documents/actions-runner/mlverse-m1/_work/libtorch-mac-m1/libtorch-mac-m1/pytorch/aten/src/ATen/ExpandUtils.cpp:35 (most recent call first):
               frame #0: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) + 92 (0x1393a0834 in libc10.dylib)
               frame #1: at::infer_size_dimvector(c10::ArrayRef<long long>, c10::ArrayRef<long long>) + 424 (0x2df4811e4 in libtorch_cpu.dylib)
               frame #2: at::TensorIteratorBase::compute_shape(at::TensorIteratorConfig const&) + 460 (0x2df51bc74 in libtorch_cpu.dylib)
               frame #3: at::TensorIteratorBase::build(at::TensorIteratorConfig&) + 524 (0x2df51702c in libtorch_cpu.dylib)

@cregouby
Copy link
Collaborator

Good to see it works !

BTW you can use device="mps", with some wired messages that are only warning (as far as I can tell in #141) and do not prevent computing to end up correctly.
Rerunning your toy example, I have to admit it is not that efficient to use "mps", as time to move data is higher than time saved in GPU computing

library(tabnet)
library(tidymodels)
#> Warning: package 'scales' was built under R version 4.3.1
#> Warning: package 'dplyr' was built under R version 4.3.1
#> Warning: package 'ggplot2' was built under R version 4.3.1
#> Warning: package 'recipes' was built under R version 4.3.1
set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <- recipe(Class ~ ., train) |> 
  step_normalize(all_numeric())

set.seed(1)

tictoc::tic()
tab_pre <- tab_rec |>
  tabnet_pretrain(
    train, 
    valid_split = 0.2,
    epochs = 100,
    device = "cpu",
    verbose = FALSE
    )
tictoc::toc()
#> 45.689 sec elapsed

tictoc::tic()
tab_pre_gpu <- tab_rec |>
  tabnet_pretrain(
    train, 
    valid_split = 0.2,
    epochs = 100,
    device = "mps",
    verbose = FALSE
    )
tictoc::toc()
#> 69.116 sec elapsed

Created on 2024-01-12 with reprex v2.0.2

For the tensor-size error, I was not able to reproduce as I don't know what is behind reg_tab_wflow . Could you please open a new issue ?
When you are happy with this tabnet()evolution, please close the issue or tell me to do it.

@cgoo4
Copy link
Author

cgoo4 commented Jan 12, 2024

@cregouby thank you. I'll take another look at device = "mps". I didn't previously pursue it after seeing the messages.

I've been trying to isolate the tensor error. I've found that, with my own data, if I convert my two character variables to factors, then the stacking code all works fine, including the fit_resamples.

The puzzling thing is that if I convert the lending data factors to character the toy example doesn't fail when stacking!

But puzzle aside, I think you can close this now :) It's a great enhancement.

If I were to open a separate feature request for importance weights (with an example) is that something you could consider? This would be to address a class imbalance and would enable equivalent TabNet and XGBoost results.

@cregouby
Copy link
Collaborator

Ok, It seems you hit #124 with the character / factor problem.
Of course feel free to open issues for feature requests !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants