Skip to content

Commit

Permalink
fix tabnet_fit in vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
cregouby committed Oct 14, 2024
1 parent fe56a04 commit dd04552
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
7 changes: 4 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

## Bugfixes

* improve function documentation consistency before translation
* fix ".... is not an exported object from 'namespace:dials'" error when using tune() on tabnet parameters. (#160 @cphaarmeyer)

* fix `tabet_pretrain` wrongly used instead of `tabnet_fit` in Missing data predictor vignette
* improve message related to case_weights not being used as predictors.
* improve function documentation consistency before translation.
* fix "..." is not an exported object from 'namespace:dials'" error when using tune() on tabnet parameters. (#160 @cphaarmeyer)

# tabnet 0.6.0

Expand Down
32 changes: 16 additions & 16 deletions vignettes/Missing_data_predictors.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ Now we capture the columns with missings, and create a convenience function to c

```{r}
col_with_missings <- ames_missing %>%
summarise_all(~sum(is.na(.))>0) %>%
t %>% enframe(name="Variable") %>%
rename(has_missing="value")
summarise_all(~sum(is.na(.)) > 0) %>%
t %>% enframe(name = "Variable") %>%
rename(has_missing = "value")
vip_color <- function(object, col_has_missing) {
vip_data <- vip::vip(object)$data %>% arrange(Importance)
vis_miss_plus <- left_join(vip_data, col_has_missing , by="Variable") %>%
mutate(Variable=factor(Variable, levels = vip_data$Variable))
vis_miss_plus <- left_join(vip_data, col_has_missing , by = "Variable") %>%
mutate(Variable = factor(Variable, levels = vip_data$Variable))
vis_miss_plus
ggplot(vis_miss_plus, aes(x=Variable, y=Importance, fill=has_missing)) +
ggplot(vis_miss_plus, aes(x = Variable, y = Importance, fill = has_missing)) +
geom_col() + coord_flip() + scale_fill_grey()
}
vip_color(ames_pretrain, col_with_missings)
Expand All @@ -145,12 +145,12 @@ Let's pretrain a new model with the same hyperparameter, but now using the `ames
In order to compensate the 13% missingness already present in the `ames_missing` dataset, we adjust the `pretraining_ratio` parameter to `0.5 - 0.13 = 0.37`

```{r}
ames_missing_rec <- recipe(Sale_Price ~ ., data=ames_missing) %>%
ames_missing_rec <- recipe(Sale_Price ~ ., data = ames_missing) %>%
step_normalize(all_numeric())
ames_missing_pretrain <- tabnet_pretrain(ames_missing_rec, data=ames_missing, epoch=50,
ames_missing_pretrain <- tabnet_pretrain(ames_missing_rec, data = ames_missing, epoch = 50,
cat_emb_dim = cat_emb_dim,
valid_split = 0.2, verbose=TRUE, batch=2930,
pretraining_ratio=0.37,
valid_split = 0.2, verbose = TRUE, batch = 2930,
pretraining_ratio = 0.37,
early_stopping_patience = 3L, early_stopping_tolerance = 1e-4)
autoplot(ames_missing_pretrain)
vip_color(ames_missing_pretrain, col_with_missings)
Expand Down Expand Up @@ -183,9 +183,9 @@ We can see here no variables with high missingness is present in the top 10 impo
## Variable importance with raw `ames` dataset

```{r}
ames_fit <- tabnet_pretrain(ames_rec, data=ames, tabnet_model = ames_pretrain,
epoch=50, cat_emb_dim = cat_emb_dim,
valid_split = 0.2, verbose=TRUE, batch=2930,
ames_fit <- tabnet_fit(ames_rec, data = ames, tabnet_model = ames_pretrain,
epoch = 50, cat_emb_dim = cat_emb_dim,
valid_split = 0.2, verbose = TRUE, batch = 2930,
early_stopping_patience = 5L, early_stopping_tolerance = 1e-4)
autoplot(ames_fit)
vip_color(ames_fit, col_with_missings)
Expand All @@ -201,9 +201,9 @@ Here again, the model uses two predictors `BasmFin_SF_2` and `Garage_Finish` tha
## Variable importance with `ames_missing` dataset

```{r}
ames_missing_fit <- tabnet_pretrain(ames_rec, data=ames_missing, tabnet_model = ames_missing_pretrain,
epoch=50, cat_emb_dim = cat_emb_dim,
valid_split = 0.2, verbose=TRUE, batch=2930,
ames_missing_fit <- tabnet_fit(ames_rec, data = ames_missing, tabnet_model = ames_missing_pretrain,
epoch = 50, cat_emb_dim = cat_emb_dim,
valid_split = 0.2, verbose = TRUE, batch = 2930,
early_stopping_patience = 5L, early_stopping_tolerance = 1e-4)
autoplot(ames_missing_fit)
vip_color(ames_missing_fit, col_with_missings)
Expand Down

0 comments on commit dd04552

Please sign in to comment.