Skip to content

Commit

Permalink
feat: restart lost workers
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 30, 2023
1 parent 531b231 commit 80dd506
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 172 deletions.
101 changes: 99 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,99 @@
inst/doc
attic
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,r
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,r

### Linux ###
*~

# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*

# KDE directory preferences
.directory

# Linux trash folder which might appear on any partition or disk
.Trash-*

# .nfs files are created when an open file is removed but is still being accessed
.nfs*

### R ###
# History files
.Rhistory
.Rapp.history

# Session Data files
.RData
.RDataTmp

# User-specific files
.Ruserdata

# Example code in package build process
*-Ex.R

# Output files from R CMD build
/*.tar.gz

# Output files from R CMD check
/*.Rcheck/

# RStudio files
.Rproj.user/

# produced vignettes
vignettes/*.html
vignettes/*.pdf

# OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
.httr-oauth

# knitr and R markdown default cache directories
*_cache/
/cache/

# Temporary files created by R markdown
*.utf8.md
*.knit.md

# R Environment Variables
.Renviron

# pkgdown site
docs/

# translation temp files
po/*~

# RStudio Connect folder
rsconnect/

### R.Bookdown Stack ###
# R package: bookdown caching files
/*_files/

### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets

# Local History for Visual Studio Code
.history/

# Built Visual Studio Code Extensions
*.vsix

### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide

# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,r

# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)

.Rprofile
attic/
159 changes: 93 additions & 66 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ Rush = R6::R6Class("Rush",
}
self$connector = redux::hiredis(self$config)
private$.hostname = get_hostname()
private$.pid_exists = choose_pid_exists()
},

#' @description
Expand Down Expand Up @@ -166,6 +165,8 @@ Rush = R6::R6Class("Rush",
#'
#' @param n_workers (`integer(1)`)\cr
#' Number of workers to be started.
#' @param supervise (`logical(1)`)\cr
#' Whether to kill the workers when the main R process is shut down.
#' @param wait_for_workers (`logical(1)`)\cr
#' Whether to wait until all workers are available.
#' @param ... (`any`)\cr
Expand All @@ -179,12 +180,13 @@ Rush = R6::R6Class("Rush",
heartbeat_expire = NULL,
lgr_thresholds = NULL,
lgr_buffer_size = 0,

supervise = TRUE,
worker_loop = worker_loop_default,
...
) {
n_workers = assert_count(n_workers %??% rush_env$n_workers)
assert_flag(wait_for_workers)
assert_flag(supervise)

# push worker config to redis
private$.push_worker_config(
Expand All @@ -202,15 +204,35 @@ Rush = R6::R6Class("Rush",
self$processes = c(self$processes, set_names(map(worker_ids, function(worker_id) {
processx::process$new("Rscript",
args = c("-e", sprintf("rush::start_worker(network_id = '%s', worker_id = '%s', hostname = '%s', url = '%s')",
self$network_id, worker_id, private$.hostname, self$config$url)))
self$network_id, worker_id, private$.hostname, self$config$url)),
supervise = supervise)
}), worker_ids))

if (wait_for_workers) self$wait_for_workers(n_workers)

return(invisible(worker_ids))
},

#' @description
#' Restart workers.
#'
#' @param worker_ids (`character()`)\cr
#' Worker ids to be restarted.
restart_workers = function(worker_ids) {
assert_subset(unlist(worker_ids), self$worker_ids)
r = self$connector

lg$error("Restarting %i worker(s): %s", length(worker_ids), str_collapse(worker_ids))
processes = set_names(map(worker_ids, function(worker_id) {
# restart worker
processx::process$new("Rscript",
args = c("-e", sprintf("rush::start_worker(network_id = '%s', worker_id = '%s', hostname = '%s', url = '%s')",
self$network_id, worker_id, private$.hostname, self$config$url)))
}), worker_ids)
self$processes = insert_named(self$processes, processes)

return(invisible(worker_ids))
},

#' @description
#' Create script to start workers.
Expand Down Expand Up @@ -290,30 +312,33 @@ Rush = R6::R6Class("Rush",

} else if (type == "kill") {
worker_info = self$worker_info[list(worker_ids), , on = "worker_id"]

# kill local
local_workers = worker_info[list("local"), , on = c("host"), nomatch = NULL]
local_workers = worker_info[list("local"), worker_id, on = c("host"), nomatch = NULL]
lg$debug("Killing %i local worker(s) %s", length(local_workers), as_short_string(local_workers))

if (nrow(local_workers)) {
tools::pskill(local_workers$pid)
cmds = map(local_workers$worker_id, function(worker_id) {
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("killed_worker_ids"), worker_id)
})
r$pipeline(.commands = cmds)
}
# kill with processx
walk(local_workers, function(worker_id) {
killed = self$processes[[worker_id]]$kill()
if (!killed) lg$error("Failed to kill worker %s", worker_id)
})

# set worker state
cmds_local = map(local_workers, function(worker_id) {
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("killed_worker_ids"), worker_id)
})

# kill remote
remote_workers = worker_info[list("remote"), worker_id, on = c("host"), nomatch = NULL]

if (length(remote_workers)) {
# push kill signal to heartbeat
cmds = unlist(map(remote_workers, function(worker_id) {
list(
c("LPUSH", private$.get_worker_key("kill", worker_id), "TRUE"),
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("killed_worker_ids"), worker_id))
}), recursive = FALSE)
r$pipeline(.commands = cmds)
}
remote_workers = worker_info [list("remote"), worker_id, on = c("host"), nomatch = NULL]
lg$debug("Killing %i remote worker(s) %s", length(remote_workers), as_short_string(remote_workers))

# push kill signal to heartbeat process and set worker state
cmds_remote = unlist(map(remote_workers, function(worker_id) {
list(
c("LPUSH", private$.get_worker_key("kill", worker_id), "TRUE"),
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("killed_worker_ids"), worker_id))
}), recursive = FALSE)

r$pipeline(.commands = c(cmds_local, cmds_remote))
}

return(invisible(self))
Expand All @@ -326,12 +351,17 @@ Rush = R6::R6Class("Rush",
#' Checking local workers on unix systems only takes a few microseconds per worker.
#' But checking local workers on windows might be very slow.
#' Workers with a heartbeat process are checked with the heartbeat.
detect_lost_workers = function() {
#' Lost tasks are marked as `"lost"`.
#'
#' @param restart (`logical(1)`)\cr
#' Whether to restart lost workers.
detect_lost_workers = function(restart = FALSE) {
assert_flag(restart)
r = self$connector

# check workers with a heartbeat
heartbeat_keys = r$SMEMBERS(private$.get_key("heartbeat_keys"))
if (length(heartbeat_keys)) {
lost_workers = if (length(heartbeat_keys)) {
lg$debug("Checking %i worker(s) with heartbeat", length(heartbeat_keys))
running = as.logical(r$pipeline(.commands = map(heartbeat_keys, function(heartbeat_key) c("EXISTS", heartbeat_key))))
if (all(running)) return(invisible(self))
Expand All @@ -340,50 +370,55 @@ Rush = R6::R6Class("Rush",
heartbeat_keys = heartbeat_keys[!running]
lost_workers = self$worker_info[heartbeat == heartbeat_keys, worker_id]

# move worker ids to lost workers set and remove heartbeat keys
cmds = map(lost_workers, function(worker_id) c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("lost_worker_ids"), worker_id))
r$pipeline(.commands = c(cmds, list(c("SREM", "heartbeat_keys", heartbeat_keys))))
lg$error("Lost %i worker(s): %s", length(lost_workers), str_collapse(lost_workers))
# set worker state
cmds = map(lost_workers, function(worker_id) {
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("lost_worker_ids"), worker_id)
})

# remove heartbeat keys
cmds = c(cmds, list(c("SREM", "heartbeat_keys", heartbeat_keys)))
r$pipeline(.commands = cmds)
lost_workers
}

# check local workers without a heartbeat
local_pids = r$SMEMBERS(private$.get_key("local_pids"))
if (length(local_pids)) {
lg$debug("Checking %i worker(s) with process id", length(local_pids))
running = map_lgl(local_pids, function(pid) private$.pid_exists(pid))
local_workers = r$SMEMBERS(private$.get_key("local_workers"))
lost_workers = if (length(local_workers)) {
lg$debug("Checking %i worker(s) with process id", length(local_workers))
running = map_lgl(local_workers, function(worker_id) self$processes[[worker_id]]$is_alive())
if (all(running)) return(invisible(self))

# search for associated worker ids
local_pids = local_pids[!running]
lost_workers = self$worker_info[pid == local_pids, worker_id]

# move worker ids to lost workers set and remove pids
cmds = map(lost_workers, function(worker_id) c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("lost_worker_ids"), worker_id))
r$pipeline(.commands = c(cmds, list(c("SREM", private$.get_key("local_pids"), local_pids))))
lost_workers = local_workers[!running]
lg$error("Lost %i worker(s): %s", length(lost_workers), str_collapse(lost_workers))
}

return(invisible(self))
},
if (restart) {
self$restart_workers(unlist(lost_workers))
lost_workers
} else {
# set worker state
cmds = map(lost_workers, function(worker_id) {
c("SMOVE", private$.get_key("running_worker_ids"), private$.get_key("lost_worker_ids"), worker_id)
})

#' @description
#' Detect lost tasks.
#' Changes the state of tasks to `"lost"` if the worker crashed.
detect_lost_tasks = function() {
r = self$connector
if (!self$n_workers) return(invisible(self))
self$detect_lost_workers()
lost_workers = self$lost_worker_ids
# remove local pids
cmds = c(cmds, list(c("SREM", private$.get_key("local_workers"), lost_workers)))
r$pipeline(.commands = cmds)
lost_workers
}
}

# mark lost tasks
if (length(lost_workers)) {
running_tasks = self$fetch_running_tasks(fields = "worker_extra")
if (!nrow(running_tasks)) return(invisible(self))
bin_state = redux::object_to_bin(list(state = "lost"))
keys = running_tasks[lost_workers, keys, on = "worker_id"]

lg$error("Lost %i task(s): %s", length(keys), str_collapse(keys))

cmds = unlist(map(keys, function(key) {
list(
c("HSET", key, "state", list(bin_state)),
list("HSET", key, "state", failed_state),
c("SREM", private$.get_key("running_tasks"), key),
c("RPUSH", private$.get_key("failed_tasks"), key))
}), recursive = FALSE)
Expand Down Expand Up @@ -416,7 +451,6 @@ Rush = R6::R6Class("Rush",
r$DEL(private$.get_worker_key("kill", worker_id))
r$DEL(private$.get_worker_key("heartbeat", worker_id))
r$DEL(private$.get_worker_key("queued_tasks", worker_id))
r$DEL(private$.get_worker_key("log", worker_id))
r$DEL(private$.get_worker_key("events", worker_id))
})

Expand All @@ -439,13 +473,12 @@ Rush = R6::R6Class("Rush",
r$DEL(private$.get_key("lost_worker_ids"))
r$DEL(private$.get_key("start_args"))
r$DEL(private$.get_key("terminate_on_idle"))
r$DEL(private$.get_key("local_pids"))
r$DEL(private$.get_key("local_workers"))
r$DEL(private$.get_key("heartbeat_keys"))

# reset counters and caches
private$.cached_results_dt = data.table()
private$.cached_tasks_dt = data.table()
private$.cached_worker_info = data.table()
private$.n_seen_results = 0

return(invisible(self))
Expand Down Expand Up @@ -751,15 +784,15 @@ Rush = R6::R6Class("Rush",
#'
#' @param keys (`character()`)\cr
#' Keys of the tasks to wait for.
#' @param detect_lost_tasks (`logical(1)`)\cr
#' @param detect_lost_workers (`logical(1)`)\cr
#' Whether to detect failed tasks.
#' Comes with an overhead.
wait_for_tasks = function(keys, detect_lost_tasks = FALSE) {
wait_for_tasks = function(keys, detect_lost_workers = FALSE) {
assert_character(keys, min.len = 1)
assert_flag(detect_lost_tasks)
assert_flag(detect_lost_workers)

while (any(keys %nin% c(self$finished_tasks, self$failed_tasks)) && self$n_running_workers > 0) {
if (detect_lost_tasks) self$detect_lost_tasks()
if (detect_lost_workers) self$detect_lost_workers()
Sys.sleep(0.01)
}

Expand Down Expand Up @@ -1007,8 +1040,7 @@ Rush = R6::R6Class("Rush",
#' Contains information about the workers.
worker_info = function(rhs) {
assert_ro_binding(rhs)
if (nrow(private$.cached_worker_info) == self$n_workers) return(private$.cached_worker_info)

if (!self$n_running_workers) return(data.table())
r = self$connector

fields = c("worker_id", "pid", "host", "hostname", "heartbeat")
Expand All @@ -1019,8 +1051,6 @@ Rush = R6::R6Class("Rush",
# fix type
worker_info[, pid := as.integer(pid)][]

# cache result
private$.cached_worker_info = worker_info
worker_info
},

Expand Down Expand Up @@ -1078,9 +1108,6 @@ Rush = R6::R6Class("Rush",

.cached_tasks_list = list(),

# cache of the worker info which usually does not change after starting the workers
.cached_worker_info = data.table(),

# counter of the seen results for the latest results methods
.n_seen_results = 0,

Expand Down
Loading

0 comments on commit 80dd506

Please sign in to comment.