Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add decoder GLU layers specific number of blocks #129

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# tabnet (development version)

## New features
* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129)
* {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126)
* Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120)

Expand Down
22 changes: 16 additions & 6 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,16 @@ resolve_data <- function(x, y) {
#' @param learn_rate initial learning rate for the optimizer.
#' @param optimizer the optimization method. currently only 'adam' is supported,
#' you can also pass any torch optimizer function.
#' @param valid_split (float) The fraction of the dataset used for validation.
#' @param valid_split (`[0, 1)`) The fraction of the dataset used for validation.
#' (default = 0 means no split)
#' @param num_independent Number of independent Gated Linear Units layers at each step.
#' @param num_independent Number of independent Gated Linear Units layers at each step of the encoder.
#' Usual values range from 1 to 5.
#' @param num_shared Number of shared Gated Linear Units at each step Usual values
#' range from 1 to 5
#' @param num_shared Number of shared Gated Linear Units at each step of the encoder. Usual values
#' at each step of the decoder. range from 1 to 5
#' @param num_independent_decoder For pretraining, number of independent Gated Linear Units layers
#' Usual values range from 1 to 5.
#' @param num_shared_decoder For pretraining, number of shared Gated Linear Units at each step of the
#' decoder. Usual values range from 1 to 5.
#' @param verbose (logical) Whether to print progress and loss values during
#' training.
#' @param lr_scheduler if `NULL`, no learning rate decay is used. If "step"
Expand All @@ -101,7 +105,9 @@ resolve_data <- function(x, y) {
#' or `NULL`.
#' @param step_size the learning rate scheduler step size. Unused if
#' `lr_scheduler` is a `torch::lr_scheduler` or `NULL`.
#' @param cat_emb_dim Embedding size for categorical features (default=1)
#' @param cat_emb_dim Size of the embedding of categorical features. If int, all categorical
#' features will have same embedding size, if list of int, every corresponding feature will have
#' specific embedding size.
#' @param momentum Momentum for batch normalization, typically ranges from 0.01
#' to 0.4 (default=0.02)
#' @param pretraining_ratio Ratio of features to mask for reconstruction during
Expand Down Expand Up @@ -147,6 +153,8 @@ tabnet_config <- function(batch_size = 1024^2,
cat_emb_dim = 1,
num_independent = 2,
num_shared = 2,
num_independent_decoder = 1,
num_shared_decoder = 1,
momentum = 0.02,
pretraining_ratio = 0.5,
verbose = FALSE,
Expand Down Expand Up @@ -190,6 +198,8 @@ tabnet_config <- function(batch_size = 1024^2,
cat_emb_dim = cat_emb_dim,
n_independent = num_independent,
n_shared = num_shared,
n_independent_decoder = num_independent_decoder,
n_shared_decoder = num_shared_decoder,
momentum = momentum,
pretraining_ratio = pretraining_ratio,
verbose = verbose,
Expand All @@ -198,7 +208,7 @@ tabnet_config <- function(batch_size = 1024^2,
early_stopping_monitor = resolve_early_stop_monitor(early_stopping_monitor, valid_split),
early_stopping_tolerance = early_stopping_tolerance,
early_stopping_patience = early_stopping_patience,
early_stopping = !(early_stopping_tolerance==0 || early_stopping_patience==0),
early_stopping = !(early_stopping_tolerance == 0 || early_stopping_patience == 0),
num_workers = num_workers,
skip_importance = skip_importance
)
Expand Down
2 changes: 2 additions & 0 deletions R/pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
cat_emb_dim = config$cat_emb_dim,
n_independent = config$n_independent,
n_shared = config$n_shared,
n_independent_decoder = config$n_independent_decoder,
n_shared_decoder = config$n_shared_decoder,
momentum = config$momentum
)

Expand Down
29 changes: 16 additions & 13 deletions R/tab-network.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,15 @@ tabnet_decoder <- torch::nn_module(

tabnet_pretrainer <- torch::nn_module(
"tabnet_pretrainer",
initialize = function(input_dim, pretraining_ratio=0.2,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
cat_idxs=c(), cat_dims=c(),
cat_emb_dim=1, n_independent=2,
n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum = 0.02,
mask_type="sparsemax") {
initialize = function(input_dim, pretraining_ratio = 0.2,
n_d = 8, n_a = 8,
n_steps = 3, gamma = 1.3,
cat_idxs = c(), cat_dims = c(),
cat_emb_dim = 1, n_independent = 2,
n_shared = 2, n_independent_decoder = 1,
n_shared_decoder = 1, epsilon = 1e-15,
virtual_batch_size = 128, momentum = 0.02,
mask_type = "sparsemax") {

self$input_dim <- input_dim
self$pretraining_ratio <- pretraining_ratio
Expand All @@ -259,6 +260,8 @@ tabnet_pretrainer <- torch::nn_module(
self$epsilon <- epsilon
self$n_independent <- n_independent
self$n_shared <- n_shared
self$n_independent_decoder <- n_independent_decoder
self$n_shared_decoder <- n_shared_decoder
self$mask_type <- mask_type
self$initial_bn <- torch::nn_batch_norm1d(self$input_dim, momentum = momentum)

Expand Down Expand Up @@ -290,8 +293,8 @@ tabnet_pretrainer <- torch::nn_module(
self$post_embed_dim,
n_d = n_d,
n_steps = n_steps,
n_independent = n_independent,
n_shared = n_shared,
n_independent = n_independent_decoder,
n_shared = n_shared_decoder,
virtual_batch_size = virtual_batch_size,
momentum = momentum
)
Expand Down Expand Up @@ -409,14 +412,14 @@ tabnet_no_embedding <- torch::nn_module(
#' @param n_d Dimension of the prediction layer (usually between 4 and 64).
#' @param n_a Dimension of the attention layer (usually between 4 and 64).
#' @param n_steps Number of successive steps in the network (usually between 3 and 10).
#' @param gamma Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0).
#' @param gamma Float above 1, scaling factor for attention updates (usually between 1 and 2).
#' @param cat_idxs Index of each categorical column in the dataset.
#' @param cat_dims Number of categories in each categorical column.
#' @param cat_emb_dim Size of the embedding of categorical features if int, all categorical
#' features will have same embedding size if list of int, every corresponding feature will have
#' specific size.
#' @param n_independent Number of independent GLU layer in each GLU block (default 2)..
#' @param n_shared Number of independent GLU layer in each GLU block (default 2).
#' @param n_independent Number of independent GLU layer in each GLU block of the encoder.
#' @param n_shared Number of independent GLU layer in each GLU block of the encoder.
#' @param epsilon Avoid log(0), this should be kept very low.
#' @param virtual_batch_size Batch size for Ghost Batch Normalization.
#' @param momentum Float value between 0 and 1 which will be used for momentum in all batch norm.
Expand Down
6 changes: 3 additions & 3 deletions man/tabnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/tabnet_nn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading