Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add asynchronous decentralized bayesian optimization #145

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0604202
draft
be-marc Feb 8, 2024
d4b56b9
fix: add start time
be-marc Feb 8, 2024
71fdc66
refactor: remove rush debug
be-marc Feb 9, 2024
762b136
fix: add callback
be-marc Feb 9, 2024
6a62d75
fix: transformation
be-marc Feb 9, 2024
c62f7c8
fix: xdomain
be-marc Feb 9, 2024
d92dc55
refactor: kill workers
be-marc Feb 10, 2024
2bb45a9
feat: cache archive
be-marc Feb 10, 2024
2f05764
chore: debug
be-marc Feb 10, 2024
0ed1969
chore: import rush
be-marc Feb 12, 2024
b7217e4
feat: add logging
be-marc Feb 12, 2024
6982d12
refactor: use optimize_decentralized()
be-marc Feb 12, 2024
0363e5d
feat: add exponential decay
be-marc Feb 19, 2024
c48a484
feat: add min-max imputation
be-marc Feb 19, 2024
c8784aa
feat: add n_worker parameter
be-marc Feb 22, 2024
b97bd18
draft
be-marc Apr 26, 2024
e890f13
Merge branch 'main' into adbo
be-marc Apr 26, 2024
3787ec9
draft
be-marc Apr 28, 2024
5a03bf7
refactor: remove stage
be-marc Apr 29, 2024
d77fad8
fix: description
be-marc Apr 29, 2024
170d650
feat: add n_worker argument
be-marc May 1, 2024
3bd95b0
fix: imports
be-marc May 1, 2024
43c2189
fix: tests
be-marc May 1, 2024
056e50e
ci: add redis
be-marc May 1, 2024
106c58b
Merge remote-tracking branch 'origin/main' into adbo
sumny Jun 21, 2024
81959dc
Merge branch 'main' into adbo
be-marc Jul 1, 2024
66f93eb
Merge branch 'main' into adbo
be-marc Jul 1, 2024
86f9659
Merge branch 'main' into adbo
be-marc Jul 25, 2024
a661667
refactor: add SurrogateLearnerAsync
be-marc Aug 18, 2024
1e0a31d
Merge remote-tracking branch 'origin/main' into adbo
sumny Aug 20, 2024
4abb9c6
feat: add OptimizerAsyncMbo
be-marc Aug 22, 2024
a55f037
tests: add tests
be-marc Aug 22, 2024
475152a
update
be-marc Aug 27, 2024
5a8f2ae
...
be-marc Sep 11, 2024
782040b
compatibility: mlr3 0.21.0
be-marc Sep 11, 2024
b43d542
...
be-marc Sep 11, 2024
51d9c0d
Merge branch 'compat_mlr3' into adbo
be-marc Sep 13, 2024
367583c
...
be-marc Sep 13, 2024
c7caa61
feat: n_workers
be-marc Sep 13, 2024
0709d10
docs: update
be-marc Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ linters: linters_with_defaults(
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(120)
line_length_linter = line_length_linter(120),
indentation_linter(indent = 2L, hanging_indent_style = "never")
)
14 changes: 10 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ BugReports: https://github.com/mlr-org/mlr3mbo/issues
Depends:
R (>= 3.1.0)
Imports:
bbotk (>= 1.0.0),
bbotk (>= 1.1.0),
checkmate (>= 2.0.0),
data.table,
lgr (>= 0.3.4),
mlr3 (>= 0.14.0),
mlr3 (>= 0.20.2.9000),
mlr3misc (>= 0.11.0),
mlr3tuning (>= 1.0.0),
mlr3tuning (>= 1.0.1),
paradox (>= 1.0.0),
spacefillr,
R6 (>= 2.4.1)
Expand All @@ -64,9 +64,11 @@ Suggests:
rgenoud,
rmarkdown,
rpart,
rush,
stringi,
testthat (>= 3.0.0)
Remotes: mlr-org/bbotk
Remotes:
mlr-org/mlr3
ByteCompile: no
Encoding: UTF-8
Config/testthat/edition: 3
Expand All @@ -90,14 +92,17 @@ Collate:
'AcqFunctionSmsEgo.R'
'AcqOptimizer.R'
'aaa.R'
'OptimizerAsyncMbo.R'
'OptimizerMbo.R'
'mlr_result_assigners.R'
'ResultAssigner.R'
'ResultAssignerArchive.R'
'ResultAssignerSurrogate.R'
'Surrogate.R'
'SurrogateLearner.R'
'SurrogateLearnerAsync.R'
'SurrogateLearnerCollection.R'
'TunerAsyncMbo.R'
'TunerMbo.R'
'mlr_loop_functions.R'
'bayesopt_ego.R'
Expand All @@ -109,6 +114,7 @@ Collate:
'helper.R'
'loop_function.R'
'mbo_defaults.R'
'mlr_callbacks.R'
'sugar.R'
'zzz.R'
VignetteBuilder: knitr
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ export(AcqFunctionPI)
export(AcqFunctionSD)
export(AcqFunctionSmsEgo)
export(AcqOptimizer)
export(OptimizerAsyncMbo)
export(OptimizerMbo)
export(ResultAssigner)
export(ResultAssignerArchive)
export(ResultAssignerSurrogate)
export(Surrogate)
export(SurrogateLearner)
export(SurrogateLearnerAsync)
export(SurrogateLearnerCollection)
export(TunerAsyncMbo)
export(TunerMbo)
export(acqf)
export(acqfs)
Expand Down Expand Up @@ -58,6 +61,7 @@ importFrom(R6,R6Class)
importFrom(stats,dnorm)
importFrom(stats,pnorm)
importFrom(stats,quantile)
importFrom(stats,rexp)
importFrom(stats,runif)
importFrom(stats,setNames)
importFrom(utils,bibentry)
Expand Down
18 changes: 13 additions & 5 deletions R/AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,25 @@ AcqFunctionCB = R6Class("AcqFunctionCB",
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
assert_number(lambda, lower = 0, finite = TRUE)

constants = ps(lambda = p_dbl(lower = 0, default = 2))
constants = ps(
lambda = p_dbl(lower = 0, default = 2)
)
constants$values$lambda = lambda

super$initialize("acq_cb", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "same", label = "Lower / Upper Confidence Bound", man = "mlr3mbo::mlr_acqfunctions_cb")
super$initialize("acq_cb",
constants = constants,
surrogate = surrogate,
requires_predict_type_se = TRUE,
direction = "same",
label = "Lower / Upper Confidence Bound",
man = "mlr3mbo::mlr_acqfunctions_cb")
}
),

