Skip to content

Commit

Permalink
Add french FR translation (#131)
Browse files Browse the repository at this point in the history
* switch to base:: messaging system for the translation and add FR translation
  • Loading branch information
cregouby authored Sep 27, 2023
1 parent 071291c commit 9da9e4e
Show file tree
Hide file tree
Showing 15 changed files with 571 additions and 63 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ Imports:
torch (>= 0.4.0),
hardhat (>= 1.3.0),
magrittr,
glue,
progress,
rlang,
methods,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# tabnet (development version)

## New features
* add FR translation (#131)
* `tabnet_pretrain()` now allows different GLU blocks in GLU layers in encoder and in decoder through the `config()` parameters `num_idependant_decoder` and `num_shared_decoder` (#129)
* {tabnet} now allows hierarchical multi-label classification through {data.tree} hierarchical `Node` dataset. (#126)
* Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120)
Expand Down
2 changes: 1 addition & 1 deletion R/dials.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
check_dials <- function() {
if (!requireNamespace("dials", quietly = TRUE))
rlang::abort("Package \"dials\" needed for this function to work. Please install it.")
stop("Package \"dials\" needed for this function to work. Please install it.", call. = FALSE)
}


Expand Down
7 changes: 3 additions & 4 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ tabnet_explain <- function(object, new_data) {
#' @export
#' @rdname tabnet_explain
tabnet_explain.default <- function(object, new_data) {
stop(
"`tabnet_explain()` is not defined for a '", class(object)[1], "'.",
call. = FALSE
)
stop(domain=NA,
gettextf("`tabnet_explain()` is not defined for a '%s'.", class(object)[1]),
call. = FALSE)
}

#' @export
Expand Down
49 changes: 24 additions & 25 deletions R/hardhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ tabnet_fit <- function(x, ...) {
#' @export
#' @rdname tabnet_fit
tabnet_fit.default <- function(x, ...) {
stop(
"`tabnet_fit()` is not defined for a '", class(x)[1], "'.",
call. = FALSE
)
stop(domain=NA,
gettextf("`tabnet_fit()` is not defined for a '%s'.", class(x)[1]),
call. = FALSE)
}

#' @export
Expand Down Expand Up @@ -293,10 +292,9 @@ tabnet_pretrain <- function(x, ...) {
#' @export
#' @rdname tabnet_pretrain
tabnet_pretrain.default <- function(x, ...) {
stop(
"`tabnet_pretrain()` is not defined for a '", class(x)[1], "'.",
call. = FALSE
)
stop(domain=NA,
gettextf("`tabnet_pretrain()` is not defined for a '%s'.", class(x)[1]),
call. = FALSE)
}


Expand Down Expand Up @@ -390,13 +388,14 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
epoch_shift <- 0L

if (!(is.null(tabnet_model) || inherits(tabnet_model, "tabnet_fit") || inherits(tabnet_model, "tabnet_pretrain")))
rlang::abort(glue::glue("{tabnet_model} is not recognised as a proper TabNet model"))
stop(gettextf("'%s' is not recognised as a proper TabNet model", tabnet_model),
call. = FALSE)

if (!is.null(from_epoch) && !is.null(tabnet_model)) {
# model must be loaded from checkpoint

if (from_epoch > (length(tabnet_model$fit$checkpoints) * tabnet_model$fit$config$checkpoint_epoch))
rlang::abort(glue::glue("The model was trained for less than {from_epoch} epochs"))
stop(gettextf("The model was trained for less than '%s' epochs", from_epoch), call. = FALSE)

# find closest checkpoint for that epoch
closest_checkpoint <- from_epoch %/% tabnet_model$fit$config$checkpoint_epoch
Expand All @@ -408,7 +407,7 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
}
if (task == "supervised") {
if (sum(is.na(outcomes)) > 0) {
rlang::abort(glue::glue("Error: found missing values in the `{names(outcomes)}` outcome column."))
stop(gettextf("Found missing values in the `%s` outcome column.", names(outcomes)), call. = FALSE)
}
if (is.null(tabnet_model)) {
# new supervised model needs network initialization
Expand All @@ -418,7 +417,7 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
} else if (!check_net_is_empty_ptr(tabnet_model) && inherits(tabnet_model, "tabnet_fit")) {
# resume training from supervised
if (!identical(processed$blueprint, tabnet_model$blueprint))
rlang::abort("Model dimensions don't match.")
stop("Model dimensions don't match.", call. = FALSE)

# model is available from tabnet_model$serialized_net
m <- reload_model(tabnet_model$serialized_net)
Expand All @@ -443,7 +442,7 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
tabnet_model$fit$network <- reload_model(tabnet_model$fit$checkpoints[[last_checkpoint]])
epoch_shift <- last_checkpoint * tabnet_model$fit$config$checkpoint_epoch

} else rlang::abort(glue::glue("No model serialized weight can be found in {tabnet_model} check the model history"))
} else stop(gettextf("No model serialized weight can be found in `%s`, check the model history", tabnet_model), call. = FALSE)

fit_lst <- tabnet_train_supervised(tabnet_model, predictors, outcomes, config = config, epoch_shift)
return(new_tabnet_fit(fit_lst, blueprint = processed$blueprint))
Expand Down Expand Up @@ -485,7 +484,7 @@ predict_tabnet_bridge <- function(type, object, predictors, epoch, batch_size) {
if (!is.null(epoch)) {

if (epoch > (length(object$fit$checkpoints) * object$fit$config$checkpoint_epoch))
rlang::abort(glue::glue("The model was trained for less than {epoch} epochs"))
stop(gettextf("The model was trained for less than `%s` epochs", epoch), call. = FALSE)

# find closest checkpoint for that epoch
ind <- epoch %/% object$fit$config$checkpoint_epoch
Expand Down Expand Up @@ -530,7 +529,7 @@ model_pretrain_to_fit <- function(obj, x, y, config = tabnet_config()) {
m <- reload_model(obj$serialized_net)

if (m$input_dim != tabnet_model_lst$network$input_dim)
rlang::abort("Model dimensions don't match.")
stop("Model dimensions don't match.", call. = FALSE)

# perform update of selected weights into new tabnet_model
m_stat_dict <- m$state_dict()
Expand Down Expand Up @@ -582,25 +581,25 @@ check_type <- function(outcome_ptype, type = NULL) {
outcome_all_numeric <- all(purrr::map_lgl(outcome_ptype, is.numeric))

if (!outcome_all_numeric && !outcome_all_factor)
rlang::abort(glue::glue("Mixed multi-outcome type '{unique(purrr::map_chr(outcome_ptype, ~class(.x)[[1]]))}' is not supported"))
stop(gettextf("Mixed multi-outcome type '%s' is not supported", unique(purrr::map_chr(outcome_ptype, ~class(.x)[[1]]))), call. = FALSE)

if (is.null(type)) {
if (outcome_all_factor)
type <- "class"
else if (outcome_all_numeric)
type <- "numeric"
else if (ncol(outcome_ptype) == 1)
rlang::abort(glue::glue("Unknown outcome type '{class(outcome_ptype)}'"))
stop(gettextf("Unknown outcome type '%s'", class(outcome_ptype)), call. = FALSE)
}

type <- rlang::arg_match(type, c("numeric", "prob", "class"))

if (outcome_all_factor) {
if (!type %in% c("prob", "class"))
rlang::abort(glue::glue("Outcome is factor and the prediction type is '{type}'."))
stop(gettextf("Outcome is factor and the prediction type is '%s'.", type), call. = FALSE)
} else if (outcome_all_numeric) {
if (type != "numeric")
rlang::abort(glue::glue("Outcome is numeric and the prediction type is '{type}'."))
stop(gettextf("Outcome is numeric and the prediction type is '%s'.", type), call. = FALSE)
}

invisible(type)
Expand Down Expand Up @@ -638,15 +637,15 @@ check_compliant_node <- function(node) {
reserved_names <- c(paste0("level_", c(1:node_height)), data.tree::NODE_RESERVED_NAMES_CONST)
actual_names <- colnames(node)[!colnames(node) %in% "pathString"]
} else {
rlang::abort("The provided hierarchical object is not recognized with a valid format that can be checked")
stop("The provided hierarchical object is not recognized with a valid format that can be checked", call. = FALSE)
}

if (any(actual_names %in% reserved_names)) {
rlang::abort(paste0(
"The attributes or colnames in the provided hierarchical object use the following reserved names : '",
paste(actual_names[actual_names %in% reserved_names], collapse = "', '"),
"'. Please change those names as they will lead to unexpected tabnet behavior."
))
stop(domain=NA,
gettextf("The attributes or colnames in the provided hierarchical object use the following reserved names : '%s'. Please change those names as they will lead to unexpected tabnet behavior.",
paste(actual_names[actual_names %in% reserved_names], collapse = "', '")
),
call. = FALSE)
}

invisible(node)
Expand Down
26 changes: 13 additions & 13 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ resolve_loss <- function(config, dtype) {
# cross entropy loss is required
loss_fn <- torch::nn_cross_entropy_loss()
else
rlang::abort(glue::glue("{loss} is not a valid loss for outcome of type {dtype}"))
stop(gettextf("`%s` is not a valid loss for outcome of type %s", loss, dtype), call. = FALSE)

loss_fn
}
Expand All @@ -250,7 +250,7 @@ resolve_early_stop_monitor <- function(early_stopping_monitor, valid_split) {
else if (early_stopping_monitor %in% c("train_loss", "auto"))
early_stopping_monitor <- "train_loss"
else
rlang::abort(glue::glue("{early_stopping_monitor} is not a valid early-stopping metric to monitor with `valid_split` = {valid_split}"))
stop(gettextf("%s is not a valid early-stopping metric to monitor with `valid_split` = %s", early_stopping_monitor, valid_split), call. = FALSE)

early_stopping_monitor
}
Expand Down Expand Up @@ -506,7 +506,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
if (config$optimizer == "adam")
optimizer <- torch::optim_adam(network$parameters, lr = config$learn_rate)
else
rlang::abort("Currently only the 'adam' optimizer is supported.")
stop("Currently only the 'adam' optimizer is supported.", call. = FALSE)

}

Expand All @@ -520,7 +520,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
} else if (config$lr_scheduler == "step") {
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
} else {
rlang::abort("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.")
stop("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.", call. = FALSE)
}

# restore previous metrics & checkpoints
Expand Down Expand Up @@ -565,12 +565,11 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)
}

message <- sprintf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train$loss))
if (has_valid)
message <- paste0(message, sprintf(" Valid loss: %3f", mean(metrics[[epoch]]$valid$loss)))
if (config$verbose & !has_valid)
message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train$loss)))
if (config$verbose & has_valid)
message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train$loss), mean(metrics[[epoch]]$valid$loss)))

