From 0668a36e95f5885aac06db5d9fdda06a6fa15e43 Mon Sep 17 00:00:00 2001 From: Lars Kotthoff Date: Mon, 5 Feb 2024 13:55:45 -0700 Subject: [PATCH] proposal for hook for loop function that allows plotting --- R/bayesopt_ego.R | 44 +++++++++++++++++++++++++++++++++++++++++++- R/bayesopt_emo.R | 8 +++++++- R/bayesopt_mpcl.R | 8 +++++++- R/bayesopt_parego.R | 8 +++++++- R/bayesopt_smsego.R | 8 +++++++- 5 files changed, 71 insertions(+), 5 deletions(-) diff --git a/R/bayesopt_ego.R b/R/bayesopt_ego.R index 9f8a9618..0196bed9 100644 --- a/R/bayesopt_ego.R +++ b/R/bayesopt_ego.R @@ -30,6 +30,9 @@ #' For example, if `random_interleave_iter = 2`, random interleaving is performed in the second, #' fourth, sixth, ... iteration. #' Default is `0`, i.e., no random interleaving is performed at all. +#' @param hook_fun ([function])\cr +#' [Function] to be called in each iteration of the loop, before evaluating the next proposed point. +#' See examples. #' #' @note #' * The `acq_function$surrogate`, even if already populated, will always be overwritten by the `surrogate`. @@ -82,6 +85,42 @@ #' #' optimizer$optimize(instance) #' +#' # same as above, but plot files with information at each iteration of the loop +#' library(ggplot2) +#' library(gridExtra) +#' +#' myPlot = function(xdt, instance, surrogate, acq_function, acq_optimizer) { +#' data.plot = data.table(x = seq(instance$objective$domain$lower, instance$objective$domain$upper, length.out = 100)) +#' data.plot$y = instance$objective$eval_dt(data.plot)$y +#' data.plot$acq_ei = acq_optimizer$acq_function$eval_dt(data.table(x = data.plot$x))$acq_ei +#' data.plot = data.table(data.plot, surrogate$predict(data.plot)) +#' +#' p1 = ggplot(data.plot, aes(x = x, y = y)) + +#' geom_ribbon(aes(ymin = mean - se, ymax = mean + se), fill = "lightgray") + +#' geom_line() + +#' geom_line(aes(y = mean), linetype = "dashed") + +#' geom_point(data = xdt, aes(x = x, y = surrogate$predict(data.table(x = x))$mean), color = "red") + +#' geom_point(data = instance$archive$data, aes(x = x, y = y), color = "darkgreen") + +#' theme(axis.title.x = element_blank(), +#' axis.text.x = element_blank(), +#' axis.ticks.x = element_blank()) +#' +#' p2 = ggplot(data.plot, aes(x = x, y = acq_ei)) + +#' geom_line() +#' +#' p = grid.arrange(p1, p2, ncol = 1) +#' ggsave(p, file = paste("mbo-", instance$archive$n_evals, ".pdf", sep = "")) +#' } +#' +#' optimizer = opt("mbo", +#' loop_function = bayesopt_ego, +#' surrogate = surrogate, +#' acq_function = acqfun, +#' acq_optimizer = acqopt, +#' args = list("hook_fun" = myPlot)) +#' +#' optimizer$optimize(instance) +#' #' # expected improvement per second example #' fun = function(xs) { #' list(y = xs$x ^ 2, time = abs(xs$x)) @@ -112,7 +151,8 @@ bayesopt_ego = function( acq_function, acq_optimizer, init_design_size = NULL, - random_interleave_iter = 0L + random_interleave_iter = 0L, + hook_fun = function(...) {} ) { # assertions @@ -154,6 +194,8 @@ bayesopt_ego = function( generate_design_random(search_space, n = 1L)$data }) + hook_fun(xdt, instance, surrogate, acq_function, acq_optimizer) + instance$eval_batch(xdt) if (instance$is_terminated) break } diff --git a/R/bayesopt_emo.R b/R/bayesopt_emo.R index 1caacd6f..cd61a8a0 100644 --- a/R/bayesopt_emo.R +++ b/R/bayesopt_emo.R @@ -30,6 +30,9 @@ #' For example, if `random_interleave_iter = 2`, random interleaving is performed in the second, #' fourth, sixth, ... iteration. #' Default is `0`, i.e., no random interleaving is performed at all. +#' @param hook_fun ([function])\cr +#' [Function] to be called in each iteration of the loop, before evaluating the next proposed point. +#' See examples in [mlr_loop_functions_ego]. #' #' @note #' * The `acq_function$surrogate`, even if already populated, will always be overwritten by the `surrogate`. @@ -85,7 +88,8 @@ bayesopt_emo = function( acq_function, acq_optimizer, init_design_size = NULL, - random_interleave_iter = 0L + random_interleave_iter = 0L, + hook_fun = function(...) {} ) { # assertions @@ -127,6 +131,8 @@ bayesopt_emo = function( generate_design_random(search_space, n = 1L)$data }) + hook_fun(xdt, instance, surrogate, acq_function, acq_optimizer) + instance$eval_batch(xdt) if (instance$is_terminated) break } diff --git a/R/bayesopt_mpcl.R b/R/bayesopt_mpcl.R index edac4cec..dbfeaa46 100644 --- a/R/bayesopt_mpcl.R +++ b/R/bayesopt_mpcl.R @@ -38,6 +38,9 @@ #' For example, if `random_interleave_iter = 2`, random interleaving is performed in the second, #' fourth, sixth, ... iteration. #' Default is `0`, i.e., no random interleaving is performed at all. +#' @param hook_fun ([function])\cr +#' [Function] to be called in each iteration of the loop, before evaluating the next proposed point. +#' See examples in [mlr_loop_functions_ego]. #' #' @note #' * The `acq_function$surrogate`, even if already populated, will always be overwritten by the `surrogate`. @@ -102,7 +105,8 @@ bayesopt_mpcl = function( init_design_size = NULL, q = 2L, liar = mean, - random_interleave_iter = 0L + random_interleave_iter = 0L, + hook_fun = function(...) {} ) { # assertions @@ -175,6 +179,8 @@ bayesopt_mpcl = function( acq_function$surrogate$archive = instance$archive + hook_fun(xdt, instance, surrogate, acq_function, acq_optimizer) + instance$eval_batch(xdt) if (instance$is_terminated) break diff --git a/R/bayesopt_parego.R b/R/bayesopt_parego.R index 62550871..96e8f769 100644 --- a/R/bayesopt_parego.R +++ b/R/bayesopt_parego.R @@ -40,6 +40,9 @@ #' For example, if `random_interleave_iter = 2`, random interleaving is performed in the second, #' fourth, sixth, ... iteration. #' Default is `0`, i.e., no random interleaving is performed at all. +#' @param hook_fun ([function])\cr +#' [Function] to be called in each iteration of the loop, before evaluating the next proposed point. +#' See examples in [mlr_loop_functions_ego]. #' #' @note #' * The `acq_function$surrogate`, even if already populated, will always be overwritten by the `surrogate`. @@ -105,7 +108,8 @@ bayesopt_parego = function( q = 1L, s = 100L, rho = 0.05, - random_interleave_iter = 0L + random_interleave_iter = 0L, + hook_fun = function(...) {} ) { # assertions @@ -169,6 +173,8 @@ bayesopt_parego = function( }) }, .fill = TRUE) + hook_fun(xdt, instance, surrogate, acq_function, acq_optimizer) + instance$eval_batch(xdt) if (instance$is_terminated) break } diff --git a/R/bayesopt_smsego.R b/R/bayesopt_smsego.R index 124fcd86..6a6b1dfb 100644 --- a/R/bayesopt_smsego.R +++ b/R/bayesopt_smsego.R @@ -29,6 +29,9 @@ #' For example, if `random_interleave_iter = 2`, random interleaving is performed in the second, #' fourth, sixth, ... iteration. #' Default is `0`, i.e., no random interleaving is performed at all. +#' @param hook_fun ([function])\cr +#' [Function] to be called in each iteration of the loop, before evaluating the next proposed point. +#' See examples in [mlr_loop_functions_ego]. #' #' @note #' * The `acq_function$surrogate`, even if already populated, will always be overwritten by the `surrogate`. @@ -90,7 +93,8 @@ bayesopt_smsego = function( acq_function, acq_optimizer, init_design_size = NULL, - random_interleave_iter = 0L + random_interleave_iter = 0L, + hook_fun = function(...) {} ) { # assertions @@ -134,6 +138,8 @@ bayesopt_smsego = function( generate_design_random(search_space, n = 1L)$data }) + hook_fun(xdt, instance, surrogate, acq_function, acq_optimizer) + instance$eval_batch(xdt) if (instance$is_terminated) break }