Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply 'styler' style for mlr #2498

Merged
merged 8 commits into from
Apr 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 6 additions & 2 deletions R/Aggregation.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,21 @@ NULL
#' @return ([Aggregation]).
#' @examples
#' # computes the interquartile range on all performance values
#' test.iqr = makeAggregation(id = "test.iqr", name = "Test set interquartile range",
#' test.iqr = makeAggregation(
#' id = "test.iqr", name = "Test set interquartile range",
#' properties = "req.test",
#' fun = function (task, perf.test, perf.train, measure, group, pred) IQR(perf.test))
#' fun = function(task, perf.test, perf.train, measure, group, pred) IQR(perf.test)
#' )
#' @export
makeAggregation = function(id, name = id, properties, fun) {

assertString(id)
assertString(name)
makeS3Obj("Aggregation", id = id, name = name, fun = fun, properties = properties)
}

#' @export
print.Aggregation = function(x, ...) {

catf("Aggregation function: %s", x$id)
}
21 changes: 16 additions & 5 deletions R/BaggingWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#' @family wrapper
#' @export
makeBaggingWrapper = function(learner, bw.iters = 10L, bw.replace = TRUE, bw.size, bw.feats = 1) {

learner = checkLearner(learner, type = c("classif", "regr"))
pv = list()
if (!missing(bw.iters)) {
Expand All @@ -60,8 +61,9 @@ makeBaggingWrapper = function(learner, bw.iters = 10L, bw.replace = TRUE, bw.siz
assertNumber(bw.feats, lower = 0, upper = 1)
pv$bw.feats = bw.feats
}
if (learner$predict.type != "response")
if (learner$predict.type != "response") {
stop("Predict type of the basic learner must be 'response'.")
}
id = stri_paste(learner$id, "bagged", sep = ".")
packs = learner$package
ps = makeParamSet(
Expand All @@ -76,6 +78,7 @@ makeBaggingWrapper = function(learner, bw.iters = 10L, bw.replace = TRUE, bw.siz

#' @export
print.BaggingModel = function(x, ...) {

s = capture.output(print.WrappedModel(x))
u = sprintf("Bagged Learner: %s", class(x$learner$next.learner)[1L])
s = append(s, u, 1L)
Expand All @@ -86,8 +89,9 @@ print.BaggingModel = function(x, ...) {
trainLearner.BaggingWrapper = function(.learner, .task, .subset = NULL, .weights = NULL,
bw.iters = 10, bw.replace = TRUE, bw.size, bw.feats = 1, ...) {

if (missing(bw.size))
if (missing(bw.size)) {
bw.size = if (bw.replace) 1 else 0.632
}
.task = subsetTask(.task, subset = .subset)
n = getTaskSize(.task)
# number of observations to sample
Expand All @@ -104,6 +108,7 @@ trainLearner.BaggingWrapper = function(.learner, .task, .subset = NULL, .weights
}

doBaggingTrainIteration = function(i, n, m, k, bw.replace, task, learner, weights) {

setSlaveOptions()
bag = sample(seq_len(n), m, replace = bw.replace)
task = subsetTask(task, features = sample(getTaskFeatureNames(task), k, replace = FALSE))
Expand All @@ -112,21 +117,25 @@ doBaggingTrainIteration = function(i, n, m, k, bw.replace, task, learner, weight

#' @export
predictLearner.BaggingWrapper = function(.learner, .model, .newdata, .subset = NULL, ...) {

models = getLearnerModel(.model, more.unwrap = FALSE)
g = if (.learner$type == "classif") as.character else identity
p = asMatrixCols(lapply(models, function(m) {

nd = .newdata[, m$features, drop = FALSE]
g(predict(m, newdata = nd, subset = .subset, ...)$data$response)
}))
if (.learner$predict.type == "response") {
if (.learner$type == "classif")
if (.learner$type == "classif") {
as.factor(apply(p, 1L, computeMode))
else
} else {
rowMeans(p)
}
} else {
if (.learner$type == "classif") {
levs = .model$task.desc$class.levels
p = apply(p, 1L, function(x) {

x = factor(x, levels = levs) # we need all level for the table and we need them in consistent order!
as.numeric(prop.table(table(x)))
})
Expand All @@ -141,12 +150,14 @@ predictLearner.BaggingWrapper = function(.learner, .model, .newdata, .subset = N
# be response, we can estimates probs and se on the outside
#' @export
setPredictType.BaggingWrapper = function(learner, predict.type) {

setPredictType.Learner(learner, predict.type)
}

#' @export
getLearnerProperties.BaggingWrapper = function(learner) {
switch(learner$type,

switch(learner$type,
"classif" = union(getLearnerProperties(learner$next.learner), "prob"),
"regr" = union(getLearnerProperties(learner$next.learner), "se")
)
Expand Down
12 changes: 8 additions & 4 deletions R/BaseEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@ makeBaseEnsemble = function(id, base.learners, bls.type = NULL,
base.learners = lapply(base.learners, checkLearner, type = bls.type)

tt = unique(extractSubList(base.learners, "type"))
if (length(tt) > 1L)
if (length(tt) > 1L) {
stopf("Base learners must all be of same type, but have: %s", collapse(tt))
if (is.null(ens.type))
}
if (is.null(ens.type)) {
ens.type = tt
}

ids = unique(extractSubList(base.learners, "id"))
if (length(ids) != length(base.learners))
if (length(ids) != length(base.learners)) {
stop("Base learners must all have unique ids!")
}

# check that all predict.types are the same
pts = unique(extractSubList(base.learners, "predict.type"))
if (length(pts) > 1L)
if (length(pts) > 1L) {
stopf("Base learners must all have same predict.type, but have: %s", collapse(pts))
}

# join all parsets of base.learners + prefix param names with base learner id
# (we could also do this operation on-the.fly in getParamSet.BaseEnsemble,
Expand Down
10 changes: 9 additions & 1 deletion R/BaseEnsemble_operators.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# find the learner for a given param name, so <learnerid>.<paramid>
matchBaseEnsembleLearner = function(ensemble, pn) {

patterns = stri_paste("^", names(ensemble$base.learners), "\\.")
j = which(vlapply(patterns, stri_detect_regex, str = pn))
par.id = stri_replace_first(pn, "", regex = patterns[j])
Expand All @@ -8,10 +9,13 @@ matchBaseEnsembleLearner = function(ensemble, pn) {

#' @export
getHyperPars.BaseEnsemble = function(learner, for.fun = c("train", "predict", "both")) {

pvs = lapply(learner$base.learners, function(lrn) {

xs = getHyperPars(lrn, for.fun = for.fun)
if (length(xs) > 0L)
if (length(xs) > 0L) {
names(xs) = stri_paste(lrn$id, ".", names(xs))
}
return(xs)
})
# if we dont do this, R prefixes the list names again.
Expand All @@ -24,6 +28,7 @@ getHyperPars.BaseEnsemble = function(learner, for.fun = c("train", "predict", "b
# set hyper pars down in ensemble base learners, identify correct base learner + remove prefix
#' @export
setHyperPars2.BaseEnsemble = function(learner, par.vals) {

ns = names(par.vals)
parnames.bls = names(learner$par.set.bls$pars)
for (i in seq_along(par.vals)) {
Expand All @@ -43,6 +48,7 @@ setHyperPars2.BaseEnsemble = function(learner, par.vals) {

#' @export
removeHyperPars.BaseEnsemble = function(learner, ids) {

parnames.bls = names(learner$par.set.bls$pars)
for (id in ids) {
if (id %in% parnames.bls) {
Expand All @@ -63,6 +69,7 @@ removeHyperPars.BaseEnsemble = function(learner, ids) {
# if one does not want this, one must override
#' @export
setPredictType.BaseEnsemble = function(learner, predict.type) {

# this does the check for the prop
lrn = setPredictType.Learner(learner, predict.type)
lrn$base.learners = lapply(lrn$base.learners, setPredictType, predict.type = predict.type)
Expand All @@ -71,6 +78,7 @@ setPredictType.BaseEnsemble = function(learner, predict.type) {

#' @export
makeWrappedModel.BaseEnsemble = function(learner, learner.model, task.desc, subset, features, factor.levels, time) {

x = NextMethod(x)
addClasses(x, "BaseEnsembleModel")
}
15 changes: 12 additions & 3 deletions R/BaseWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
#' @export
makeBaseWrapper = function(id, type, next.learner, package = character(0L), par.set = makeParamSet(),
par.vals = list(), learner.subclass, model.subclass, cache = FALSE) {
if (inherits(next.learner, "OptWrapper") && is.element("TuneWrapper", learner.subclass))

if (inherits(next.learner, "OptWrapper") && is.element("TuneWrapper", learner.subclass)) {
stop("Cannot wrap a tuning wrapper around another optimization wrapper!")
}
ns = intersect(names(par.set$pars), names(next.learner$par.set$pars))
if (length(ns) > 0L)
if (length(ns) > 0L) {
stopf("Hyperparameter names in wrapper clash with base learner names: %s", collapse(ns))
}

learner = makeLearnerBaseConstructor(classes = c(learner.subclass, "BaseWrapper"),
id = id,
Expand All @@ -43,6 +46,7 @@ makeBaseWrapper = function(id, type, next.learner, package = character(0L), par.

#' @export
print.BaseWrapper = function(x, ...) {

s = ""
y = x
while (inherits(y, "BaseWrapper")) {
Expand All @@ -68,6 +72,7 @@ print.BaseWrapper = function(x, ...) {

#' @export
predictLearner.BaseWrapper = function(.learner, .model, .newdata, ...) {

args = removeFromDots(names(.learner$par.vals), ...)
do.call(predictLearner, c(
list(.learner = .learner$next.learner, .model = .model$learner.model$next.model, .newdata = .newdata),
Expand All @@ -77,6 +82,7 @@ predictLearner.BaseWrapper = function(.learner, .model, .newdata, ...) {

#' @export
makeWrappedModel.BaseWrapper = function(learner, learner.model, task.desc, subset = NULL, features, factor.levels, time) {

x = NextMethod()
addClasses(x, c(learner$model.subclass, "BaseWrapperModel"))
}
Expand All @@ -85,22 +91,25 @@ makeWrappedModel.BaseWrapper = function(learner, learner.model, task.desc, subse

#' @export
isFailureModel.BaseWrapperModel = function(model) {

return(!inherits(model$learner.model, "NoFeaturesModel") && isFailureModel(model$learner.model$next.model))
}

#' @export
getFailureModelMsg.BaseWrapperModel = function(model) {

return(getFailureModelMsg(model$learner.model$next.model))
}

#' @export
getFailureModelDump.BaseWrapperModel = function(model) {

return(getFailureModelDump(model$learner.model$next.model))
}

#' @export
getLearnerProperties.BaseWrapper = function(learner) {

# set properties by default to what the resulting type is allowed and what the base learner can do
intersect(listLearnerProperties(learner$type), getLearnerProperties(learner$next.learner))
}

13 changes: 10 additions & 3 deletions R/BaseWrapper_operators.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
#' @export
getParamSet.BaseWrapper = function(x) {

c(x$par.set, getParamSet(x$next.learner))
}


#' @export
getHyperPars.BaseWrapper = function(learner, for.fun = c("train", "predict", "both")) {

c(getHyperPars(learner$next.learner, for.fun), getHyperPars.Learner(learner, for.fun))
}


#' @export
setHyperPars2.BaseWrapper = function(learner, par.vals) {

ns = names(par.vals)
pds.n = names(learner$par.set$pars)
for (i in seq_along(par.vals)) {
Expand All @@ -26,18 +29,22 @@ setHyperPars2.BaseWrapper = function(learner, par.vals) {

#' @export
removeHyperPars.BaseWrapper = function(learner, ids) {

i = intersect(names(learner$par.vals), ids)
if (length(i) > 0L)
if (length(i) > 0L) {
learner = removeHyperPars.Learner(learner, i)
}
learner$next.learner = removeHyperPars(learner$next.learner, setdiff(ids, i))
return(learner)
}



getLeafLearner = function(learner) {
if (inherits(learner, "BaseWrapper"))

if (inherits(learner, "BaseWrapper")) {
return(getLeafLearner(learner$next.learner))
}
return(learner)
}

Expand All @@ -46,7 +53,7 @@ getLeafLearner = function(learner) {
# if one does not want this, one must override
#' @export
setPredictType.BaseWrapper = function(learner, predict.type) {

learner$next.learner = setPredictType(learner$next.learner, predict.type)
setPredictType.Learner(learner, predict.type)
}

10 changes: 7 additions & 3 deletions R/BenchmarkResultOrderLevels.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# useful for plotting in ggplot2
# if order.tsks is NULL, just return the df
orderBMRTasks = function(bmr, df = NULL, order.tsks) {
if (is.null(df))

if (is.null(df)) {
df = as.data.frame(bmr)
}
if (!is.null(order.tsks)) {
assertCharacter(order.tsks, len = length(getBMRTaskIds(bmr)))
assertSetEqual(order.tsks, getBMRTaskIds(bmr), ordered = FALSE)
Expand All @@ -15,9 +17,11 @@ orderBMRTasks = function(bmr, df = NULL, order.tsks) {
# order levels of learner.ids of a BenchmarkResult or similar data.frame
# useful for plotting in ggplot2
# if order.tsks is NULL, just return the df
orderBMRLrns = function(bmr, df = NULL, order.lrns){
if (is.null(df))
orderBMRLrns = function(bmr, df = NULL, order.lrns) {

if (is.null(df)) {
df = as.data.frame(bmr)
}
if (!is.null(order.lrns)) {
assertCharacter(order.lrns, len = length(getBMRLearnerIds(bmr)))
assertSetEqual(order.lrns, getBMRLearnerIds(bmr), ordered = FALSE)
Expand Down
Loading