if (config$verbose)
rlang::inform(message)

# Early-stopping checks
if (config$early_stopping && config$early_stopping_monitor=="valid_loss"){
Expand All @@ -585,7 +584,7 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
patience_counter <- patience_counter + 1
if (patience_counter >= config$early_stopping_patience){
if (config$verbose)
rlang::inform(sprintf("Early stopping at epoch %03d", epoch))
message(gettextf("Early stopping at epoch %03d", epoch))
break
}
} else {
Expand All @@ -610,9 +609,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
if(!config$skip_importance) {
importance_sample_size <- config$importance_sample_size
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
rlang::warn(c(glue::glue("Computing importances for a dataset with size {train_ds$.length()}."),
"This can consume too much memory. We are going to use a sample of size 1e5",
"You can disable this message by using the `importance_sample_size` argument."))
warning(
gettextf(
"Computing importances for a dataset with size %s. This can consume too much memory. We are going to use a sample of size 1e5, You can disable this message by using the `importance_sample_size` argument.",
train_ds$.length()))
importance_sample_size <- 1e5
}
indexes <- as.numeric(torch::torch_randint(
Expand Down
2 changes: 1 addition & 1 deletion R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ tabnet <- function(mode = "unknown", epochs = NULL, penalty = NULL, batch_size =
num_independent = NULL, num_shared = NULL, momentum = NULL) {

if (!requireNamespace("parsnip", quietly = TRUE))
rlang::abort("Package \"parsnip\" needed for this function to work. Please install it.")
stop("Package \"parsnip\" needed for this function to work. Please install it.", call. = FALSE)

if (!tabnet_env$parsnip_added) {
add_parsnip_tabnet()
Expand Down
20 changes: 9 additions & 11 deletions R/pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
if (config$optimizer == "adam")
optimizer <- torch::optim_adam(network$parameters, lr = config$learn_rate)
else
rlang::abort("Currently only the 'adam' optimizer is supported.")
stop("Currently only the 'adam' optimizer is supported.", call. = FALSE)

}

Expand All @@ -150,7 +150,7 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
} else if (config$lr_scheduler == "step") {
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
} else {
rlang::abort("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.")
stop("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.", call. = FALSE)
}

# initialize metrics & checkpoints
Expand Down Expand Up @@ -195,12 +195,10 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)
}

message <- sprintf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train$loss))
if (has_valid)
message <- paste0(message, sprintf(" Valid loss: %3f", mean(metrics[[epoch]]$valid$loss)))

