Skip to content

Commit

Permalink
add decoder GLU layers specific number of blocks (#129)
Browse files Browse the repository at this point in the history
* add decoder GRU layers specific number of parameters
  • Loading branch information
cregouby authored Sep 7, 2023
1 parent d6411d2 commit 071291c
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 72 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# tabnet (development version)

## New features
* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129)
* {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126)
* Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120)

Expand Down
22 changes: 16 additions & 6 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,16 @@ resolve_data <- function(x, y) {
#' @param learn_rate initial learning rate for the optimizer.
#' @param optimizer the optimization method. currently only 'adam' is supported,
#' you can also pass any torch optimizer function.
#' @param valid_split (float) The fraction of the dataset used for validation.
#' @param valid_split (`[0, 1)`) The fraction of the dataset used for validation.
#' (default = 0 means no split)
#' @param num_independent Number of independent Gated Linear Units layers at each step.
#' @param num_independent Number of independent Gated Linear Units layers at each step of the encoder.
#' Usual values range from 1 to 5.
#' @param num_shared Number of shared Gated Linear Units at each step Usual values
#' range from 1 to 5
#' @param num_shared Number of shared Gated Linear Units at each step of the encoder. Usual values
#' at each step of the decoder. range from 1 to 5
#' @param num_independent_decoder For pretraining, number of independent Gated Linear Units layers
#' Usual values range from 1 to 5.
#' @param num_shared_decoder For pretraining, number of shared Gated Linear Units at each step of the
#' decoder. Usual values range from 1 to 5.
#' @param verbose (logical) Whether to print progress and loss values during
#' training.
#' @param lr_scheduler if `NULL`, no learning rate decay is used. If "step"
Expand All @@ -101,7 +105,9 @@ resolve_data <- function(x, y) {
#' or `NULL`.
#' @param step_size the learning rate scheduler step size. Unused if
#' `lr_scheduler` is a `torch::lr_scheduler` or `NULL`.
#' @param cat_emb_dim Embedding size for categorical features (default=1)
#' @param cat_emb_dim Size of the embedding of categorical features. If int, all categorical
#' features will have same embedding size, if list of int, every corresponding feature will have
#' specific embedding size.
#' @param momentum Momentum for batch normalization, typically ranges from 0.01
#' to 0.4 (default=0.02)
#' @param pretraining_ratio Ratio of features to mask for reconstruction during
Expand Down Expand Up @@ -147,6 +153,8 @@ tabnet_config <- function(batch_size = 1024^2,
cat_emb_dim = 1,
num_independent = 2,
num_shared = 2,
num_independent_decoder = 1,
num_shared_decoder = 1,
momentum = 0.02,
pretraining_ratio = 0.5,
verbose = FALSE,
Expand Down Expand Up @@ -190,6 +198,8 @@ tabnet_config <- function(batch_size = 1024^2,
cat_emb_dim = cat_emb_dim,
n_independent = num_independent,
n_shared = num_shared,
n_independent_decoder = num_independent_decoder,
n_shared_decoder = num_shared_decoder,
momentum = momentum,
pretraining_ratio = pretraining_ratio,
verbose = verbose,
Expand All @@ -198,7 +208,7 @@ tabnet_config <- function(batch_size = 1024^2,
early_stopping_monitor = resolve_early_stop_monitor(early_stopping_monitor, valid_split),
early_stopping_tolerance = early_stopping_tolerance,
early_stopping_patience = early_stopping_patience,
early_stopping = !(early_stopping_tolerance==0 || early_stopping_patience==0),
early_stopping = !(early_stopping_tolerance == 0 || early_stopping_patience == 0),
num_workers = num_workers,
skip_importance = skip_importance
)
Expand Down
2 changes: 2 additions & 0 deletions R/pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
cat_emb_dim = config$cat_emb_dim,
n_independent = config$n_independent,
n_shared = config$n_shared,
n_independent_decoder = config$n_independent_decoder,
n_shared_decoder = config$n_shared_decoder,
momentum = config$momentum
)

Expand Down
29 changes: 16 additions & 13 deletions R/tab-network.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,15 @@ tabnet_decoder <- torch::nn_module(

tabnet_pretrainer <- torch::nn_module(
"tabnet_pretrainer",
initialize = function(input_dim, pretraining_ratio=0.2,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
cat_idxs=c(), cat_dims=c(),
cat_emb_dim=1, n_independent=2,
n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum = 0.02,
mask_type="sparsemax") {
initialize = function(input_dim, pretraining_ratio = 0.2,
n_d = 8, n_a = 8,
n_steps = 3, gamma = 1.3,
cat_idxs = c(), cat_dims = c(),
cat_emb_dim = 1, n_independent = 2,
n_shared = 2, n_independent_decoder = 1,
n_shared_decoder = 1, epsilon = 1e-15,
virtual_batch_size = 128, momentum = 0.02,
mask_type = "sparsemax") {

self$input_dim <- input_dim
self$pretraining_ratio <- pretraining_ratio
Expand All @@ -259,6 +260,8 @@ tabnet_pretrainer <- torch::nn_module(
self$epsilon <- epsilon
self$n_independent <- n_independent
self$n_shared <- n_shared
self$n_independent_decoder <- n_independent_decoder
self$n_shared_decoder <- n_shared_decoder
self$mask_type <- mask_type
self$initial_bn <- torch::nn_batch_norm1d(self$input_dim, momentum = momentum)

Expand Down Expand Up @@ -290,8 +293,8 @@ tabnet_pretrainer <- torch::nn_module(
self$post_embed_dim,
n_d = n_d,
n_steps = n_steps,
n_independent = n_independent,
n_shared = n_shared,
n_independent = n_independent_decoder,
n_shared = n_shared_decoder,
virtual_batch_size = virtual_batch_size,
momentum = momentum
)
Expand Down Expand Up @@ -409,14 +412,14 @@ tabnet_no_embedding <- torch::nn_module(
#' @param n_d Dimension of the prediction layer (usually between 4 and 64).
#' @param n_a Dimension of the attention layer (usually between 4 and 64).
#' @param n_steps Number of successive steps in the network (usually between 3 and 10).
#' @param gamma Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0).
#' @param gamma Float above 1, scaling factor for attention updates (usually between 1 and 2).
#' @param cat_idxs Index of each categorical column in the dataset.
#' @param cat_dims Number of categories in each categorical column.
#' @param cat_emb_dim Size of the embedding of categorical features if int, all categorical
#' features will have same embedding size if list of int, every corresponding feature will have
#' specific size.
#' @param n_independent Number of independent GLU layer in each GLU block (default 2)..
#' @param n_shared Number of independent GLU layer in each GLU block (default 2).
#' @param n_independent Number of independent GLU layer in each GLU block of the encoder.
#' @param n_shared Number of independent GLU layer in each GLU block of the encoder.
#' @param epsilon Avoid log(0), this should be kept very low.
#' @param virtual_batch_size Batch size for Ghost Batch Normalization.
#' @param momentum Float value between 0 and 1 which will be used for momentum in all batch norm.
Expand Down
6 changes: 3 additions & 3 deletions man/tabnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/tabnet_nn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 071291c

Please sign in to comment.