Skip to content

Commit

Permalink
refactor: extract internal tuned values in instance (#164)
Browse files Browse the repository at this point in the history
* refactor: extract internal tuned values in instance

* ...

* ...
  • Loading branch information
be-marc authored Oct 16, 2024
1 parent b3b8c74 commit 3de4d0c
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 10 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ BugReports: https://github.com/mlr-org/mlr3mbo/issues
Depends:
R (>= 3.1.0)
Imports:
bbotk (>= 1.0.0),
bbotk (>= 1.1.1),
checkmate (>= 2.0.0),
data.table,
lgr (>= 0.3.4),
mlr3 (>= 0.21.0),
mlr3misc (>= 0.11.0),
mlr3tuning (>= 1.0.0),
mlr3tuning (>= 1.0.2),
paradox (>= 1.0.0),
spacefillr,
R6 (>= 2.4.1)
Expand Down
12 changes: 6 additions & 6 deletions R/ResultAssignerArchive.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ ResultAssignerArchive = R6Class("ResultAssignerArchive",
#' @param instance ([bbotk::OptimInstanceBatchSingleCrit] | [bbotk::OptimInstanceBatchMultiCrit])\cr
#' The [bbotk::OptimInstance] the final result should be assigned to.
assign_result = function(instance) {
res = instance$archive$best()
xdt = res[, instance$search_space$ids(), with = FALSE]
xydt = instance$archive$best()
xdt = xydt[, instance$search_space$ids(), with = FALSE]
if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
ydt = res[, instance$archive$cols_y, with = FALSE]
instance$assign_result(xdt, ydt)
ydt = xydt[, instance$archive$cols_y, with = FALSE]
instance$assign_result(xdt, ydt, xydt = xydt)
}
else {
y = unlist(res[, instance$archive$cols_y, with = FALSE])
instance$assign_result(xdt, y)
y = unlist(xydt[, instance$archive$cols_y, with = FALSE])
instance$assign_result(xdt, y, xydt = xydt)
}
}
),
Expand Down
5 changes: 3 additions & 2 deletions R/ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,16 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
}
archive_tmp = archive$clone(deep = TRUE)
archive_tmp$data[, self$surrogate$cols_y := means]
best = archive_tmp$best()[, archive_tmp$cols_x, with = FALSE]
xydt = archive_tmp$best()
best = xydt[, archive_tmp$cols_x, with = FALSE]

# ys are still the ones originally evaluated
best_y = if (inherits(instance, "OptimInstanceBatchSingleCrit")) {
unlist(archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE])
} else if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE]
}
instance$assign_result(xdt = best, best_y)
instance$assign_result(xdt = best, best_y, xydt = xydt)
}
),

Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test_ResultAssignerArchive.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,31 @@ test_that("ResultAssignerArchive works with OptimizerMbo and bayesopt_smsego", {
expect_data_table(instance$result, min.rows = 1L)
})

test_that("ResultAssignerArchive passes internal tuned values", {
result_assigner = ResultAssignerArchive$new()

learner = lrn("classif.debug",
validate = 0.2,
early_stopping = TRUE,
x = to_tune(0.2, 0.3),
iter = to_tune(upper = 1000, internal = TRUE, aggr = function(x) 99))

instance = ti(
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 20),
store_benchmark_result = TRUE
)
surrogate = SurrogateLearner$new(REGR_KM_DETERM)
acq_function = AcqFunctionEI$new()
acq_optimizer = AcqOptimizer$new(opt("random_search", batch_size = 2L), terminator = trm("evals", n_evals = 2L))

tuner = tnr("mbo", result_assigner = result_assigner)
expect_data_table(tuner$optimize(instance), nrows = 1)
expect_list(instance$archive$data$internal_tuned_values, len = 20, types = "list")
expect_equal(instance$archive$data$internal_tuned_values[[1]], list(iter = 99))
expect_false(instance$result_learner_param_vals$early_stopping)
expect_equal(instance$result_learner_param_vals$iter, 99)
})
28 changes: 28 additions & 0 deletions tests/testthat/test_ResultAssignerSurrogate.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,31 @@ test_that("ResultAssignerSurrogate works with OptimizerMbo and bayesopt_smsego",
expect_data_table(instance$result, min.rows = 1L)
})

test_that("ResultAssignerSurrogate passes internal tuned values", {
result_assigner = ResultAssignerSurrogate$new()

learner = lrn("classif.debug",
validate = 0.2,
early_stopping = TRUE,
x = to_tune(0.2, 0.3),
iter = to_tune(upper = 1000, internal = TRUE, aggr = function(x) 99))

instance = ti(
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 20),
store_benchmark_result = TRUE
)
surrogate = SurrogateLearner$new(REGR_KM_DETERM)
acq_function = AcqFunctionEI$new()
acq_optimizer = AcqOptimizer$new(opt("random_search", batch_size = 2L), terminator = trm("evals", n_evals = 2L))

tuner = tnr("mbo", result_assigner = result_assigner)
expect_data_table(tuner$optimize(instance), nrows = 1)
expect_list(instance$archive$data$internal_tuned_values, len = 20, types = "list")
expect_equal(instance$archive$data$internal_tuned_values[[1]], list(iter = 99))
expect_false(instance$result_learner_param_vals$early_stopping)
expect_equal(instance$result_learner_param_vals$iter, 99)
})

0 comments on commit 3de4d0c

Please sign in to comment.