private = list(
.fun = function(xdt, ...) {
constants = list(...)
lambda = constants$lambda
.fun = function(xdt, lambda) {
#constants = list(...)
#lambda = constants$lambda
p = self$surrogate$predict(xdt)
cb = p$mean - self$surrogate_max_to_min * lambda * p$se
data.table(acq_cb = cb)
Expand Down
203 changes: 203 additions & 0 deletions R/OptimizerAsyncMbo.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#' @title Asynchronous Model Based Optimization
#'
#' @name mlr_optimizers_async_mbo
#'
#' @description
#' `OptimizerAsyncMbo` class that implements asynchronous Model Based Optimization (MBO).
#'
#' @section Parameters:
#' \describe{
#' \item{`initial_design`}{`data.table::data.table()`\cr
#' Initial design of the optimization.
#' If `NULL`, a design of size `design_size` is generated with `design_function`.}
#' \item{`design_size`}{`integer(1)`\cr
#' Size of the initial design.}
#' \item{`design_function`}{`character(1)`\cr
#' Function to generate the initial design.
#' One of `c("random", "sobol", "lhs")`.}
#' \item{`n_workers`}{`integer(1)`\cr
#' Number of parallel workers.
#' If `NULL`, all rush workers set with [rush::rush_plan()] are used.}
#' }
#'
#' @template param_surrogate
#' @template param_acq_function
#' @template param_acq_optimizer
#'
#' @param param_set [paradox::ParamSet]\cr
#' Set of control parameters.
#'
#' @export
OptimizerAsyncMbo = R6Class("OptimizerAsyncMbo",
inherit = OptimizerAsync,

public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(surrogate = NULL, acq_function = NULL, acq_optimizer = NULL, param_set = NULL) {
default_param_set = ps(
initial_design = p_uty(),
design_size = p_int(lower = 1, default = 10),
design_function = p_fct(c("random", "sobol", "lhs"), default = "sobol"),
n_workers = p_int(lower = 1L)
)
param_set = c(default_param_set, param_set)

param_set$set_values(design_size = 10, design_function = "sobol")

super$initialize("async_mbo",
param_set = param_set,
param_classes = c("ParamLgl", "ParamInt", "ParamDbl", "ParamFct"),
properties = c("dependencies", "single-crit"),
packages = c("mlr3mbo", "rush"),
label = "Asynchronous Model Based Optimization",
man = "mlr3mbo::OptimizerAsyncMbo")

self$surrogate = assert_r6(surrogate, classes = "Surrogate", null.ok = TRUE)
self$acq_function = assert_r6(acq_function, classes = "AcqFunction", null.ok = TRUE)
self$acq_optimizer = assert_r6(acq_optimizer, classes = "AcqOptimizer", null.ok = TRUE)
},


#' @description
#' Performs the optimization on a [OptimInstanceAsyncSingleCrit] or [OptimInstanceAsyncMultiCrit] until termination.
#' The single evaluations will be written into the [ArchiveAsync].
#' The result will be written into the instance object.
#'
#' @param inst ([OptimInstanceAsyncSingleCrit] | [OptimInstanceAsyncMultiCrit]).
#'
#' @return [data.table::data.table()]
optimize = function(inst) {
pv = self$param_set$values

# initial design
design = if (is.null(pv$initial_design)) {
# generate initial design
generate_design = switch(pv$design_function,
"random" = generate_design_random,
"sobol" = generate_design_sobol,
"lhs" = generate_design_lhs)

lg$debug("Generating sobol design with size %s", pv$design_size)
generate_design(inst$search_space, n = pv$design_size)$data
} else {
# use provided initial design
lg$debug("Using provided initial design with size %s", nrow(pv$initial_design))
pv$initial_design
}
optimize_async_default(inst, self, design, n_workers = pv$n_workers)
}
),

active = list(
#' @template field_surrogate
surrogate = function(rhs) {
if (missing(rhs)) {
private$.surrogate
} else {
private$.surrogate = assert_r6(rhs, classes = "Surrogate", null.ok = TRUE)
}
},

#' @template field_acq_function
acq_function = function(rhs) {
if (missing(rhs)) {
private$.acq_function
} else {
private$.acq_function = assert_r6(rhs, classes = "AcqFunction", null.ok = TRUE)
}
},

#' @template field_acq_optimizer
acq_optimizer = function(rhs) {
if (missing(rhs)) {
private$.acq_optimizer
} else {
private$.acq_optimizer = assert_r6(rhs, classes = "AcqOptimizer", null.ok = TRUE)
}
},

#' @template field_param_classes
param_classes = function(rhs) {
assert_ro_binding(rhs)
param_classes_surrogate = c("logical" = "ParamLgl", "integer" = "ParamInt", "numeric" = "ParamDbl", "factor" = "ParamFct")
if (!is.null(self$surrogate)) {
param_classes_surrogate = param_classes_surrogate[c("logical", "integer", "numeric", "factor") %in% self$surrogate$feature_types] # surrogate has precedence over acq_function$surrogate
}
param_classes_acq_opt = if (!is.null(self$acq_optimizer)) {
self$acq_optimizer$optimizer$param_classes
} else {
c("ParamLgl", "ParamInt", "ParamDbl", "ParamFct")
}
unname(intersect(param_classes_surrogate, param_classes_acq_opt))
},

#' @template field_properties
properties = function(rhs) {
assert_ro_binding(rhs)

properties_surrogate = "dependencies"
if (!is.null(self$surrogate)) {
if ("missings" %nin% self$surrogate$properties) {
properties_surrogate = character()
}
}
unname(c(properties_surrogate))
},

#' @template field_packages
packages = function(rhs) {
assert_ro_binding(rhs)
union("mlr3mbo", c(self$acq_function$packages, self$surrogate$packages, self$acq_optimizer$optimizer$packages))
}
),

private = list(
.surrogate = NULL,
.acq_function = NULL,
.acq_optimizer = NULL,

.optimize = function(inst) {
pv = self$param_set$values
search_space = inst$search_space
archive = inst$archive

if (is.null(self$acq_function)) {
self$acq_function = self$acq_optimizer$acq_function %??% default_acqfunction(inst)
}

if (is.null(self$surrogate)) {
self$surrogate = self$acq_function$surrogate %??% default_surrogate(inst)
}

if (is.null(self$acq_optimizer)) {
self$acq_optimizer = default_acqoptimizer(self$acq_function)
}

self$surrogate$archive = inst$archive
self$acq_function$surrogate = self$surrogate
self$acq_optimizer$acq_function = self$acq_function

lg$debug("Optimizer '%s' evaluates the initial design", self$id)
get_private(inst)$.eval_queue()

lg$debug("Optimizer '%s' starts the tuning phase", self$id)

# actual loop
while (!inst$is_terminated) {
# sample
self$acq_function$surrogate$update()
self$acq_function$update()
xdt = self$acq_optimizer$optimize()
xs = transpose_list(xdt)[[1]]

# eval
get_private(inst)$.eval_point(xs)
}
}
)
)

#' @include aaa.R
optimizers[["async_mbo"]] = OptimizerAsyncMbo
Loading
Loading