Skip to content

Commit

Permalink
adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jan 14, 2024
1 parent 0b533b9 commit af15285
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ generate_acq_codomain = function(surrogate, id, direction = "same") {
if (surrogate$archive$codomain$length > 1L) {
stop("Not supported yet.") # FIXME: But should be?
}
tags = surrogate$archive$codomain$params[[1L]]$tags
tags = surrogate$archive$codomain$tags[[1L]]
tags = tags[tags %in% c("minimize", "maximize")] # only filter out the relevant one
} else {
tags = direction
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,14 @@ expect_acqfunction = function(acqf) {
expect_man_exists(acqf$man)
}


sortnames = function(x) {
if (!is.null(names(x))) {
x <- x[order(names(x), decreasing = TRUE)]
}
x
}

expect_equal_sorted = function(x, y, ...) {
expect_equal(sortnames(x), sortnames(y), ...)
}
12 changes: 6 additions & 6 deletions tests/testthat/test_AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ test_that("AcqOptimizer param_set", {
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 1L), trm("evals", n_evals = 1L))
expect_r6(acqopt$param_set, "ParamSet")
expect_setequal(acqopt$param_set$ids(), c("n_candidates", "logging_level", "warmstart", "warmstart_size", "skip_already_evaluated", "catch_errors"))
expect_r6(acqopt$param_set$params$n_candidates, "ParamInt")
expect_r6(acqopt$param_set$params$logging_level, "ParamFct")
expect_r6(acqopt$param_set$params$warmstart, "ParamLgl")
expect_r6(acqopt$param_set$params$warmstart_size, "ParamInt")
expect_r6(acqopt$param_set$params$skip_already_evaluated, "ParamLgl")
expect_r6(acqopt$param_set$params$catch_errors, "ParamLgl")
expect_equal(acqopt$param_set$class[["n_candidates"]], "ParamInt")
expect_equal(acqopt$param_set$class[["logging_level"]], "ParamFct")
expect_equal(acqopt$param_set$class[["warmstart"]], "ParamLgl")
expect_equal(acqopt$param_set$class[["warmstart_size"]], "ParamInt")
expect_equal(acqopt$param_set$class[["skip_already_evaluated"]], "ParamLgl")
expect_equal(acqopt$param_set$class[["catch_errors"]], "ParamLgl")
expect_error({acqopt$param_set = list()}, regexp = "param_set is read-only.")
})

Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test_SurrogateLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test_that("SurrogateLearner API works", {
# upgrading error class works
surrogate = SurrogateLearner$new(LearnerRegrError$new(), archive = inst$archive)
expect_error(surrogate$update(), class = "surrogate_update_error")

surrogate$param_set$values$catch_errors = FALSE
expect_error(surrogate$optimize(), class = "simpleError")

Expand Down Expand Up @@ -51,10 +51,10 @@ test_that("param_set", {
surrogate = SurrogateLearner$new(learner = REGR_FEATURELESS, archive = inst$archive)
expect_r6(surrogate$param_set, "ParamSet")
expect_setequal(surrogate$param_set$ids(), c("assert_insample_perf", "perf_measure", "perf_threshold", "catch_errors"))
expect_r6(surrogate$param_set$params$assert_insample_perf, "ParamLgl")
expect_r6(surrogate$param_set$params$perf_measure, "ParamUty")
expect_r6(surrogate$param_set$params$perf_threshold, "ParamDbl")
expect_r6(surrogate$param_set$params$catch_errors, "ParamLgl")
expect_equal(surrogate$param_set$class[["assert_insample_perf"]], "ParamLgl")
expect_equal(surrogate$param_set$class[["perf_measure"]], "ParamUty")
expect_equal(surrogate$param_set$class[["perf_threshold"]], "ParamDbl")
expect_equal(surrogate$param_set$class[["catch_errors"]], "ParamLgl")
expect_error({surrogate$param_set = list()}, regexp = "param_set is read-only.")
})

Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test_mbo_defaults.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST_1D())
expect_r6(surrogate, "SurrogateLearner")
expect_r6(surrogate$learner, "LearnerRegrKM")
expect_equal(surrogate$learner$param_set$values,
expect_equal_sorted(surrogate$learner$param_set$values,
list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 1e-08))
expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner$fallback, "LearnerRegrRanger")
Expand All @@ -30,7 +30,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST_1D_NOISY())
expect_r6(surrogate, "SurrogateLearner")
expect_r6(surrogate$learner, "LearnerRegrKM")
expect_equal(surrogate$learner$param_set$values,
expect_equal_sorted(surrogate$learner$param_set$values,
list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.estim = TRUE, jitter = 1e-12))
expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner$fallback, "LearnerRegrRanger")
Expand All @@ -39,7 +39,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST(OBJ_1D_2, search_space = PS_1D))
expect_r6(surrogate, "SurrogateLearnerCollection")
expect_list(surrogate$learner, types = "LearnerRegrKM")
expect_equal(surrogate$learner[[1L]]$param_set$values,
expect_equal_sorted(surrogate$learner[[1L]]$param_set$values,
list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 1e-08))
expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger")
Expand All @@ -51,7 +51,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST(OBJ_1D_2_NOISY, search_space = PS_1D))
expect_r6(surrogate, "SurrogateLearnerCollection")
expect_list(surrogate$learner, types = "LearnerRegrKM")
expect_equal(surrogate$learner[[1L]]$param_set$values,
expect_equal_sorted(surrogate$learner[[1L]]$param_set$values,
list(covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.estim = TRUE, jitter = 1e-12))
expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger")
Expand All @@ -63,7 +63,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST(OBJ_1D_MIXED, search_space = PS_1D_MIXED))
expect_r6(surrogate, "SurrogateLearner")
expect_r6(surrogate$learner, "LearnerRegrRanger")
expect_equal(surrogate$learner$param_set$values,
expect_equal_sorted(surrogate$learner$param_set$values,
list(num.threads = 1L, num.trees = 100L, keep.inbag = TRUE, se.method = "jack"))
expect_equal(surrogate$learner$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner$fallback, "LearnerRegrRanger")
Expand All @@ -72,7 +72,7 @@ test_that("default_surrogate", {
surrogate = default_surrogate(MAKE_INST(OBJ_1D_2_MIXED, search_space = PS_1D_MIXED))
expect_r6(surrogate, "SurrogateLearnerCollection")
expect_list(surrogate$learner, types = "LearnerRegrRanger")
expect_equal(surrogate$learner[[1L]]$param_set$values,
expect_equal_sorted(surrogate$learner[[1L]]$param_set$values,
list(num.threads = 1L, num.trees = 100L, keep.inbag = TRUE, se.method = "jack"))
expect_equal(surrogate$learner[[1L]]$encapsulate, c(train = "evaluate", predict = "evaluate"))
expect_r6(surrogate$learner[[1L]]$fallback, "LearnerRegrRanger")
Expand All @@ -85,7 +85,7 @@ test_that("default_surrogate", {
expect_r6(surrogate, "SurrogateLearner")
expect_r6(surrogate$learner, "GraphLearner")
expect_equal(surrogate$learner$graph$ids(), c("imputesample", "imputeoor", "colapply", "regr.ranger"))
expect_equal(surrogate$learner$param_set$values,
expect_equal_sorted(surrogate$learner$param_set$values,
list(imputesample.affect_columns = mlr3pipelines::selector_type("logical"),
imputeoor.min = TRUE,
imputeoor.offset = 1,
Expand Down

0 comments on commit af15285

Please sign in to comment.