if (config$verbose)
rlang::inform(message)
if (config$verbose & !has_valid)
message(gettextf("[Epoch %03d] Loss: %3f", epoch, mean(metrics[[epoch]]$train$loss)))
if (config$verbose & has_valid)
message(gettextf("[Epoch %03d] Loss: %3f, Valid loss: %3f", epoch, mean(metrics[[epoch]]$train$loss), mean(metrics[[epoch]]$valid$loss)))

# Early-stopping checks
if (config$early_stopping && config$early_stopping_monitor=="valid_loss"){
Expand Down Expand Up @@ -240,9 +238,9 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =

importance_sample_size <- config$importance_sample_size
if (is.null(config$importance_sample_size) && train_ds$.length() > 1e5) {
rlang::warn(c(glue::glue("Computing importances for a dataset with size {train_ds$.length()}."),
"This can consume too much memory. We are going to use a sample of size 1e5",
"You can disable this message by using the `importance_sample_size` argument."))
warning(domain=NA,
gettextf("Computing importances for a dataset with size %s. This can consume too much memory. We are going to use a sample of size 1e5. You can disable this message by using the `importance_sample_size` argument.", train_ds$.length()),
call. = FALSE)
importance_sample_size <- 1e5
}
indexes <- as.numeric(torch::torch_randint(
Expand Down
10 changes: 5 additions & 5 deletions R/tab-network.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ tabnet_pretrainer <- torch::nn_module(
self$initial_bn <- torch::nn_batch_norm1d(self$input_dim, momentum = momentum)

if (self$n_steps <= 0)
stop("n_steps should be a positive integer.")
stop("'n_steps' should be a positive integer.")
if (self$n_independent == 0 && self$n_shared == 0)
stop("n_shared and n_independant can't be both zero.")
stop("'n_shared' and 'n_independant' can't be both zero.")

self$virtual_batch_size <- virtual_batch_size
self$embedder <- embedding_generator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
Expand Down Expand Up @@ -452,9 +452,9 @@ tabnet_nn <- torch::nn_module(
self$mask_type <- mask_type

if (self$n_steps <= 0)
stop("n_steps should be a positive integer.")
stop("'n_steps' should be a positive integer.")
if (self$n_independent == 0 && self$n_shared == 0)
stop("n_shared and n_independant can't be both zero.")
stop("'n_shared' and 'n_independant' can't be both zero.")

self$virtual_batch_size <- virtual_batch_size
self$embedder <- embedding_generator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
Expand Down Expand Up @@ -494,7 +494,7 @@ attentive_transformer <- torch::nn_module(
else if (mask_type == "entmax")
self$selector <- entmax(dim = -1)
else
stop("Please choose either sparsemax or entmax as masktype")
stop("Please choose either 'sparsemax' or 'entmax' as 'mask_type'")

},
forward = function(priors, processed_feat) {
Expand Down
Loading

0 comments on commit 9da9e4e

Please sign in to comment.