Skip to content

Commit

Permalink
fixing parallel stuff for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Mar 20, 2024
1 parent 699480b commit f31113a
Show file tree
Hide file tree
Showing 28 changed files with 204 additions and 219 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
^\.github$
^vignettes/articles-online-only$
^release-prep\.R$
^doc$
^Meta$
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
inst/doc
dev-helpers.R
release-prep.R
/doc/
/Meta/
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export(read_cmdstan_csv)
export(read_sample_csv)
export(rebuild_cmdstan)
export(register_knitr_engine)
export(remaining_columns_to_read)
export(set_cmdstan_path)
export(set_num_threads)
export(write_stan_file)
Expand Down
2 changes: 1 addition & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ validate_seed <- function(seed, num_procs) {
#' @return An integer vector of length `num_procs`.
maybe_generate_seed <- function(seed, num_procs) {
if (is.null(seed)) {
seed <- base::sample(.Machine$integer.max, 1)
seed <- base::sample(.Machine$integer.max, num_procs)
} else if (length(seed) == 1 && num_procs > 1) {
seed <- rep(as.integer(seed), num_procs)
} else if (length(seed) != num_procs) {
Expand Down
2 changes: 1 addition & 1 deletion R/cmdstanr-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
#' @inherit cmdstan_model examples
#' @import R6
#'
NULL
"_PACKAGE"

if (getRversion() >= "2.15.1") utils::globalVariables(c("self", "private", "super"))
1 change: 1 addition & 0 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ unrepair_variable_names <- function(names) {
names
}

#' @export
remaining_columns_to_read <- function(requested, currently_read, all) {
if (is.null(requested)) {
if (is.null(all)) {
Expand Down
4 changes: 0 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1347,11 +1347,7 @@ CmdStanMCMC <- R6::R6Class(
},
# override the CmdStanFit output method
output = function(id = NULL) {
if (is.null(id)) {
self$runset$procs$proc_output()
} else {
cat(paste(self$runset$procs$proc_output(id), collapse = "\n"))
}
},

# override the CmdStanFit draws method
Expand Down
69 changes: 47 additions & 22 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
#' data = stan_data,
#' seed = 123,
#' chains = 2,
#' parallel_chains = 2
#' threads = 2
#' )
#'
#' # Use 'posterior' package for summaries
Expand Down Expand Up @@ -341,6 +341,9 @@ CmdStanModel <- R6::R6Class(
"- ", new_hpp_loc)
private$hpp_file_ <- new_hpp_loc
invisible(private$hpp_file_)
},
threads_enabled = function() {
return(as.logical(private$cpp_options_[["STAN_THREADS"]]))
}
)
)
Expand Down Expand Up @@ -414,7 +417,6 @@ CmdStanModel <- R6::R6Class(
#' [`$expose_functions()`][model-method-expose_functions] method.
#' @param dry_run (logical) If `TRUE`, the code will do all checks before compilation,
#' but skip the actual C++ compilation. Used to speedup tests.
#'
#' @param threads Deprecated and will be removed in a future release. Please
#' turn on threading via `cpp_options = list(stan_threads = TRUE)` instead.
#'
Expand Down Expand Up @@ -461,7 +463,7 @@ compile <- function(quiet = TRUE,
pedantic = FALSE,
include_paths = NULL,
user_header = NULL,
cpp_options = list(),
cpp_options = list(stan_threads = os_use_single_process()),
stanc_options = list(),
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
Expand Down Expand Up @@ -1218,9 +1220,19 @@ sample <- function(data = NULL,
if (fixed_param) {
save_warmup <- FALSE
}
if (self$threads_enabled()) {
num_procs = 1
parallel_procs = 1
threads_per_proc = threads
} else {
num_procs = chains
parallel_procs = chains
threads_per_proc = as.integer(threads / chains)
}
procs <- CmdStanMCMCProcs$new(
num_procs = 1,
parallel_procs = 1,
num_procs = num_procs,
parallel_procs = parallel_procs,
threads_per_proc = assert_valid_threads(threads_per_proc, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
Expand Down Expand Up @@ -1266,7 +1278,7 @@ sample <- function(data = NULL,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables,
threads = threads
threads = threads_per_proc
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1382,9 +1394,19 @@ sample_mpi <- function(data = NULL,
chains <- 1
save_warmup <- FALSE
}
if (self$threads_enabled()) {
num_procs = 1
parallel_procs = 1
threads_per_proc = threads
} else {
num_procs = chains
parallel_procs = chains
threads_per_proc = as.integer(threads / num_chains)
}
procs <- CmdStanMCMCProcs$new(
num_procs = 1,
parallel_procs = 1,
num_procs = num_procs,
parallel_procs = parallel_procs,
threads_per_proc = assert_valid_threads(threads_per_proc, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
Expand Down Expand Up @@ -1455,10 +1477,6 @@ CmdStanModel$set("public", name = "sample_mpi", value = sample_mpi)
#' metadata of an example model, e.g.,
#' `cmdstanr_example(method="optimize")$metadata()`.
#' @template model-common-args
#' @param threads (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., when
#' using the Stan functions `reduce_sum()` or `map_rect()`).
#' @param iter (positive integer) The maximum number of iterations.
#' @param algorithm (string) The optimization algorithm. One of `"lbfgs"`,
#' `"bfgs"`, or `"newton"`. The control parameters below are only available
Expand Down Expand Up @@ -1509,8 +1527,12 @@ optimize <- function(data = NULL,
history_size = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
if (self$threads_enabled() && is.null(threads)) {
threads <- 1
}
procs <- CmdStanProcs$new(
num_procs = 1,
threads_per_proc = assert_valid_threads(threads, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
Expand Down Expand Up @@ -1579,7 +1601,6 @@ CmdStanModel$set("public", name = "optimize", value = optimize)
#' installed version of CmdStan.
#'
#' @template model-common-args
#' @inheritParams model-method-optimize
#' @param save_latent_dynamics Ignored for this method.
#' @param mode (multiple options) The mode to center the approximation at. One
#' of the following:
Expand Down Expand Up @@ -1647,8 +1668,12 @@ laplace <- function(data = NULL,
if (!is.null(mode) && !is.null(opt_args)) {
stop("Cannot specify both 'opt_args' and 'mode' arguments.", call. = FALSE)
}
if (self$threads_enabled() && is.null(threads)) {
threads <- 1
}
procs <- CmdStanProcs$new(
num_procs = 1,
threads_per_proc = assert_valid_threads(threads, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
Expand Down Expand Up @@ -1743,10 +1768,6 @@ CmdStanModel$set("public", name = "laplace", value = laplace)
#' installed version of CmdStan.
#'
#' @template model-common-args
#' @param threads (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., when using the Stan
#' functions `reduce_sum()` or `map_rect()`).
#' @param algorithm (string) The algorithm. Either `"meanfield"` or
#' `"fullrank"`.
#' @param iter (positive integer) The _maximum_ number of iterations.
Expand Down Expand Up @@ -1869,10 +1890,6 @@ CmdStanModel$set("public", name = "variational", value = variational)
#' installed version of CmdStan
#'
#' @template model-common-args
#' @param num_threads (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., for multi-path pathfinder
#' as well as `reduce_sum`).
#' @param init_alpha (positive real) The initial step size parameter.
#' @param tol_obj (positive real) Convergence tolerance on changes in objective function value.
#' @param tol_rel_obj (positive real) Convergence tolerance on relative changes in objective function value.
Expand Down Expand Up @@ -1938,8 +1955,12 @@ pathfinder <- function(data = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
if (self$threads_enabled() && is.null(threads)) {
threads <- 1
}
procs <- CmdStanProcs$new(
num_procs = 1,
threads_per_proc = assert_valid_threads(threads, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
Expand Down Expand Up @@ -2074,10 +2095,14 @@ generate_quantities <- function(fitted_params,
threads = NULL,
opencl_ids = NULL) {
fitted_params_files <- process_fitted_params(fitted_params)
if (self$threads_enabled() && is.null(threads)) {
threads <- 1
}
procs <- CmdStanGQProcs$new(
num_procs = length(fitted_params_files),
parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1,
null.ok = TRUE)
null.ok = TRUE),
threads_per_proc = assert_valid_threads(threads, self$cpp_options(), multiple_chains = TRUE),
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down
59 changes: 43 additions & 16 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,20 @@ CmdStanRun <- R6::R6Class(
call. = FALSE
)
}
private$latent_dynamics_files_
if (include_failed) {
private$latent_dynamics_files_
} else {
ok <- self$procs$is_finished() | self$procs$is_queued()
private$latent_dynamics_files_[ok]
}
},
output_files = function(include_failed = FALSE) {
if (include_failed) {
private$output_files_
} else {
ok <- self$procs$is_finished() | self$procs$is_queued()
private$output_files_[ok]
}
},
profile_files = function(include_failed = FALSE) {
files <- private$profile_files_
Expand All @@ -132,7 +142,12 @@ CmdStanRun <- R6::R6Class(
call. = FALSE
)
}
files
if (include_failed) {
files
} else {
ok <- self$procs$is_finished() | self$procs$is_queued()
files[ok]
}
},
save_output_files = function(dir = ".",
basename = NULL,
Expand Down Expand Up @@ -234,7 +249,7 @@ CmdStanRun <- R6::R6Class(
command = function() self$args$command(),
command_args = function(id = 1) {
# create a list of character vectors (one per run/chain) of cmdstan arguments
if (inherits(self$args$method_args, "GenerateQuantitiesArgs")) {
if (self$procs$num_procs() > 1) {
output_file = private$output_files_[id]
latent_dynamic_file = private$latent_dynamics_files_[id]
} else {
Expand Down Expand Up @@ -429,24 +444,29 @@ check_target_exe <- function(exe) {
start_time <- Sys.time()
chain_id <- 1
while (!all(procs$is_finished() | procs$is_failed())) {
procs$new_proc(
id = chain_id,
command = self$command(),
args = self$command_args(),
wd = dirname(self$exe_file()),
mpi_cmd = mpi_cmd,
mpi_args = mpi_args
)
procs$mark_proc_start(chain_id)
procs$set_active_procs(procs$active_procs() + 1)
while (procs$active_procs() != procs$parallel_procs() && procs$any_queued()) {
procs$new_proc(
id = chain_id,
command = self$command(),
args = self$command_args(chain_id),
wd = dirname(self$exe_file()),
mpi_cmd = mpi_cmd,
mpi_args = mpi_args
)
procs$mark_proc_start(chain_id)
procs$set_active_procs(procs$active_procs() + 1)
chain_id <- chain_id + 1
}
start_active_procs <- procs$active_procs()
while (procs$active_procs() == start_active_procs &&
procs$active_procs() > 0) {
procs$wait(0.1)
procs$poll(0)
if (!procs$is_queued(chain_id)) {
procs$process_output(chain_id)
procs$process_error_output(chain_id)
for (chain_iter in seq_len(chain_id)) {
if (!procs$is_queued(chain_iter)) {
procs$process_output(chain_iter)
procs$process_error_output(chain_iter)
}
}
procs$set_active_procs(procs$num_alive())
}
Expand Down Expand Up @@ -629,17 +649,20 @@ CmdStanProcs <- R6::R6Class(
public = list(
initialize = function(num_procs,
parallel_procs = NULL,
threads_per_proc = NULL,
show_stderr_messages = TRUE,
show_stdout_messages = TRUE) {
checkmate::assert_integerish(num_procs, lower = 1, len = 1, any.missing = FALSE)
checkmate::assert_integerish(parallel_procs, lower = 1, len = 1, any.missing = FALSE, null.ok = TRUE)
checkmate::assert_integerish(threads_per_proc, lower = 1, len = 1, null.ok = TRUE)
private$num_procs_ <- as.integer(num_procs)
if (is.null(parallel_procs)) {
private$parallel_procs_ <- private$num_procs_
} else {
private$parallel_procs_ <- as.integer(parallel_procs)
}
private$active_procs_ <- 0
private$threads_per_proc_ <- as.integer(threads_per_proc)
private$proc_ids_ <- seq_len(num_procs)
zeros <- rep(0, num_procs)
names(zeros) <- private$proc_ids_
Expand All @@ -662,6 +685,9 @@ CmdStanProcs <- R6::R6Class(
parallel_procs = function() {
private$parallel_procs_
},
threads_per_proc = function() {
private$threads_per_proc_
},
proc_ids = function() {
private$proc_ids_
},
Expand Down Expand Up @@ -883,6 +909,7 @@ CmdStanProcs <- R6::R6Class(
num_procs_ = integer(),
parallel_procs_ = integer(),
active_procs_ = integer(),
threads_per_proc_ = integer(),
proc_state_ = NULL,
proc_start_time_ = NULL,
proc_total_time_ = NULL,
Expand Down
4 changes: 4 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ os_is_linux <- function() {
isTRUE(Sys.info()[["sysname"]] == "Linux")
}

os_use_single_process <- function() {
return(os_is_wsl() || os_is_linux());
}

is_rtools43_toolchain <- function() {
os_is_windows() && R.version$major == "4" && R.version$minor >= "3.0"
}
Expand Down
5 changes: 5 additions & 0 deletions man-roxygen/model-common-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@
#' [`$output()`][fit-method-output] method of the resulting fit object can be
#' used to display the silenced messages.
#'
#' @param threads (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., when for multiple chains
#' running in parallel and for using the Stan functions
#' `reduce_sum()` or `map_rect()`).
Loading

0 comments on commit f31113a

Please sign in to comment.