diff --git a/DESCRIPTION b/DESCRIPTION index ed92130..685f68e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -42,7 +42,6 @@ Imports: tune, utils, vctrs, - vip, withr, zeallot Suggests: @@ -55,6 +54,7 @@ Suggests: testthat (>= 3.0.0), tidymodels, tidyverse, + vip, visdat, workflows, yardstick diff --git a/R/parsnip.R b/R/parsnip.R index 98bfb72..f205289 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -426,6 +426,7 @@ add_parsnip_tabnet <- function() { #' @param mode A single character string for the type of model. Possible values #' for this model are "unknown", "regression", or "classification". #' @inheritParams tabnet_config +#' @inheritParams tabnet_fit #' #' @inheritSection tabnet_fit Threading #' @seealso tabnet_fit @@ -444,8 +445,7 @@ add_parsnip_tabnet <- function() { #' #' @export tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, attention_width = NULL, - num_steps = NULL, mask_type = NULL, mlp_hidden_multiplier = NULL, mlp_activation = NULL, - encoder_activation = NULL, num_independent = NULL, num_shared = NULL, + num_steps = NULL, mask_type = NULL, num_independent = NULL, num_shared = NULL, num_independent_decoder = NULL, num_shared_decoder = NULL, penalty = NULL, feature_reusage = NULL, momentum = NULL, epochs = NULL, batch_size = NULL, virtual_batch_size = NULL, learn_rate = NULL, optimizer = NULL, loss = NULL, @@ -471,9 +471,6 @@ tabnet <- function(mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, attention_width = rlang::enquo(attention_width), num_steps = rlang::enquo(num_steps), mask_type = rlang::enquo(mask_type), - mlp_hidden_multiplier = rlang::enquo(mlp_hidden_multiplier), - mlp_activation = rlang::enquo(mlp_activation), - encoder_activation = rlang::enquo(encoder_activation), num_independent = rlang::enquo(num_independent), num_shared = rlang::enquo(num_shared), num_independent_decoder = rlang::enquo(num_independent_decoder), diff --git a/R/plot.R b/R/plot.R index d05fec4..c1b681c 100644 --- a/R/plot.R +++ b/R/plot.R @@ -40,8 +40,8 @@ autoplot.tabnet_fit <- function(object, ...) { if ("checkpoint" %in% names(collect_metrics)) { checkpoints <- collect_metrics %>% - dplyr::filter(checkpoint == TRUE, dataset == "train") %>% - dplyr::select(-checkpoint) %>% + dplyr::filter(.data$checkpoint == TRUE, dataset == "train") %>% + dplyr::select(-.data$checkpoint) %>% dplyr::mutate(size = 2) p + ggplot2::geom_point(data = checkpoints, ggplot2::aes(x = epoch, y = loss, color = dataset, size = .data$size )) diff --git a/man/tabnet.Rd b/man/tabnet.Rd index 5b19c9c..d4684b1 100644 --- a/man/tabnet.Rd +++ b/man/tabnet.Rd @@ -11,9 +11,6 @@ tabnet( attention_width = NULL, num_steps = NULL, mask_type = NULL, - mlp_hidden_multiplier = NULL, - mlp_activation = NULL, - encoder_activation = NULL, num_independent = NULL, num_shared = NULL, num_independent_decoder = NULL, @@ -142,6 +139,12 @@ display a warning.} \item{early_stopping_patience}{Number of epochs without improving until stopping training. (default=5)} \item{skip_importance}{if feature importance calculation should be skipped (default: \code{FALSE})} + +\item{tabnet_model}{A previously fitted TabNet model object to continue the fitting on. +if \code{NULL} (the default) a brand new model is initialized.} + +\item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch. +Default is last available checkpoint for restored model, or last epoch for in-memory model.} } \value{ A TabNet \code{parsnip} instance. It can be used to fit tabnet models using