From 3bf5fd59d82b80986dd61161c9f663df5e3dcd8c Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 1 Oct 2021 16:51:40 +0200 Subject: [PATCH 01/16] new databackend for outer joins --- R/DataBackendJoin.R | 133 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 R/DataBackendJoin.R diff --git a/R/DataBackendJoin.R b/R/DataBackendJoin.R new file mode 100644 index 000000000..5eba66108 --- /dev/null +++ b/R/DataBackendJoin.R @@ -0,0 +1,133 @@ + + +#' @export +DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE, + public = list( + initialize = function(b1, b2, by_b1 = NULL, by_b2 = NULL, type = "outer", b1_index_colname = NULL, b2_index_colname = NULL) { + assert_backend(b1) + assert_backend(b2) + + if ("data.table" %nin% intersect(b1$data_formats, b2$data_formats)) { + stop("DataBackendJoin currently only supports DataBackends that support 'data.table' format.") + } + + assert_choice(type, c("left", "right", "outer", "inner")) + + colnames_b1 = b1$colnames + colnames_b2 = b2$colnames + allcolnames = union(colnames_b1, colnames_b2) + + assert_choice(by_b1, colnames_b1, null.ok = TRUE) + assert_choice(by_b2, colnames_b2, null.ok = TRUE) + + assert_string(b1_index_colname, null.ok = TRUE) + assert_string(b2_index_colname, null.ok = TRUE) + + if (!is.null(b1_index_colname) && b1_index_colname %in% setdiff(allcolnames, b1$primary_key)) stopf("b1_index_colname '%s' already a non-primary-key column in b1 or b2.", b1_index_colname) + if (!is.null(b2_index_colname) && b2_index_colname %in% setdiff(allcolnames, b2$primary_key)) stopf("b2_index_colname '%s' already a non-primary-key column in b2 or b2.", b2_index_colname) + if (!is.null(b1_index_colname) && !is.null(b2_index_colname) && b1_index_colname == b2_index_colname) stop("b1_index_colname and b2_index_colname must be different, but are both '%s'.", b1_index_colname) + + rownames_b1 = b1$rownames + rownames_b2 = b2$rownames + + joinby_b1 = if (is.null(by_b1)) rownames_b1 else b1$data(rownames_b1, by_b1, data_format = "data.table")[[1]] + joinby_b2 = if (is.null(by_b2)) rownames_b2 else b2$data(rownames_b2, by_b2, data_format = "data.table")[[1]] + + index_table = merge(data.table(rownames_b1, joinby_b1), data.table(rownames_b2, joinby_b2), by.x = "joinby_b1", by.y = "joinby_b2", + all.x = type %in% c("left", "outer"), all.y = type %in% c("right", "outer"), sort = FALSE, allow.cartesian = TRUE) + + index_table[, c("joinby_b1", "joinby_b2") := NULL] + + pk = "..row_id" + index = 0 + while (pk %in% allcolnames) { + index = index + 1 + pk = paste0("..row_id.", index) + } + + super$initialize(list( + b1 = b1, b2 = b2, + colnames_b1 = setdiff(colnames_b1, colnames_b2) + allcolnames = union(colnames_b1, colnames_b2, b1_index_colname, b2_index_colname, pk), + index_table = index_table, + b1_index_colname = b1_index_colname, + b2_index_colname = b2_index_colname, + pk = pk, + aux_hash = calculate_hash(by_b1, by_b2, type, b1_index_colname, b2_index_colname) + ), primary_key = pk, data_formats = "data.table") + }, + + data = function(rows, cols, data_format = "data.table") { + d = private$.data + rows = rows[inrange(rows, 1, nrow(d$index_table))] + indices = d$index_table[rows] + b1_rows = indices[!is.na(rownames_b1), rownames_b1] + b2_rows = indices[!is.na(rownames_b2), rownames_b2] + indices[!is.na(rownames_b1), b1_index := seq_len(nrow(b1_rows))] + indices[!is.na(rownames_b2), b2_index := seq_len(nrow(b2_rows))] + b1_index = indices[, b1_index] + b2_index = indices[, b2_index] + + data = d$b2$data(b2_rows, cols, data_format = "data.table")[b2_index] + remainingcols = intersect(cols, d$colnames_b1) + if (length(remainingcols)) { + data = cbind(data, d$b1$data(b1_rows, cols, data_format = "data.table")[b1_index]) + } + if (d$pk %in% cols) { + data[, (d$pk) := rows] + } + if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) { + rownames_b2 = b2_rows$rownames_b2 + data[, (d$b2_index_colname) := rownames_b2] + } + if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) { + rownames_b1 = b1_rows$rownames_b1 + data[, (d$b1_index_colname) := rownames_b1] + } + data[, intersect(cols, names(data)), with = FALSE] + } + + head = function(n = 6L) { + rows = head(self$rownames, n) + self$data(rows = rows, cols = self$colnames) + }, + distinct = function(rows, cols, na_rm = TRUE) { + indices = private$.data$index_table[rows] + b1_rows = indices[!is.na(rownames_b1), rownames_b1] + b2_rows = indices[!is.na(rownames_b2), rownames_b2] + d2 = private$.data$b2$distinct(rows = b2_rows, cols = cols, na_rm = na_rm) + d1 = private$.data$b1$distinct(rows = b1_rows, cols = setdiff(cols, names(d2)), na_rm = na_rm) + if (!na_rm && nrow(b1_rows) < length(rows)) { + d1 = map(d1, function(x) if (any(is.na(x))) x else c(x, NA)) + } + if (!na_rm && nrow(b2_rows) < length(rows)) { + d2 = map(d2, function(x) if (any(is.na(x))) x else c(x, NA)) + } + res = c(d1, d2) + res[match(cols, names(res), nomatch = 0)] + }, + missings = functionrows, cols) { + indices = private$.data$index_table[rows] + b1_rows = indices[!is.na(rownames_b1), rownames_b1] + b2_rows = indices[!is.na(rownames_b2), rownames_b2] + m2 = private$.data$b2$missings(b2_rows, cols) + m1 = private$.data$b1$missings(b1_rows, setdiff(cols, names(m2))) + m1 = m1 + length(rows) - nrow(b1_rows) + m2 = m2 + length(rows) - nrow(b2_rows) + res = c(m1, m2) + res[match(cols, names(res), nomatch = 0)] + } + ), + active = list( + rownames = function() seq_len(nrow(private$.data$index_table)), + colnames = function() private$.data$allcolnames, + nrow = function() nrow(private$.data$index_table) + ncol = function() length(private$.data$allcolnames) + ), + private = list( + .calculate_hash = function() { + d = private$.data + calculate_hash(d$b1$hash, d$b2$hash,d$aux_hash) + } + ) +) From 2ac1ea9f9d4c118a7a7e0b361e406948205ad86f Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 1 Oct 2021 17:17:26 +0200 Subject: [PATCH 02/16] fix some bugs --- DESCRIPTION | 1 + NAMESPACE | 1 + R/DataBackendJoin.R | 12 ++++++------ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 0cc808896..d8be8bbd2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -89,6 +89,7 @@ NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = FALSE) RoxygenNote: 7.1.2 Collate: + 'DataBackendJoin.R' 'Graph.R' 'GraphLearner.R' 'mlr_pipeops.R' diff --git a/NAMESPACE b/NAMESPACE index c4dd8b15c..1692f9e1a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -25,6 +25,7 @@ S3method(print,Multiplicity) S3method(print,Selector) export("%>>%") export("%>>>%") +export(DataBackendJoin) export(Graph) export(GraphLearner) export(LearnerClassifAvg) diff --git a/R/DataBackendJoin.R b/R/DataBackendJoin.R index 5eba66108..87be5e977 100644 --- a/R/DataBackendJoin.R +++ b/R/DataBackendJoin.R @@ -36,7 +36,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = index_table = merge(data.table(rownames_b1, joinby_b1), data.table(rownames_b2, joinby_b2), by.x = "joinby_b1", by.y = "joinby_b2", all.x = type %in% c("left", "outer"), all.y = type %in% c("right", "outer"), sort = FALSE, allow.cartesian = TRUE) - index_table[, c("joinby_b1", "joinby_b2") := NULL] + index_table[, "joinby_b1" := NULL] pk = "..row_id" index = 0 @@ -47,8 +47,8 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = super$initialize(list( b1 = b1, b2 = b2, - colnames_b1 = setdiff(colnames_b1, colnames_b2) - allcolnames = union(colnames_b1, colnames_b2, b1_index_colname, b2_index_colname, pk), + colnames_b1 = setdiff(colnames_b1, colnames_b2), + allcolnames = unique(c(colnames_b1, colnames_b2, b1_index_colname, b2_index_colname, pk)), index_table = index_table, b1_index_colname = b1_index_colname, b2_index_colname = b2_index_colname, @@ -85,7 +85,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = data[, (d$b1_index_colname) := rownames_b1] } data[, intersect(cols, names(data)), with = FALSE] - } + }, head = function(n = 6L) { rows = head(self$rownames, n) @@ -106,7 +106,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = res = c(d1, d2) res[match(cols, names(res), nomatch = 0)] }, - missings = functionrows, cols) { + missings = function(rows, cols) { indices = private$.data$index_table[rows] b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] @@ -121,7 +121,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = active = list( rownames = function() seq_len(nrow(private$.data$index_table)), colnames = function() private$.data$allcolnames, - nrow = function() nrow(private$.data$index_table) + nrow = function() nrow(private$.data$index_table), ncol = function() length(private$.data$allcolnames) ), private = list( From e0eaf9b4c0ea3cf213c4b36fb4ba1b83e3eedac9 Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 1 Oct 2021 18:51:56 +0200 Subject: [PATCH 03/16] testing DataBackendJoin --- R/DataBackendJoin.R | 47 ++++++--- tests/testthat/test_DataBackendJoin.R | 137 ++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 tests/testthat/test_DataBackendJoin.R diff --git a/R/DataBackendJoin.R b/R/DataBackendJoin.R index 87be5e977..110efe7fb 100644 --- a/R/DataBackendJoin.R +++ b/R/DataBackendJoin.R @@ -3,7 +3,7 @@ #' @export DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE, public = list( - initialize = function(b1, b2, by_b1 = NULL, by_b2 = NULL, type = "outer", b1_index_colname = NULL, b2_index_colname = NULL) { + initialize = function(b1, b2, type, by_b1 = NULL, by_b2 = NULL, b1_index_colname = NULL, b2_index_colname = NULL) { assert_backend(b1) assert_backend(b2) @@ -27,6 +27,8 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = if (!is.null(b2_index_colname) && b2_index_colname %in% setdiff(allcolnames, b2$primary_key)) stopf("b2_index_colname '%s' already a non-primary-key column in b2 or b2.", b2_index_colname) if (!is.null(b1_index_colname) && !is.null(b2_index_colname) && b1_index_colname == b2_index_colname) stop("b1_index_colname and b2_index_colname must be different, but are both '%s'.", b1_index_colname) + colnames = unique(c(allcolnames, b1_index_colname, b2_index_colname)) + rownames_b1 = b1$rownames rownames_b2 = b2$rownames @@ -63,8 +65,8 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = indices = d$index_table[rows] b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] - indices[!is.na(rownames_b1), b1_index := seq_len(nrow(b1_rows))] - indices[!is.na(rownames_b2), b2_index := seq_len(nrow(b2_rows))] + indices[!is.na(rownames_b1), b1_index := seq_len(length(b1_rows))] + indices[!is.na(rownames_b2), b2_index := seq_len(length(b2_rows))] b1_index = indices[, b1_index] b2_index = indices[, b2_index] @@ -73,15 +75,16 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = if (length(remainingcols)) { data = cbind(data, d$b1$data(b1_rows, cols, data_format = "data.table")[b1_index]) } + setkeyv(data, NULL) if (d$pk %in% cols) { data[, (d$pk) := rows] } if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) { - rownames_b2 = b2_rows$rownames_b2 + rownames_b2 = indices$rownames_b2 data[, (d$b2_index_colname) := rownames_b2] } if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) { - rownames_b1 = b1_rows$rownames_b1 + rownames_b1 = indices$rownames_b1 data[, (d$b1_index_colname) := rownames_b1] } data[, intersect(cols, names(data)), with = FALSE] @@ -92,29 +95,51 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = self$data(rows = rows, cols = self$colnames) }, distinct = function(rows, cols, na_rm = TRUE) { - indices = private$.data$index_table[rows] + d = private$.data + indices = d$index_table[rows] b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] d2 = private$.data$b2$distinct(rows = b2_rows, cols = cols, na_rm = na_rm) + if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) { + d2[[d$b2_index_colname]] = if (na_rm) unique(b2_rows) else unique(indices$rownames_b2) + } d1 = private$.data$b1$distinct(rows = b1_rows, cols = setdiff(cols, names(d2)), na_rm = na_rm) - if (!na_rm && nrow(b1_rows) < length(rows)) { + if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) { + d1[[d$b1_index_colname]] = if (na_rm) unique(b1_rows) else unique(indices$rownames_b1) + } + + if (!na_rm && length(b1_rows) < length(rows)) { d1 = map(d1, function(x) if (any(is.na(x))) x else c(x, NA)) } - if (!na_rm && nrow(b2_rows) < length(rows)) { + if (!na_rm && length(b2_rows) < length(rows)) { d2 = map(d2, function(x) if (any(is.na(x))) x else c(x, NA)) } res = c(d1, d2) + if (d$pk %in% cols) { + res[[d$pk]] = unique(rows) + } + res[match(cols, names(res), nomatch = 0)] }, missings = function(rows, cols) { - indices = private$.data$index_table[rows] + d = private$.data + indices = d$index_table[rows] b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] m2 = private$.data$b2$missings(b2_rows, cols) + if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) { + m2[d$b2_index_colname] = 0L + } m1 = private$.data$b1$missings(b1_rows, setdiff(cols, names(m2))) - m1 = m1 + length(rows) - nrow(b1_rows) - m2 = m2 + length(rows) - nrow(b2_rows) + if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) { + m1[d$b1_index_colname] = 0L + } + m1 = m1 + length(rows) - length(b1_rows) + m2 = m2 + length(rows) - length(b2_rows) res = c(m1, m2) + if (d$pk %in% cols) { + res[d$pk] = 0L + } res[match(cols, names(res), nomatch = 0)] } ), diff --git a/tests/testthat/test_DataBackendJoin.R b/tests/testthat/test_DataBackendJoin.R new file mode 100644 index 000000000..6992e981f --- /dev/null +++ b/tests/testthat/test_DataBackendJoin.R @@ -0,0 +1,137 @@ +context("DataBackendJoin") + + +test_that("DataBackendJoin works as expected", { + + d1 <- data.table(x = c(letters[-2], NA), y = LETTERS, z = rep(1:13, 2), id = (1:26) * 10L) + d2 <- data.table(a = c(paste0(letters, LETTERS)[-2], NA), y = letters, idx = (27:2) * 10L) + d3 <- data.table(a = c(paste0(letters, LETTERS)[-2], NA), y = letters, id = (27:2)) + + d1b <- DataBackendDataTable$new(d1, "id") + d2b <- DataBackendDataTable$new(d2, "idx") + d3b <- DataBackendDataTable$new(d3, "id") + + + + dbj <- DataBackendJoin$new(d1b, d2b) + + expect_backend(dbj) + + expect_identical(dbj$data(1:3, c("x", "y", "id", "a", "idx")), + data.table(x = letters[1:4][-2], y = c(NA, letters[26:25]), id = (1:3) * 10L, a = c(NA, NA, "zZ"), idx = c(NA, (2:3) * 10L)) + ) + + expect_identical(dbj$data(1:3, c("x", "y", "id", "a", "idx", dbj$primary_key)), + data.table(x = letters[1:4][-2], y = c(NA, letters[26:25]), id = (1:3) * 10L, a = c(NA, NA, "zZ"), idx = c(NA, (2:3) * 10L), ..row_id = 1:3) + ) + + expect_identical(dbj$missings(1:3, c("x", "y", "id", "a", "idx")), c(x = 0L, y = 1L, id = 0L, a = 2L, idx = 1L)) + expect_identical(dbj$missings(1:3, c("x", "y", "id", "a", "idx", dbj$primary_key)), c(x = 0L, y = 1L, id = 0L, a = 2L, idx = 1L, ..row_id = 0L)) + + + dbj <- DataBackendJoin$new(d1b, d3b, by_b1 = "z", b1_index_colname = "b1index", b2_index_colname = "b2index", type = "outer") + + expect_backend(dbj) + + expected = merge(d1[, c("x", "z", "id"), with = FALSE], rev(d3), by.x = "z", by.y = "id", all = TRUE, sort = FALSE)[, .(x, y, z = ifelse(z %inrange% c(1, 13), z, NA), id = z, a, b1index = id, b2index = z)] + expected[, ..row_id := seq_len(nrow(expected))] + expected[id == 1, id := NA] + expected[b2index == 1, b2index := NA] + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), expected, check.attributes = FALSE) + + + dbj <- DataBackendJoin$new(d1b, d3b, by_b1 = "z", b1_index_colname = "b1index", b2_index_colname = "b2index", type = "inner") + expect_backend(dbj) + + expected = merge(d1[, c("x", "z", "id"), with = FALSE], rev(d3), by.x = "z", by.y = "id", all = FALSE, sort = FALSE)[, .(x, y, z = ifelse(z %inrange% c(1, 13), z, NA), id = z, a, b1index = id, b2index = z)] + expected[, ..row_id := seq_len(nrow(expected))] + expect_equal(dbj$data(dbj$rownames, dbj$colnames), expected, check.attributes = FALSE) + + dbj <- DataBackendJoin$new(d1b, d3b, by_b1 = "z", b1_index_colname = "b1index", b2_index_colname = "b2index", type = "left") + expect_backend(dbj) + + expected = merge(d1[, c("x", "z", "id"), with = FALSE], rev(d3), by.x = "z", by.y = "id", all.x = TRUE, all.y = FALSE, sort = FALSE)[, .(x, y, z = ifelse(z %inrange% c(1, 13), z, NA), id = z, a, b1index = id, b2index = z)] + expected[, ..row_id := seq_len(nrow(expected))] + expected[id == 1, id := NA] + expected[b2index == 1, b2index := NA] + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), expected, check.attributes = FALSE) + + dbj <- DataBackendJoin$new(d1b, d3b, by_b1 = "z", b1_index_colname = "b1index", b2_index_colname = "b2index", type = "right") + expect_backend(dbj) + + expected = merge(d1[, c("x", "z", "id"), with = FALSE], rev(d3), by.x = "z", by.y = "id", all.x = FALSE, all.y = TRUE, sort = FALSE)[, .(x, y, z = ifelse(z %inrange% c(1, 13), z, NA), id = z, a, b1index = id, b2index = z)] + expected[, ..row_id := seq_len(nrow(expected))] + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), expected, check.attributes = FALSE) +}) + +test_that("DataBackendJoin edge cases", { + d1 <- data.table(x = c(letters[-2], NA), y = LETTERS, z = rep(1:13, 2), ..row_id = (1:26) * 10L) + d2 <- data.table(a = c(paste0(letters, LETTERS)[-2], NA), y = letters, ..row_id = (27:2) * 10L) + + d1b <- DataBackendDataTable$new(d1, "..row_id") + d2b <- DataBackendDataTable$new(d2, "..row_id") + + dbj = DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "b1", b2_index_colname = "b2") + expect_backend(dbj) + + expect_set_equal(dbj$colnames, c(colnames(d1), colnames(d2), "..row_id.1", "b1", "b2")) + + expect_equal( + dbj$data(dbj$rownames, dbj$colnames), + d1[d2, .(x, y = i.y, z, ..row_id, a, b1 = ..row_id, b2 = ..row_id, ..row_id.1 = seq_len(25)), on = "..row_id", nomatch = NULL], + check.attributes = FALSE + ) + + d1 = data.table(x = c(1, 2), y = c("a", "b"), ..row_id = 1:2) + d2 = data.table(x = 1, y = "z", ..row_id = 3) + d1b <- DataBackendDataTable$new(d1, "..row_id") + d2b <- DataBackendDataTable$new(d2, "..row_id") + + dbj <- DataBackendJoin$new(d1b, d2b, type = "inner", by_b1 = "x", by_b2 = "x", b1_index_colname = "b1", b2_index_colname = "b2") + expect_backend(dbj) + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), data.table(x = 1, y = "z", ..row_id = 3, b1 = 1, b2 = 3, ..row_id.1 = 1)) + + dbj <- DataBackendJoin$new(d1b, d2b, type = "outer", by_b1 = "x", by_b2 = "x", b1_index_colname = "b1", b2_index_colname = "b2") + expect_backend(dbj) + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), data.table(x = c(1, NA), y = c("z", NA), ..row_id = c(3, NA), b1 = 1:2, b2 = c(3, NA), ..row_id.1 = c(1, 2))) + + dbj <- DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "b1", b2_index_colname = "b2") + expect_backend(dbj) + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), data.table(x = 1, y = "z", ..row_id = 3, b1 = 1, b2 = 3, ..row_id.1 = 1)[0]) + + dbj <- DataBackendJoin$new(d1b, d2b, type = "outer", b1_index_colname = "b1", b2_index_colname = "b2") + expect_backend(dbj) + + expect_equal(dbj$data(dbj$rownames, dbj$colnames), data.table(x = c(NA, NA, 1), y = c(NA, NA, "z"), ..row_id = c(NA, NA, 3), b1 = c(1, 2, NA), b2 = c(NA, NA, 3), ..row_id.1 = 1:3)) + +}) + +test_that("DataBackendJoin errors", { + + d1 <- data.table(x = c(letters[-2], NA), y = LETTERS, z = rep(1:13, 2), id = (1:26) * 10L) + d2 <- data.table(a = c(paste0(letters, LETTERS)[-2], NA), y = letters, idx = (27:2) * 10L) + + d1b <- DataBackendDataTable$new(d1, "id") + d2b <- DataBackendDataTable$new(d2, "idx") + + expect_error(DataBackendJoin$new(d1b, d2b, by_b1 = "n"), "by_b1.*of set.*but is 'n'") + expect_error(DataBackendJoin$new(d1b, d2b, by_b2 = "n"), "by_b2.*of set.*but is 'n'") + + expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "x"), "already a non-primary-key") + expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "a"), "already a non-primary-key") + expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b2_index_colname = "x"), "already a non-primary-key") + expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b2_index_colname = "a"), "already a non-primary-key") + expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "n", b2_index_colname = "n"), "must be different") + + + + + + +}) From 39b761bcfb9bc0ee03932549704983a7b6d33bd8 Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 1 Oct 2021 21:06:49 +0200 Subject: [PATCH 04/16] adding DataBackendMultiCbind --- R/DataBackendMultiCbind.R | 110 ++++++++++++++++++++++++++ tests/testthat/test_DataBackendJoin.R | 8 +- 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 R/DataBackendMultiCbind.R diff --git a/R/DataBackendMultiCbind.R b/R/DataBackendMultiCbind.R new file mode 100644 index 000000000..8ebe83ba9 --- /dev/null +++ b/R/DataBackendMultiCbind.R @@ -0,0 +1,110 @@ + + +#' @export +DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE, + public = list( + initialize = function(bs) { + assert_list(bs, min.len = 1) + lapply(bs, assert_backend) + + formats = Reduce(intersect, map(bs, "data_formats")) + + super$initialize(list(bs = rev(bs)), bs[[1]]$primary_key, formats) + }, + data = function(rows, cols, data_format = "data.table") { + bs = private$.data$bs + + urows = unique(rows) + + datas = list() + pks = character(length(bs)) + include_pk = logical(length(bs)) + cols_remaining = cols + allrows = list() + for (i in seq_along(bs)) { + ## Not doing 'if (length(cols_remaining)) break' because there could still be tables remaining that add rows + pk = bs[[i]]$primary_key + pks[[i]] = pk + include_pk = pk %in% cols_remaining + if (include_pk[[i]]) { + datas[[i]] = bs[[i]]$data(urows, cols_remaining, data_format = data_format) + cols_remaining = setdiff(cols_remaining, colnames(datas[[i]])) + } else { + datas[[i]] = bs[[i]]$data(urows, c(pk, cols_remaining), data_format = data_format) + cols_remaining = setdiff(cols_remaining, colnames(datas[[i]])[-1]) + } + allrows[[i]] = datas[[i]][[pk]] + } + presentrows = unique(unlist(allrows)) + result = do.call(rbind, pmap(list(datas, pks, include_pk), function(data, pk, include) { + if (include) data[presentrows] else data[presentrows, -pk, with = FALSE, nomatch = NA] + })) + sbk = self$primary_key + if (sbk %in% cols) result[, (sbk) := presentrows] + data[J(presentrows), on = sbk, nomatch = NULL] + }, + head = function(n = 6L) { + rows = head(self$rownames, n) + self$data(rows = rows, cols = self$colnames) + }, + distinct = function(rows, cols, na_rm = TRUE) { + bs = private$data$bs + getpk = self$primary_key %in% cols + results = list() + remaining_cols = cols + if (na_rm || getpk) { + rows = intersect(rows, self$rownames) + } + for (i in seq_along(bs)) { + if (!length(remaining_cols)) break + results[[i]] = bs[[i]]$distinct(rows = rows, cols = cols, na_rm = na_rm) + remaining_cols = setdiff(remaining_cols, names(results[[i]])) + if (na_rm && !all(rows %in% bs[[i]]$rownames)) { + results[[i]] = c(results[[i]], NA) + } + } + result = unlist(result, recursive = FALSE) + if (getpk) { + result[[self$primary_key]] = rows + } + result[match(cols, names(result), nomatch = 0)] + }, + missings = function(rows, cols) { + rows = rows[rows %in% self$rownames] + bs = private$data$bs + getpk = self$primary_key %in% cols + results = list() + remaining_cols = cols + for (i in seq_along(bs)) { + if (!length(remaining_cols)) break + missingrows = sum(rows %nin% bs[[i]]$rownames) + results[[i]] = bs[[i]]$missing(rows, remaining_cols) + missingrows + remaining_cols = setdiff(remaining_cols, names(results[[i]])) + } + result = unlist(result) + if (self$primary_key %in% cols) { + result[[self$primary_key]] = 0L + } + result[match(cols, names(result), nomatch = 0)] + } + ), + active = list( + rownames = function() { + if (is.null(private$.rownames_cache)) private$.rownames_cache = unique(unlist(map(bs, "rownames"))) + private$.rownames_cache + }, + colnames = function() { + if (is.null(private$.colnames_cache)) private$.colnames_cache = unique(unlist(map(bs, "colnames"))) + private$.colnames_cache + }, + nrow = function() length(self$rownames), + ncol = function() length(self$colnames) + ), + private = list( + .rownames_cache = NULL, + .colnames_cache = NULL, + .calculate_hash = function() { + do.call(calculate_hash, private$.data$bs) + } + ) +) diff --git a/tests/testthat/test_DataBackendJoin.R b/tests/testthat/test_DataBackendJoin.R index 6992e981f..d7c9f0306 100644 --- a/tests/testthat/test_DataBackendJoin.R +++ b/tests/testthat/test_DataBackendJoin.R @@ -13,10 +13,16 @@ test_that("DataBackendJoin works as expected", { - dbj <- DataBackendJoin$new(d1b, d2b) + dbj <- DataBackendJoin$new(d1b, d2b, type = "outer") expect_backend(dbj) + expect_identical(dbj$distinct(1:2, dbj$colnames), list(x = c("a", "c"), y = "z", z = 1:2, id = c(10L, 20L), a = character(0), idx = 20L, ..row_id = 1:2)) + expect_identical(dbj$distinct(1:2, dbj$colnames, na_rm = FALSE), list(x = c("a", "c"), y = c("z", NA), z = 1:2, id = c(10L, 20L), a = NA_character_, idx = c(20L, NA), ..row_id = 1:2)) + + expect_identical(dbj$missings(1:2, dbj$colnames), c(x = 0L, y = 1L, z = 0L, id = 0L, a = 2L, idx = 1L, ..row_id = 0L)) + expect_identical(dbj$missings(c(1:2, 2:1), dbj$colnames), 2L * c(x = 0L, y = 1L, z = 0L, id = 0L, a = 2L, idx = 1L, ..row_id = 0L)) + expect_identical(dbj$data(1:3, c("x", "y", "id", "a", "idx")), data.table(x = letters[1:4][-2], y = c(NA, letters[26:25]), id = (1:3) * 10L, a = c(NA, NA, "zZ"), idx = c(NA, (2:3) * 10L)) ) From 32ca417a92b7015cd887bd520b00b1e6114061af Mon Sep 17 00:00:00 2001 From: mb706 Date: Sat, 2 Oct 2021 19:52:46 +0200 Subject: [PATCH 05/16] DataBackendMultiCbind --- DESCRIPTION | 1 + NAMESPACE | 1 + R/DataBackendMultiCbind.R | 70 ++++++++++++------- man/mlr3pipelines-package.Rd | 6 +- tests/testthat/test_DataBackendJoin.R | 5 -- tests/testthat/test_DataBackendMultiCbind.R | 77 +++++++++++++++++++++ 6 files changed, 130 insertions(+), 30 deletions(-) create mode 100644 tests/testthat/test_DataBackendMultiCbind.R diff --git a/DESCRIPTION b/DESCRIPTION index d8be8bbd2..db917d22b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -90,6 +90,7 @@ Roxygen: list(markdown = TRUE, r6 = FALSE) RoxygenNote: 7.1.2 Collate: 'DataBackendJoin.R' + 'DataBackendMultiCbind.R' 'Graph.R' 'GraphLearner.R' 'mlr_pipeops.R' diff --git a/NAMESPACE b/NAMESPACE index 1692f9e1a..2b12790fa 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -26,6 +26,7 @@ S3method(print,Selector) export("%>>%") export("%>>>%") export(DataBackendJoin) +export(DataBackendMultiCbind) export(Graph) export(GraphLearner) export(LearnerClassifAvg) diff --git a/R/DataBackendMultiCbind.R b/R/DataBackendMultiCbind.R index 8ebe83ba9..1b6ed4ebb 100644 --- a/R/DataBackendMultiCbind.R +++ b/R/DataBackendMultiCbind.R @@ -1,7 +1,7 @@ #' @export -DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE, +DataBackendMultiCbind = R6Class("DataBackendMultiCbind", inherit = DataBackend, cloneable = FALSE, public = list( initialize = function(bs) { assert_list(bs, min.len = 1) @@ -9,7 +9,24 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = formats = Reduce(intersect, map(bs, "data_formats")) - super$initialize(list(bs = rev(bs)), bs[[1]]$primary_key, formats) + private$.colnames = unique(unlist(map(bs, "colnames"))) + + # primary key: if all backends have the same pk, just use that one. + otherpk = unique(unlist(map(bs, "primary_key"))) + if (length(otherpk) == 1) { + pk = otherpk + } else { + # otherwise: introduce a new primary key that is completely different from the previous ones. + pk = "..row_id" + index = 0 + while (pk %in% private$.colnames) { + index = index + 1 + pk = paste0("..row_id.", index) + } + private$.colnames = c(private$.colnames, pk) + } + + super$initialize(list(bs = rev(bs)), pk, formats) }, data = function(rows, cols, data_format = "data.table") { bs = private$.data$bs @@ -25,7 +42,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = ## Not doing 'if (length(cols_remaining)) break' because there could still be tables remaining that add rows pk = bs[[i]]$primary_key pks[[i]] = pk - include_pk = pk %in% cols_remaining + include_pk[[i]] = pk %in% cols_remaining if (include_pk[[i]]) { datas[[i]] = bs[[i]]$data(urows, cols_remaining, data_format = data_format) cols_remaining = setdiff(cols_remaining, colnames(datas[[i]])) @@ -36,34 +53,40 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = allrows[[i]] = datas[[i]][[pk]] } presentrows = unique(unlist(allrows)) - result = do.call(rbind, pmap(list(datas, pks, include_pk), function(data, pk, include) { - if (include) data[presentrows] else data[presentrows, -pk, with = FALSE, nomatch = NA] + result = do.call(cbind, pmap(list(datas, pks, include_pk), function(data, pk, include) { + if (include) { + result = data[J(presentrows), on = pk, nomatch = NA] + droppk = result[[pk]] %nin% data[[pk]] + result[droppk, (pk) := NA] + } else { + data[J(presentrows), -pk, on = pk, with = FALSE, nomatch = NA] + } })) sbk = self$primary_key - if (sbk %in% cols) result[, (sbk) := presentrows] - data[J(presentrows), on = sbk, nomatch = NULL] + result[, (sbk) := presentrows] + result[J(rows), intersect(cols, colnames(result)), with = FALSE, on = sbk, nomatch = NULL] }, head = function(n = 6L) { rows = head(self$rownames, n) self$data(rows = rows, cols = self$colnames) }, distinct = function(rows, cols, na_rm = TRUE) { - bs = private$data$bs + bs = private$.data$bs getpk = self$primary_key %in% cols - results = list() + reslist = list() remaining_cols = cols - if (na_rm || getpk) { + if (!na_rm || getpk) { rows = intersect(rows, self$rownames) } for (i in seq_along(bs)) { if (!length(remaining_cols)) break - results[[i]] = bs[[i]]$distinct(rows = rows, cols = cols, na_rm = na_rm) - remaining_cols = setdiff(remaining_cols, names(results[[i]])) - if (na_rm && !all(rows %in% bs[[i]]$rownames)) { - results[[i]] = c(results[[i]], NA) + reslist[[i]] = bs[[i]]$distinct(rows = rows, cols = cols, na_rm = na_rm) + remaining_cols = setdiff(remaining_cols, names(reslist[[i]])) + if (!na_rm && !all(rows %in% bs[[i]]$rownames)) { + reslist[[i]] = map(reslist[[i]], function(x) if (any(is.na(x))) x else c(x, NA)) } } - result = unlist(result, recursive = FALSE) + result = unlist(reslist, recursive = FALSE) if (getpk) { result[[self$primary_key]] = rows } @@ -71,17 +94,17 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = }, missings = function(rows, cols) { rows = rows[rows %in% self$rownames] - bs = private$data$bs + bs = private$.data$bs getpk = self$primary_key %in% cols - results = list() + reslist = list() remaining_cols = cols for (i in seq_along(bs)) { if (!length(remaining_cols)) break missingrows = sum(rows %nin% bs[[i]]$rownames) - results[[i]] = bs[[i]]$missing(rows, remaining_cols) + missingrows - remaining_cols = setdiff(remaining_cols, names(results[[i]])) + reslist[[i]] = bs[[i]]$missings(rows, remaining_cols) + missingrows + remaining_cols = setdiff(remaining_cols, names(reslist[[i]])) } - result = unlist(result) + result = unlist(reslist) if (self$primary_key %in% cols) { result[[self$primary_key]] = 0L } @@ -90,19 +113,18 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = ), active = list( rownames = function() { - if (is.null(private$.rownames_cache)) private$.rownames_cache = unique(unlist(map(bs, "rownames"))) + if (is.null(private$.rownames_cache)) private$.rownames_cache = unique(unlist(rev(map(private$.data$bs, "rownames")))) private$.rownames_cache }, colnames = function() { - if (is.null(private$.colnames_cache)) private$.colnames_cache = unique(unlist(map(bs, "colnames"))) - private$.colnames_cache + private$.colnames }, nrow = function() length(self$rownames), ncol = function() length(self$colnames) ), private = list( .rownames_cache = NULL, - .colnames_cache = NULL, + .colnames = NULL, .calculate_hash = function() { do.call(calculate_hash, private$.data$bs) } diff --git a/man/mlr3pipelines-package.Rd b/man/mlr3pipelines-package.Rd index 0084f20bb..cc209fe4c 100644 --- a/man/mlr3pipelines-package.Rd +++ b/man/mlr3pipelines-package.Rd @@ -8,7 +8,11 @@ \description{ \if{html}{\figure{logo.png}{options: align='right' alt='logo' width='120'}} -Dataflow programming toolkit that enriches 'mlr3' with a diverse set of pipelining operators ('PipeOps') that can be composed into graphs. Operations exist for data preprocessing, model fitting, and ensemble learning. Graphs can themselves be treated as 'mlr3' 'Learners' and can therefore be resampled, benchmarked, and tuned. +Dataflow programming toolkit that enriches 'mlr3' with a diverse + set of pipelining operators ('PipeOps') that can be composed into graphs. + Operations exist for data preprocessing, model fitting, and ensemble + learning. Graphs can themselves be treated as 'mlr3' 'Learners' and can + therefore be resampled, benchmarked, and tuned. } \seealso{ Useful links: diff --git a/tests/testthat/test_DataBackendJoin.R b/tests/testthat/test_DataBackendJoin.R index d7c9f0306..15e0246b4 100644 --- a/tests/testthat/test_DataBackendJoin.R +++ b/tests/testthat/test_DataBackendJoin.R @@ -135,9 +135,4 @@ test_that("DataBackendJoin errors", { expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b2_index_colname = "a"), "already a non-primary-key") expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "n", b2_index_colname = "n"), "must be different") - - - - - }) diff --git a/tests/testthat/test_DataBackendMultiCbind.R b/tests/testthat/test_DataBackendMultiCbind.R new file mode 100644 index 000000000..7d0a2f361 --- /dev/null +++ b/tests/testthat/test_DataBackendMultiCbind.R @@ -0,0 +1,77 @@ +context("DataBackendMultiCbind") + + +test_that("DataBackendMultiCbind works as expected", { + + d1 = data.table(x = c(letters[1:3], NA), y = LETTERS[1:4], z = c(1, 2, 2, 1), id = (1:4) * 10L) + d2 = data.table(a = c(paste0(letters[1:3], LETTERS[1:3]), NA), y = letters[20:23], id = -(2:5) * 10L, idx = (0:3) * 10L) + d3 = data.table(x = as.character(1:4), z = 9:6, id = (3:6) * 10L) + + d1b <- DataBackendDataTable$new(d1, "id") + d2b <- DataBackendDataTable$new(d2, "idx") + d3b <- DataBackendDataTable$new(d3, "id") + + dbmc <- DataBackendMultiCbind$new(list(d1b, d2b, d3b)) + + expect_backend(dbmc) + + expect_equal( + dbmc$data((0:6) * 10L, dbmc$colnames), + data.table( + x = c(NA_character_, NA, NA, 1:4), y = c(d2$y, NA, NA, NA), z = c(NA, NA, NA, 9:6), + id = c(NA, NA, NA, 3:6) * 10L, a = c(d2$a, NA, NA, NA), idx = c((0:3) * 10L, NA, NA, NA), + ..row_id = (0:6) * 10L + ) + ) + + + dbmc <- DataBackendMultiCbind$new(list(d1b, d3b)) + + expect_backend(dbmc) + + expect_equal( + dbmc$data((0:6) * 10L, dbmc$colnames), + data.table( + x = c(NA_character_, NA, 1:4), y = c(d1$y, NA, NA), z = c(NA, NA, 9:6), + id = (1:6) * 10L + ) + ) + + d0b = DataBackendDataTable$new(data.table(id = c(10:20)), "id") + + dbmc <- DataBackendMultiCbind$new(list(d0b, d1b)) + + expect_backend(dbmc) + + expect_set_equal(dbmc$rownames, c(10:20, 30L, 40L)) + + expect_equal(dbmc$data(c(10:20, 30L, 40L), dbmc$colnames), data.table( + id = c(10:20, 30L, 40L), + x = c("a", rep(NA, 9), letters[2:3], NA), + y = c("A", rep(NA, 9), LETTERS[2:4]), + z = c(1, rep(NA, 9), 2, 2, 1) + )) + + expect_identical(dbmc$data(11, dbmc$colnames), data.table(id = 11L, x = NA_character_, y = NA_character_, z = NA_real_)) + + expect_identical(dbmc$data(11, "x"), data.table(x = NA_character_)) + + + + dbmc <- DataBackendMultiCbind$new(list(d0b, d3b)) + + expect_backend(dbmc) + + + expect_set_equal(dbmc$rownames, c(10:20, (3:6)*10L)) + + expect_equal(dbmc$data(c(10:20, (3:6) * 10L), dbmc$colnames), data.table( + id = c(10:20, (3:6) * 10L), + x = c(rep(NA_character_, 11), 1:4), + z = c(rep(NA, 11), 9:6) + )) + + +}) + + From 5197023f05cc7ddbc3d3a09f1d40dea39047960a Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 14:15:07 +0200 Subject: [PATCH 06/16] space --- R/PipeOpScaleMaxAbs.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PipeOpScaleMaxAbs.R b/R/PipeOpScaleMaxAbs.R index 6813c2230..2e6f0e13a 100644 --- a/R/PipeOpScaleMaxAbs.R +++ b/R/PipeOpScaleMaxAbs.R @@ -64,7 +64,7 @@ PipeOpScaleMaxAbs = R6Class("PipeOpScaleMaxAbs", private = list( .get_state_dt = function(dt, levels, target) { - lapply(dt, function(x){ + lapply(dt, function(x) { s = max(abs(range(x, na.rm = TRUE, finite = TRUE))) if (s == 0) { s = 1 From 82a3521490ae4f786ae32417657f32fe5a720395 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 14:16:57 +0200 Subject: [PATCH 07/16] repair tests --- tests/testthat/test_DataBackendJoin.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_DataBackendJoin.R b/tests/testthat/test_DataBackendJoin.R index 15e0246b4..de5b94bd9 100644 --- a/tests/testthat/test_DataBackendJoin.R +++ b/tests/testthat/test_DataBackendJoin.R @@ -126,8 +126,8 @@ test_that("DataBackendJoin errors", { d1b <- DataBackendDataTable$new(d1, "id") d2b <- DataBackendDataTable$new(d2, "idx") - expect_error(DataBackendJoin$new(d1b, d2b, by_b1 = "n"), "by_b1.*of set.*but is 'n'") - expect_error(DataBackendJoin$new(d1b, d2b, by_b2 = "n"), "by_b2.*of set.*but is 'n'") + expect_error(DataBackendJoin$new(d1b, d2b, type = "outer", by_b1 = "n"), "by_b1.*of set.*but is 'n'") + expect_error(DataBackendJoin$new(d1b, d2b, type = "outer", by_b2 = "n"), "by_b2.*of set.*but is 'n'") expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "x"), "already a non-primary-key") expect_error(DataBackendJoin$new(d1b, d2b, type = "inner", b1_index_colname = "a"), "already a non-primary-key") From ca92ec25f72f9246a703fc665ff8f4e0eaaf3527 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 16:11:28 +0200 Subject: [PATCH 08/16] alternative to chain_graphs --- NAMESPACE | 1 + R/operators.R | 35 ++++++++++++++++++++++++++++---- man/concat_graphs.Rd | 29 ++++++++++++++++++++++++-- man/mlr3pipelines-package.Rd | 6 +----- man/mlr_pipeops_tunethreshold.Rd | 2 +- 5 files changed, 61 insertions(+), 12 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2b12790fa..d94399531 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -110,6 +110,7 @@ export(as_graph) export(as_pipeop) export(assert_graph) export(assert_pipeop) +export(chain_graphs) export(concat_graphs) export(filter_noop) export(greplicate) diff --git a/R/operators.R b/R/operators.R index b19eccc65..e83eef214 100644 --- a/R/operators.R +++ b/R/operators.R @@ -38,12 +38,34 @@ #' Note that if `g1` is `NULL`, `g2` converted to a [`Graph`] will be returned. #' Analogously, if `g2` is `NULL`, `g1` converted to a [`Graph`] will be returned. #' +#' @section Chaining Graphs: +#' `concat_graphs` can also be called with the `glist` argument, which takes an arbitrary amount of [`Graph`]s or [`PipeOp`]s (or objects that can be automatically +#' converted into [`Graph`]s or [`PipeOp`]s) and joins them in a serial [`Graph`], as if connecting them using [`%>>%`]. +#' +#' Care is taken to avoid unnecessarily cloning of components. A call of +#' `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` is equivalent to +#' `g1 %>>% g2 %>>>% g3 %>>>% g4 %>>>% ...`. +#' A call of `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` +#' is equivalent to `g1 %>>>% g2 %>>>% g3 %>>>% g4 %>>>% ...` (differing in the +#' first operator being `%>>>%` as well). +#' +#' `concat_graphs(glist = )` (implicitly with `in_place = FALSE`) is a safe way of generating large linear pipelines quickly, while +#' still avoiding to change any of its inputs by reference, and avoiding the risk of ending up with broken objects. +#' #' @param g1 ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`) \cr #' [`Graph`] / [`PipeOp`] / object-convertible-to-[`PipeOp`] to put in front of `g2`. #' @param g2 ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`) \cr #' [`Graph`] / [`PipeOp`] / object-convertible-to-[`PipeOp`] to put after `g1`. +#' @param glist `list` of ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`)\cr +#' List of elements which are the [`Graph`]s to be joined. Elements must be convertible to [`Graph`] or [`PipeOp`] using [`as_graph()`] and [`as_pipeop()`]. +#' `NULL` is the neutral element of [`%>>%`] and skipped. When this is given, `g1` and `g2` must not be given. #' @param in_place (`logical(1)`)\cr -#' Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless. +#' When `g1` and `g2` are given: Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless.\n +#' When `glist` is given instead: +#' Whether to try to avoid cloning the first element of `glist`, similar to the difference +#' of `%>>>%` over `%>>%`. This can only be avoided if `glist[[1]]` is already a [`Graph`]. +#' Beware that, when `in_place` is `TRUE` and if `concat_graphs()` fails because of id collisions, then `glist[[1]]` will possibly be in an incompletely +#' modified state. #' #' @return [`Graph`]: the constructed [`Graph`]. #' @family Graph operators @@ -81,7 +103,11 @@ #' o1 %>>>% o2 #' #' o1 # not changed, becuase not a Graph. -concat_graphs = function(g1, g2, in_place = FALSE) { +concat_graphs = function(g1, g2, glist, in_place = FALSE) { + if (!missing(glist)) { + if (!missing(g1) || !missing(g2)) stop("When glist is given, g1 and g2 must not be given") + return(chain_graphs(graphs = glist, in_place = in_place)) + } assert_flag(in_place) # neutral elements handling if (is.null(g1)) return(if (!is.null(g2)) as_graph(g2, clone = TRUE)) @@ -94,7 +120,6 @@ concat_graphs = function(g1, g2, in_place = FALSE) { if (nrow(g1out) != 1 && nrow(g1out) != nrow(g2in) && !(nrow(g2in) == 1 && g2in$channel.name == "...")) { stopf("Graphs / PipeOps to be connected have mismatching number of inputs / outputs.") } - g = gunion(list(g1, g2), in_place = in_place) # check that types agree @@ -114,6 +139,8 @@ concat_graphs = function(g1, g2, in_place = FALSE) { # build edges from free output channels of g1 and free input channels of g2 new_edges = cbind(g1out[, list(src_id = get("op.id"), src_channel = get("channel.name"))], g2in[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))]) + + g = gunion(list(g1, g2), in_place = in_place) g$edges = rbind(g$edges, new_edges) g } @@ -158,7 +185,7 @@ strip_multiplicity_type = function(type) { #' Beware that, if `chain_graphs()` fails because of id collisions, then `graphs[[1]]` will possibly be in an incompletely #' modified state when `in_place` is `TRUE`. #' @return [`Graph`] the resulting [`Graph`], or `NULL` if there are no non-null values in `graphs`. -#' +#' @export chain_graphs = function(graphs, in_place = FALSE) { assert_list(graphs) graphs = discard(graphs, is.null) diff --git a/man/concat_graphs.Rd b/man/concat_graphs.Rd index 79952c01f..340200c61 100644 --- a/man/concat_graphs.Rd +++ b/man/concat_graphs.Rd @@ -6,7 +6,7 @@ \alias{\%>>>\%} \title{PipeOp Composition Operator} \usage{ -concat_graphs(g1, g2, in_place = FALSE) +concat_graphs(g1, g2, glist, in_place = FALSE) g1 \%>>\% g2 @@ -19,8 +19,17 @@ g1 \%>>>\% g2 \item{g2}{(\code{\link{Graph}} | \code{\link{PipeOp}} | \code{\link[mlr3:Learner]{Learner}} | \code{\link[mlr3filters:Filter]{Filter}} | \code{list} | \code{...}) \cr \code{\link{Graph}} / \code{\link{PipeOp}} / object-convertible-to-\code{\link{PipeOp}} to put after \code{g1}.} +\item{glist}{\code{list} of (\code{\link{Graph}} | \code{\link{PipeOp}} | \code{\link[mlr3:Learner]{Learner}} | \code{\link[mlr3filters:Filter]{Filter}} | \code{list} | \code{...})\cr +List of elements which are the \code{\link{Graph}}s to be joined. Elements must be convertible to \code{\link{Graph}} or \code{\link{PipeOp}} using \code{\link[=as_graph]{as_graph()}} and \code{\link[=as_pipeop]{as_pipeop()}}. +\code{NULL} is the neutral element of \code{\link{\%>>\%}} and skipped. When this is given, \code{g1} and \code{g2} must not be given.} + \item{in_place}{(\code{logical(1)})\cr -Whether to try to avoid cloning \code{g1}. If \code{g1} is not a \code{\link{Graph}}, then it is cloned regardless.} +When \code{g1} and \code{g2} are given: Whether to try to avoid cloning \code{g1}. If \code{g1} is not a \code{\link{Graph}}, then it is cloned regardless.\n +When \code{glist} is given instead: +Whether to try to avoid cloning the first element of \code{glist}, similar to the difference +of \verb{\%>>>\%} over \verb{\%>>\%}. This can only be avoided if \code{glist[[1]]} is already a \code{\link{Graph}}. +Beware that, when \code{in_place} is \code{TRUE} and if \code{concat_graphs()} fails because of id collisions, then \code{glist[[1]]} will possibly be in an incompletely +modified state.} } \value{ \code{\link{Graph}}: the constructed \code{\link{Graph}}. @@ -65,6 +74,22 @@ it is cloned just as when \verb{\%>>\%} is used; \verb{\%>>>\%} only avoids \cod Note that if \code{g1} is \code{NULL}, \code{g2} converted to a \code{\link{Graph}} will be returned. Analogously, if \code{g2} is \code{NULL}, \code{g1} converted to a \code{\link{Graph}} will be returned. } +\section{Chaining Graphs}{ + +\code{concat_graphs} can also be called with the \code{glist} argument, which takes an arbitrary amount of \code{\link{Graph}}s or \code{\link{PipeOp}}s (or objects that can be automatically +converted into \code{\link{Graph}}s or \code{\link{PipeOp}}s) and joins them in a serial \code{\link{Graph}}, as if connecting them using \code{\link{\%>>\%}}. + +Care is taken to avoid unnecessarily cloning of components. A call of +\code{chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)} is equivalent to +\code{g1 \%>>\% g2 \%>>>\% g3 \%>>>\% g4 \%>>>\% ...}. +A call of \code{chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)} +is equivalent to \code{g1 \%>>>\% g2 \%>>>\% g3 \%>>>\% g4 \%>>>\% ...} (differing in the +first operator being \verb{\%>>>\%} as well). + +\verb{concat_graphs(glist = )} (implicitly with \code{in_place = FALSE}) is a safe way of generating large linear pipelines quickly, while +still avoiding to change any of its inputs by reference, and avoiding the risk of ending up with broken objects. +} + \examples{ o1 = PipeOpScale$new() o2 = PipeOpPCA$new() diff --git a/man/mlr3pipelines-package.Rd b/man/mlr3pipelines-package.Rd index cc209fe4c..0084f20bb 100644 --- a/man/mlr3pipelines-package.Rd +++ b/man/mlr3pipelines-package.Rd @@ -8,11 +8,7 @@ \description{ \if{html}{\figure{logo.png}{options: align='right' alt='logo' width='120'}} -Dataflow programming toolkit that enriches 'mlr3' with a diverse - set of pipelining operators ('PipeOps') that can be composed into graphs. - Operations exist for data preprocessing, model fitting, and ensemble - learning. Graphs can themselves be treated as 'mlr3' 'Learners' and can - therefore be resampled, benchmarked, and tuned. +Dataflow programming toolkit that enriches 'mlr3' with a diverse set of pipelining operators ('PipeOps') that can be composed into graphs. Operations exist for data preprocessing, model fitting, and ensemble learning. Graphs can themselves be treated as 'mlr3' 'Learners' and can therefore be resampled, benchmarked, and tuned. } \seealso{ Useful links: diff --git a/man/mlr_pipeops_tunethreshold.Rd b/man/mlr_pipeops_tunethreshold.Rd index 56947c7ef..210ac5361 100644 --- a/man/mlr_pipeops_tunethreshold.Rd +++ b/man/mlr_pipeops_tunethreshold.Rd @@ -19,7 +19,7 @@ Returns a single \code{\link[mlr3:PredictionClassif]{PredictionClassif}}. This PipeOp should be used in conjunction with \code{\link{PipeOpLearnerCV}} in order to optimize thresholds of cross-validated predictions. In order to optimize thresholds without cross-validation, use \code{\link{PipeOpLearnerCV}} -in conjunction with \code{\link[mlr3:mlr_resamplings_insample]{ResamplingInsample}}. +in conjunction with \code{\link[mlr3:ResamplingInsample]{ResamplingInsample}}. } \section{Construction}{ \preformatted{* `PipeOpTuneThreshold$new(id = "tunethreshold", param_vals = list())` \\cr From 52234577bbb193d49dcad853857f2f6a00c5ca8b Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 16:44:59 +0200 Subject: [PATCH 09/16] gunion and concat_graphs performance enhancement --- R/Graph.R | 13 ++++++++----- R/gunion.R | 11 +++++++---- R/operators.R | 10 +++++++--- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/R/Graph.R b/R/Graph.R index d05f5b83d..3f7d4c9e4 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -485,7 +485,7 @@ graph_channels = function(ids, channels, pipeops, direction) { return(data.table(name = character(), train = character(), predict = character(), op.id = character(), channel.name = character())) } - map_dtr(pipeops, function(po) { + ret = map_dtr(pipeops, function(po) { # Note: This uses data.frame and is 20% faster than the fastest data.table I could come up with # (and factor 2 faster than a naive data.table implementation below). @@ -494,10 +494,7 @@ graph_channels = function(ids, channels, pipeops, direction) { df = as.data.frame(po[[direction]], stringsAsFactors = FALSE) rows = df$name %nin% channels[ids == po$id] if (!any(rows)) { - return(data.frame(name = character(), - train = character(), predict = character(), - op.id = character(), channel.name = character(), - stringsAsFactors = FALSE)) + return(NULL) } df$op.id = po$id df = df[rows, @@ -506,6 +503,12 @@ graph_channels = function(ids, channels, pipeops, direction) { names(df)[5] = "channel.name" df }) + + if (!nrow(ret)) { + return(data.table(name = character(), train = character(), predict = character(), + op.id = character(), channel.name = character())) + } + ret } graph_channels_dt = function(ids, channels, pipeops, direction) { diff --git a/R/gunion.R b/R/gunion.R index ba4063e81..ab968b207 100644 --- a/R/gunion.R +++ b/R/gunion.R @@ -17,26 +17,29 @@ #' `NULL` values automatically get converted to [`PipeOpNOP`] with a random ID of the format `nop_********`. #' The list can be named, in which case the #' IDs of the elements are prefixed with the names, separated by a dot (`.`). -#' @param in_place (`logical(1)`)\cr +#' @param in_place (`logical(1)` | `logical`)\cr #' Whether to try to avoid cloning the first element of `graphs`, similar to the difference #' of [`%>>>%`] over [`%>>%`]. This can only be avoided if `graphs[[1]]` is already a [`Graph`].\cr #' Unlike [`chain_graphs()`], `gunion()` does all checks *before* mutating `graphs[[1]]`, so it will not leave `graphs[[1]]` -#' in an incompletely modified state when it fails. +#' in an incompletely modified state when it fails.\cr +#' `in_place` may also be of length `graph`, in which case it determines for each element of `graphs` whether it is cloned. +#' This is for internal usage and is not recommended. #' @return [`Graph`] the resulting [`Graph`]. #' #' @family Graph operators #' @export gunion = function(graphs, in_place = FALSE) { assert_list(graphs) + assert(check_flag(in_place), check_logical(in_place, any.missing = FALSE, len = length(graphs))) if (length(graphs) == 0) return(Graph$new()) graphs = map_if(graphs, is.null, function(x) po("nop", id = paste0("nop_", paste(sample(c(letters, 0:9), 8, TRUE), collapse = "")))) - do_clone = c(!in_place, rep(TRUE, length(graphs) - 1)) + do_clone = if (length(in_place) == length(graphs)) !in_place else c(!in_place, rep(TRUE, length(graphs) - 1)) graphs = structure(pmap(list(x = graphs, clone = do_clone), as_graph), names = names(graphs)) graphs = Filter(function(x) length(x$pipeops), graphs) if (length(graphs) == 0) return(Graph$new()) - if (in_place) { + if (in_place[[1]]) { g = graphs[[1]] g$.__enclos_env__$private$.param_set = NULL # clear param_set cache } else { diff --git a/R/operators.R b/R/operators.R index e83eef214..816f77c45 100644 --- a/R/operators.R +++ b/R/operators.R @@ -103,6 +103,8 @@ #' o1 %>>>% o2 #' #' o1 # not changed, becuase not a Graph. +#' +#' concat_graphs(glist = list(o1, o2, o3)) concat_graphs = function(g1, g2, glist, in_place = FALSE) { if (!missing(glist)) { if (!missing(g1) || !missing(g2)) stop("When glist is given, g1 and g2 must not be given") @@ -113,8 +115,10 @@ concat_graphs = function(g1, g2, glist, in_place = FALSE) { if (is.null(g1)) return(if (!is.null(g2)) as_graph(g2, clone = TRUE)) if (is.null(g2)) return(as_graph(g1, clone = !in_place)) - g1 = as_graph(g1) - g2 = as_graph(g2) + # one idea would be to not clone here, and let `gunion()` decide whether to clone. However, + # that would lead to `PipeOp`s being cloned twice, so we clone here explicitly and tell gunion to do things in-place. + g1 = as_graph(g1, clone = !in_place) + g2 = as_graph(g2, clone = TRUE) g1out = g1$output g2in = g2$input if (nrow(g1out) != 1 && nrow(g1out) != nrow(g2in) && !(nrow(g2in) == 1 && g2in$channel.name == "...")) { @@ -140,7 +144,7 @@ concat_graphs = function(g1, g2, glist, in_place = FALSE) { new_edges = cbind(g1out[, list(src_id = get("op.id"), src_channel = get("channel.name"))], g2in[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))]) - g = gunion(list(g1, g2), in_place = in_place) + g = gunion(list(g1, g2), in_place = c(TRUE, TRUE)) # at this point graphs are already cloned. g$edges = rbind(g$edges, new_edges) g } From 9c3d9b7984f42cfbc47dc49b986f77bcbc4e347c Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 22:01:05 +0200 Subject: [PATCH 10/16] robustify against unfortunate data.table column names --- R/DataBackendJoin.R | 12 +++++++----- R/DataBackendMultiCbind.R | 14 ++++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/R/DataBackendJoin.R b/R/DataBackendJoin.R index 110efe7fb..208a1f0dc 100644 --- a/R/DataBackendJoin.R +++ b/R/DataBackendJoin.R @@ -38,7 +38,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = index_table = merge(data.table(rownames_b1, joinby_b1), data.table(rownames_b2, joinby_b2), by.x = "joinby_b1", by.y = "joinby_b2", all.x = type %in% c("left", "outer"), all.y = type %in% c("right", "outer"), sort = FALSE, allow.cartesian = TRUE) - index_table[, "joinby_b1" := NULL] + set(index_table, , "joinby_b1", NULL) pk = "..row_id" index = 0 @@ -77,26 +77,27 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = } setkeyv(data, NULL) if (d$pk %in% cols) { - data[, (d$pk) := rows] + set(data, , d$pk, rows) } if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) { rownames_b2 = indices$rownames_b2 - data[, (d$b2_index_colname) := rownames_b2] + set(data, , d$b2_index_colname, rownames_b2) } if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) { rownames_b1 = indices$rownames_b1 - data[, (d$b1_index_colname) := rownames_b1] + set(data, ,d$b1_index_colname, rownames_b1) } data[, intersect(cols, names(data)), with = FALSE] }, head = function(n = 6L) { - rows = head(self$rownames, n) + rows = first(self$rownames, n) self$data(rows = rows, cols = self$colnames) }, distinct = function(rows, cols, na_rm = TRUE) { d = private$.data indices = d$index_table[rows] + rownames_b1 = rownames_b2 = NULL b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] d2 = private$.data$b2$distinct(rows = b2_rows, cols = cols, na_rm = na_rm) @@ -124,6 +125,7 @@ DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = missings = function(rows, cols) { d = private$.data indices = d$index_table[rows] + rownames_b1 = rownames_b2 = NULL b1_rows = indices[!is.na(rownames_b1), rownames_b1] b2_rows = indices[!is.na(rownames_b2), rownames_b2] m2 = private$.data$b2$missings(b2_rows, cols) diff --git a/R/DataBackendMultiCbind.R b/R/DataBackendMultiCbind.R index 1b6ed4ebb..d62136a69 100644 --- a/R/DataBackendMultiCbind.R +++ b/R/DataBackendMultiCbind.R @@ -53,18 +53,20 @@ DataBackendMultiCbind = R6Class("DataBackendMultiCbind", inherit = DataBackend, allrows[[i]] = datas[[i]][[pk]] } presentrows = unique(unlist(allrows)) + join = list(presentrows) result = do.call(cbind, pmap(list(datas, pks, include_pk), function(data, pk, include) { if (include) { - result = data[J(presentrows), on = pk, nomatch = NA] - droppk = result[[pk]] %nin% data[[pk]] - result[droppk, (pk) := NA] + result = data[join, on = pk, nomatch = NA] + set(result, result[[pk]] %nin% data[[pk]], pk, NA) } else { - data[J(presentrows), -pk, on = pk, with = FALSE, nomatch = NA] + data[join, -pk, on = pk, with = FALSE, nomatch = NA] } })) sbk = self$primary_key - result[, (sbk) := presentrows] - result[J(rows), intersect(cols, colnames(result)), with = FALSE, on = sbk, nomatch = NULL] + + set(result, , sbk, presentrows) + join = list(rows) + result[join, intersect(cols, colnames(result)), with = FALSE, on = sbk, nomatch = NULL] }, head = function(n = 6L) { rows = head(self$rownames, n) From ba62dc9536b478b308ae1a8565b69f0b730fce89 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 4 Oct 2021 22:01:20 +0200 Subject: [PATCH 11/16] document POFU plans --- R/PipeOpFeatureUnion.R | 63 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/R/PipeOpFeatureUnion.R b/R/PipeOpFeatureUnion.R index 8379d76c5..a7e1d7bee 100644 --- a/R/PipeOpFeatureUnion.R +++ b/R/PipeOpFeatureUnion.R @@ -12,8 +12,11 @@ #' across all [`Task`][mlr3::Task]s. Only the target column(s) of the first [`Task`][mlr3::Task] #' are kept. #' -#' If `assert_targets_equal` is `TRUE` then target column names are compared and an error is thrown -#' if they differ across inputs. +#' `PipeOpFeatureUnion` tries to merge columns that are identical, while preventing accidental +#' overwrites of columns that contain differing data. This is controlled using the `feature_clash` +#' (for columns containing features, weights etc.) and `target_clash` (for tharget columns) +#' hyperparameters. The `assert_target_equal` construction parameter / field can still be used +#' as well but is deprecated and will generate a warning. #' #' If input tasks share some feature names but these features are not identical an error is thrown. #' This check is performed by first comparing the features names and if duplicates are found, also @@ -41,6 +44,7 @@ #' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise #' be set during construction. Default `list()`. #' * `assert_targets_equal` :: `logical(1)`\cr +#' DEPRECATED; use `target_clash` hyperparameter instead.\cr #' If `assert_targets_equal` is `TRUE` (Default), task target column names are checked for #' agreement. Disagreeing target column names are usually a bug, so this should often be left at #' the default. @@ -61,7 +65,33 @@ #' The `$state` is left empty (`list()`). #' #' @section Parameters: -#' [`PipeOpFeatureUnion`] has no Parameters. +#' * `target_clash` :: `character(1)`\cr +#' How to handle target columns that differ between input [`Task`][mlr3::Task]s. `"allow_same_hash"` +#' checks the names and `$col_hashes` and throws an error if they disagree. `"allow_same_content"` (default) is +#' more permissive: If `$col_hashes` disagree, then it checks the target content, if the content of both +#' columns agree, then merging of tasks is still allowed. This avoids some rare false-positives, but in cases +#' where hashes *do* disagree this may be slow for [`Task`][mlr3::Task]s with many rows or targets. +#' `"ignore"` does not check for target agreement and overwrites the target with the target of the *rightmost* / +#' highest numbered input [`Task`][mlr3::Task]. Use with caution. This is the only option that allows feature-union of [`Task`][mlr3::Task]s +#' that differ in the names of their target column (and all target columns except the rightmost / highest numbered input +#' [`Task`][mlr3::Task]'s target are dropped in that case).\cr +#' The deprecated field `assert_targets_equal` sets this value to `"allow_same_content"` (i.e. default) when `TRUE` and to +#' `"ignore"` when `FALSE`. +#' * `feature_clash` :: `character(1)`\cr +#' How to handle non-target columns that have the same name but differ between input [`Task`][mlr3::Task]s. `"allow_same_hash"` +#' checks the names and `$col_hashes` and throws an error if they disagree. `"allow_same_content"` (default) is +#' more permissive: If `$col_hashes` disagree, then it checks the column content, if the content of both +#' columns agree, then merging of tasks is still allowed. This avoids some rare false-positives, but in cases +#' where hashes *do* disagree this may be slow for large [`Task`][mlr3::Task]s. +#' `"ignore"` does not check for column data agreement and overwrites columns of the same name with the values of the *rightmost* / +#' highest numbered input [`Task`][mlr3::Task].\cr +#' Some column roles (`"group"`, `"weight"`, `"name"`) do not allow more than one column role present in a [`Task`][mlr3::Task] (see +#' `$col_roles` documentation there). When up to one [`Task`][mlr3::Task] has a column of these column role, it is taken for the +#' resulting [`Task`][mlr3::Task] without any issue. When more than one [`Task`][mlr3::Task] has a column with one of these roles, +#' but with the same name, the `feature_clash` policy applies as described above. When more than one [`Task`][mlr3::Task] has a +#' column with one of these roles, but they have *different* names, then an error is thrown when `feature_clash` is not `"ignore"`. +#' When it is `"ignore"`, the *rightmost* / highest numbered input [`Task`][mlr3::Task]'s column is used and all others of this +#' role are discarded. #' #' @section Internals: #' [`PipeOpFeatureUnion`] uses the [`Task`][mlr3::Task] `$cbind()` method to bind the input values @@ -99,21 +129,26 @@ PipeOpFeatureUnion = R6Class("PipeOpFeatureUnion", inherit = PipeOp, public = list( - assert_targets_equal = NULL, + inprefix = NULL, initialize = function(innum = 0L, collect_multiplicity = FALSE, id = "featureunion", param_vals = list(), assert_targets_equal = TRUE) { assert( check_int(innum, lower = 0L), check_character(innum, min.len = 1L, any.missing = FALSE) ) + params = ps( + target_clash = p_fct(c("allow_same_hash", "allow_same_content", "ignore")), + feature_clash = p_fct(c("forbid", "allow_same_hash", "allow_same_content", "ignore")) + ) + params$values = list(target_clash = "allow_same_content", feature_clash = "allow_same_content") + if (is.numeric(innum)) { self$inprefix = rep("", innum) } else { self$inprefix = innum innum = length(innum) } - assert_flag(assert_targets_equal) - self$assert_targets_equal = assert_targets_equal + inname = if (innum) rep_suffix("input", innum) else "..." intype = "Task" private$.collect = assert_flag(collect_multiplicity) @@ -129,9 +164,25 @@ PipeOpFeatureUnion = R6Class("PipeOpFeatureUnion", output = data.table(name = "output", train = "Task", predict = "Task"), tags = "ensemble" ) + + # the following is DEPRECATED + if (!missing(assert_targets_equal)) { + # do this after init so the AB can modify self$param_set + assert_flag(assert_targets_equal) + self$assert_targets_equal = assert_targets_equal + } + } + ), + active = list( + assert_targets_equal = function(rhs) { + if (!missing(rhs)) private$.assert_targets_equal = rhs + self$param_set$values$target_clash = if (private$.assert_targets_equal) "allow_same_content" else "ignore" + warning("PipeOpFeatureUnion assert_targets_equal is deprecated. Use the 'target_clash' hyperparameter.") + private$.assert_targets_equal } ), private = list( + .assert_targets_equal = NULL, .train = function(inputs) { self$state = list() if (private$.collect) inputs = unclass(inputs[[1]]) From 5a1debd8c89af7b0abb0d39f70c1d375038b0241 Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 5 Oct 2021 09:57:03 +0200 Subject: [PATCH 12/16] %>>>% --> %>>!% --- NAMESPACE | 2 +- R/gunion.R | 2 +- R/operators.R | 44 +++++++++++++++---------------- R/pipeline_bagging.R | 6 ++--- R/pipeline_branch.R | 2 +- R/pipeline_ovr.R | 2 +- R/pipeline_stacking.R | 4 +-- tests/testthat/test_doublearrow.R | 6 ++--- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index d94399531..af09ced9f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -24,7 +24,7 @@ S3method(predict,Graph) S3method(print,Multiplicity) S3method(print,Selector) export("%>>%") -export("%>>>%") +export("%>>!%") export(DataBackendJoin) export(DataBackendMultiCbind) export(Graph) diff --git a/R/gunion.R b/R/gunion.R index ab968b207..36aab005c 100644 --- a/R/gunion.R +++ b/R/gunion.R @@ -19,7 +19,7 @@ #' IDs of the elements are prefixed with the names, separated by a dot (`.`). #' @param in_place (`logical(1)` | `logical`)\cr #' Whether to try to avoid cloning the first element of `graphs`, similar to the difference -#' of [`%>>>%`] over [`%>>%`]. This can only be avoided if `graphs[[1]]` is already a [`Graph`].\cr +#' of [`%>>!%`] over [`%>>%`]. This can only be avoided if `graphs[[1]]` is already a [`Graph`].\cr #' Unlike [`chain_graphs()`], `gunion()` does all checks *before* mutating `graphs[[1]]`, so it will not leave `graphs[[1]]` #' in an incompletely modified state when it fails.\cr #' `in_place` may also be of length `graph`, in which case it determines for each element of `graphs` whether it is cloned. diff --git a/R/operators.R b/R/operators.R index 816f77c45..b87d3653b 100644 --- a/R/operators.R +++ b/R/operators.R @@ -5,7 +5,7 @@ #' Both source and sink can either be #' a [`Graph`] or a [`PipeOp`] (or an object that can be automatically converted into a [`Graph`] or [`PipeOp`], see [`as_graph()`] and [`as_pipeop()`]). #' -#' `%>>%` and `%>>>%` try to automatically match output channels of `g1` to input channels of `g2`; this is only possible if either +#' `%>>%` and `%>>!%` try to automatically match output channels of `g1` to input channels of `g2`; this is only possible if either #' * the number of output channels of `g1` (as given by `g1$output`) is equal to the #' number of input channels of `g2` (as given by `g2$input`), or #' * `g1` has only one output channel (i.e. `g1$output` has one line), or @@ -17,23 +17,23 @@ #' #' `%>>%` always creates deep copies of its input arguments, so they cannot be modified by reference afterwards. #' To access individual [`PipeOp`]s after composition, use the resulting [`Graph`]'s `$pipeops` list. -#' `%>>>%`, on the other hand, tries to avoid cloning its first argument: If it is a [`Graph`], then this [`Graph`] +#' `%>>!%`, on the other hand, tries to avoid cloning its first argument: If it is a [`Graph`], then this [`Graph`] #' will be modified in-place. #' -#' When `%>>>%` fails, then it leaves `g1` in an incompletely modified state. It is therefore usually recommended to use +#' When `%>>!%` fails, then it leaves `g1` in an incompletely modified state. It is therefore usually recommended to use #' `%>>%`, since the very marginal gain of performance from -#' using `%>>>%` often does not outweigh the risk of either modifying objects by-reference that should not be modified or getting +#' using `%>>!%` often does not outweigh the risk of either modifying objects by-reference that should not be modified or getting #' graphs that are in an incompletely modified state. However, -#' when creating long [`Graph`]s, chaining with `%>>>%` instead of `%>>%` can give noticeable performance benefits -#' because `%>>%` makes a number of `clone()`-calls that is quadratic in chain length, `%>>>%` only linear. +#' when creating long [`Graph`]s, chaining with `%>>!%` instead of `%>>%` can give noticeable performance benefits +#' because `%>>%` makes a number of `clone()`-calls that is quadratic in chain length, `%>>!%` only linear. #' -#' `concat_graphs(g1, g2, in_place = FALSE)` is equivalent to `g1 %>>% g2`. `concat_graphs(g1, g2, in_place = TRUE)` is equivalent to `g1 %>>>% g2`. +#' `concat_graphs(g1, g2, in_place = FALSE)` is equivalent to `g1 %>>% g2`. `concat_graphs(g1, g2, in_place = TRUE)` is equivalent to `g1 %>>!% g2`. #' #' Both arguments of `%>>%` are automatically converted to [`Graph`]s using [`as_graph()`]; this means that objects on either side may be objects #' that can be automatically converted to [`PipeOp`]s (such as [`Learner`][mlr3::Learner]s or [`Filter`][mlr3filters::Filter]s), or that can #' be converted to [`Graph`]s. This means, in particular, `list`s of [`Graph`]s, [`PipeOp`]s or objects convertible to that, because -#' [`as_graph()`] automatically applies [`gunion()`] to `list`s. See examples. If the first argument of `%>>>%` is not a [`Graph`], then -#' it is cloned just as when `%>>%` is used; `%>>>%` only avoids `clone()` if the first argument is a [`Graph`]. +#' [`as_graph()`] automatically applies [`gunion()`] to `list`s. See examples. If the first argument of `%>>!%` is not a [`Graph`], then +#' it is cloned just as when `%>>%` is used; `%>>!%` only avoids `clone()` if the first argument is a [`Graph`]. #' #' Note that if `g1` is `NULL`, `g2` converted to a [`Graph`] will be returned. #' Analogously, if `g2` is `NULL`, `g1` converted to a [`Graph`] will be returned. @@ -44,10 +44,10 @@ #' #' Care is taken to avoid unnecessarily cloning of components. A call of #' `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` is equivalent to -#' `g1 %>>% g2 %>>>% g3 %>>>% g4 %>>>% ...`. +#' `g1 %>>% g2 %>>!% g3 %>>!% g4 %>>!% ...`. #' A call of `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` -#' is equivalent to `g1 %>>>% g2 %>>>% g3 %>>>% g4 %>>>% ...` (differing in the -#' first operator being `%>>>%` as well). +#' is equivalent to `g1 %>>!% g2 %>>!% g3 %>>!% g4 %>>!% ...` (differing in the +#' first operator being `%>>!%` as well). #' #' `concat_graphs(glist = )` (implicitly with `in_place = FALSE`) is a safe way of generating large linear pipelines quickly, while #' still avoiding to change any of its inputs by reference, and avoiding the risk of ending up with broken objects. @@ -63,7 +63,7 @@ #' When `g1` and `g2` are given: Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless.\n #' When `glist` is given instead: #' Whether to try to avoid cloning the first element of `glist`, similar to the difference -#' of `%>>>%` over `%>>%`. This can only be avoided if `glist[[1]]` is already a [`Graph`]. +#' of `%>>!%` over `%>>%`. This can only be avoided if `glist[[1]]` is already a [`Graph`]. #' Beware that, when `in_place` is `TRUE` and if `concat_graphs()` fails because of id collisions, then `glist[[1]]` will possibly be in an incompletely #' modified state. #' @@ -96,11 +96,11 @@ #' add_edge(o1$id, o3$id, dst_channel = 1)$ #' add_edge(o2$id, o3$id, dst_channel = 2) #' -#' pipe1 %>>>% o3 # modify pipe1 in-place +#' pipe1 %>>!% o3 # modify pipe1 in-place #' #' pipe1 # contains o1, o2, and o3 now. #' -#' o1 %>>>% o2 +#' o1 %>>!% o2 #' #' o1 # not changed, becuase not a Graph. #' @@ -157,7 +157,7 @@ concat_graphs = function(g1, g2, glist, in_place = FALSE) { #' @rdname concat_graphs #' @export -`%>>>%` = function(g1, g2) { +`%>>!%` = function(g1, g2) { concat_graphs(g1, g2, in_place = TRUE) } @@ -174,10 +174,10 @@ strip_multiplicity_type = function(type) { #' #' Care is taken to avoid unnecessarily cloning of components. A call of #' `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` is equivalent to -#' `g1 %>>% g2 %>>>% g3 %>>>% g4 %>>>% ...`. +#' `g1 %>>% g2 %>>!% g3 %>>!% g4 %>>!% ...`. #' A call of `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` -#' is equivalent to `g1 %>>>% g2 %>>>% g3 %>>>% g4 %>>>% ...` (differing in the -#' first operator being `%>>>%` as well). +#' is equivalent to `g1 %>>!% g2 %>>!% g3 %>>!% g4 %>>!% ...` (differing in the +#' first operator being `%>>!%` as well). #' #' @param graphs `list` of ([`Graph`] | [`PipeOp`] | `NULL` | `...`)\cr #' List of elements which are the @@ -185,7 +185,7 @@ strip_multiplicity_type = function(type) { #' `NULL` is the neutral element of [`%>>%`] and skipped. #' @param in_place (`logical(1)`)\cr #' Whether to try to avoid cloning the first element of `graphs`, similar to the difference -#' of [`%>>>%`] over [`%>>%`]. This can only be avoided if `graphs[[1]]` is already a [`Graph`]. +#' of [`%>>!%`] over [`%>>%`]. This can only be avoided if `graphs[[1]]` is already a [`Graph`]. #' Beware that, if `chain_graphs()` fails because of id collisions, then `graphs[[1]]` will possibly be in an incompletely #' modified state when `in_place` is `TRUE`. #' @return [`Graph`] the resulting [`Graph`], or `NULL` if there are no non-null values in `graphs`. @@ -197,8 +197,8 @@ chain_graphs = function(graphs, in_place = FALSE) { if (!in_place) { # all except the first graph get cloned, so if we are in_place, # we only need to take care to clone it. We convert it to a Graph, - # so `%>>>%` will not clone it again. + # so `%>>!%` will not clone it again. graphs[[1]] = as_graph(graphs[[1]], clone = TRUE) } - Reduce(`%>>>%`, graphs) + Reduce(`%>>!%`, graphs) } diff --git a/R/pipeline_bagging.R b/R/pipeline_bagging.R index 03eeab75a..afcf0c7f9 100644 --- a/R/pipeline_bagging.R +++ b/R/pipeline_bagging.R @@ -49,9 +49,9 @@ pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) averager = as_graph(averager, clone = TRUE) } - po("replicate", param_vals = list(reps = iterations)) %>>>% - po("subsample", param_vals = list(frac = frac)) %>>>% - g %>>>% + po("replicate", param_vals = list(reps = iterations)) %>>!% + po("subsample", param_vals = list(frac = frac)) %>>!% + g %>>!% averager } diff --git a/R/pipeline_branch.R b/R/pipeline_branch.R index 2ebd9b1a6..32b923a90 100644 --- a/R/pipeline_branch.R +++ b/R/pipeline_branch.R @@ -78,7 +78,7 @@ pipeline_branch = function(graphs, prefix_branchops = "", prefix_paths = FALSE) poname_prefix = "" } - graph = gunion(graphs) %>>>% PipeOpUnbranch$new(branches, id = paste0(prefix_branchops, "unbranch")) + graph = gunion(graphs) %>>!% PipeOpUnbranch$new(branches, id = paste0(prefix_branchops, "unbranch")) branch_id = paste0(prefix_branchops, "branch") po_branch = PipeOpBranch$new(branches, id = branch_id) diff --git a/R/pipeline_ovr.R b/R/pipeline_ovr.R index 4127546c7..0fdc4b2a6 100644 --- a/R/pipeline_ovr.R +++ b/R/pipeline_ovr.R @@ -43,7 +43,7 @@ #' g3$train(task) #' g3$predict(task) pipeline_ovr = function(graph) { - PipeOpOVRSplit$new() %>>>% graph %>>>% PipeOpOVRUnite$new() + PipeOpOVRSplit$new() %>>!% graph %>>!% PipeOpOVRUnite$new() } mlr_graphs$add("ovr", pipeline_ovr) diff --git a/R/pipeline_stacking.R b/R/pipeline_stacking.R index 87fe0ae08..f80b7645b 100644 --- a/R/pipeline_stacking.R +++ b/R/pipeline_stacking.R @@ -52,8 +52,8 @@ pipeline_stacking = function(base_learners, super_learner, method = "cv", folds if (use_features) base_learners_cv = c(base_learners_cv, po("nop")) - gunion(base_learners_cv, in_place = TRUE) %>>>% - po("featureunion") %>>>% + gunion(base_learners_cv, in_place = TRUE) %>>!% + po("featureunion") %>>!% super_learner } diff --git a/tests/testthat/test_doublearrow.R b/tests/testthat/test_doublearrow.R index c1370327b..2e142d84c 100644 --- a/tests/testthat/test_doublearrow.R +++ b/tests/testthat/test_doublearrow.R @@ -89,7 +89,7 @@ test_that("triple-arrow", { p2graph1 = as_graph(p2) - gr2 = gr %>>>% p2graph1 + gr2 = gr %>>!% p2graph1 expect_equal(gr2$pipeops, list(p1 = p1, p2 = p2)) @@ -100,7 +100,7 @@ test_that("triple-arrow", { expect_deep_clone(gr2$pipeops$p2, p2) - gr3 = gr %>>>% p3 + gr3 = gr %>>!% p3 expect_identical(gr3, gr) expect_identical(gr2, gr) @@ -143,7 +143,7 @@ test_that("triple-arrow", { # not mutable in-place - gr = p1 %>>>% p2graph1 + gr = p1 %>>!% p2graph1 expect_deep_clone(p1, PipeOpNOP$new("p1")) expect_deep_clone(p2graph1, as_graph(p2)) From 97b907f2e7682cc048d16105bc897ada6942a6f1 Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 5 Oct 2021 10:46:03 +0200 Subject: [PATCH 13/16] Graph() --- R/Graph.R | 10 ++++++++++ R/operators.R | 30 ++---------------------------- 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/R/Graph.R b/R/Graph.R index 3f7d4c9e4..467db209f 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -81,6 +81,11 @@ #' channel `dst_channel` (identified by its name or number as listed in the [`PipeOp`]'s `$input`). #' If source or destination [`PipeOp`] have only one input / output channel and `src_channel` / `dst_channel` #' are therefore unambiguous, they can be omitted (i.e. left as `NULL`). +#' * `chain(gs, clone = TRUE)` \cr +#' (`list` of `Graph`s, `logical(1)`) -> `self` \cr +#' Takes a list of `Graph`s or [`PipeOp`]s (or objects that can be automatically converted into `Graph`s or [`PipeOp`]s, +#' see [`as_graph()`] and [`as_pipeop()`]) as inputs and joins them in a serial `Graph` coming after `self`, as if +#' connecting them using [`%>>%`]. #' * `plot(html)` \cr #' (`logical(1)`) -> `NULL` \cr #' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or @@ -248,6 +253,11 @@ Graph = R6Class("Graph", invisible(self) }, + chain = function(gs, clone = TRUE) { + assert_list(gs) + chain_graphs(c(list(gs), gs), in_place = TRUE) + }, + plot = function(html = FALSE) { assert_flag(html) if (!length(self$pipeops)) { diff --git a/R/operators.R b/R/operators.R index b87d3653b..e2b97641c 100644 --- a/R/operators.R +++ b/R/operators.R @@ -38,34 +38,12 @@ #' Note that if `g1` is `NULL`, `g2` converted to a [`Graph`] will be returned. #' Analogously, if `g2` is `NULL`, `g1` converted to a [`Graph`] will be returned. #' -#' @section Chaining Graphs: -#' `concat_graphs` can also be called with the `glist` argument, which takes an arbitrary amount of [`Graph`]s or [`PipeOp`]s (or objects that can be automatically -#' converted into [`Graph`]s or [`PipeOp`]s) and joins them in a serial [`Graph`], as if connecting them using [`%>>%`]. -#' -#' Care is taken to avoid unnecessarily cloning of components. A call of -#' `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` is equivalent to -#' `g1 %>>% g2 %>>!% g3 %>>!% g4 %>>!% ...`. -#' A call of `chain_graphs(list(g1, g2, g3, g4, ...), in_place = FALSE)` -#' is equivalent to `g1 %>>!% g2 %>>!% g3 %>>!% g4 %>>!% ...` (differing in the -#' first operator being `%>>!%` as well). -#' -#' `concat_graphs(glist = )` (implicitly with `in_place = FALSE`) is a safe way of generating large linear pipelines quickly, while -#' still avoiding to change any of its inputs by reference, and avoiding the risk of ending up with broken objects. -#' #' @param g1 ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`) \cr #' [`Graph`] / [`PipeOp`] / object-convertible-to-[`PipeOp`] to put in front of `g2`. #' @param g2 ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`) \cr #' [`Graph`] / [`PipeOp`] / object-convertible-to-[`PipeOp`] to put after `g1`. -#' @param glist `list` of ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`)\cr -#' List of elements which are the [`Graph`]s to be joined. Elements must be convertible to [`Graph`] or [`PipeOp`] using [`as_graph()`] and [`as_pipeop()`]. -#' `NULL` is the neutral element of [`%>>%`] and skipped. When this is given, `g1` and `g2` must not be given. #' @param in_place (`logical(1)`)\cr -#' When `g1` and `g2` are given: Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless.\n -#' When `glist` is given instead: -#' Whether to try to avoid cloning the first element of `glist`, similar to the difference -#' of `%>>!%` over `%>>%`. This can only be avoided if `glist[[1]]` is already a [`Graph`]. -#' Beware that, when `in_place` is `TRUE` and if `concat_graphs()` fails because of id collisions, then `glist[[1]]` will possibly be in an incompletely -#' modified state. +#' Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless.\n #' #' @return [`Graph`]: the constructed [`Graph`]. #' @family Graph operators @@ -105,11 +83,7 @@ #' o1 # not changed, becuase not a Graph. #' #' concat_graphs(glist = list(o1, o2, o3)) -concat_graphs = function(g1, g2, glist, in_place = FALSE) { - if (!missing(glist)) { - if (!missing(g1) || !missing(g2)) stop("When glist is given, g1 and g2 must not be given") - return(chain_graphs(graphs = glist, in_place = in_place)) - } +concat_graphs = function(g1, g2, in_place = FALSE) { assert_flag(in_place) # neutral elements handling if (is.null(g1)) return(if (!is.null(g2)) as_graph(g2, clone = TRUE)) From 23f36d892cce362a65fcd268241d8cdb4b2c78c7 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 12 Aug 2024 12:10:12 +0200 Subject: [PATCH 14/16] clean up --- R/operators.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/operators.R b/R/operators.R index 2925599f4..2558f6dc4 100644 --- a/R/operators.R +++ b/R/operators.R @@ -43,7 +43,7 @@ #' @param g2 ([`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `list` | `...`) \cr #' [`Graph`] / [`PipeOp`] / object-convertible-to-[`PipeOp`] to put after `g1`. #' @param in_place (`logical(1)`)\cr -#' Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless.\n +#' Whether to try to avoid cloning `g1`. If `g1` is not a [`Graph`], then it is cloned regardless. #' #' @return [`Graph`]: the constructed [`Graph`]. #' @family Graph operators From 4c268d2c6a34d4ccf09cfee9e8f7a830396cfe39 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 12 Aug 2024 14:58:28 +0200 Subject: [PATCH 15/16] some notes --- attic/pofu.md | 187 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 attic/pofu.md diff --git a/attic/pofu.md b/attic/pofu.md new file mode 100644 index 000000000..7fa73aa3f --- /dev/null +++ b/attic/pofu.md @@ -0,0 +1,187 @@ +# POFU design document + +## Issues + +### 126: POFU could drop all targets except first task's one + +Instead of checking that all targets are the same and throwing an error if they are not. + +### 216: POFU with differing row IDs + +PipeOpFeatureUnion could under some circumstances want to unite tasks that have differing row IDs, e.g. after PipeOpSubsample on two different paths sampled (and `$filter()`ed) different sets of rows. + +```r +graph = greplicate(PipeOpSubsample$new() %>>% + PipeOpLearnerCV$new("classif.rpart"), 2) %>>% + PipeOpFeatureUnion$new() +graph$plot() # this is what it looks like + +graph$train("iris") # assertion error +``` +mlr-org/mlr3#309 could solve part of this, but the problem goes deeper: +* what if we do sampling with replacement? +* what if PipeOpLearnerCV has a resampling that predicts some entries multiple times, e.g. RepCV or bootstrapping? + +### 271: Use DataBackend info to avoid unnecessary data comparison + +Using col_hashes + +### 388: POFU should use DataBackend cbind + +but backends do not have info about col roles + +### 390: POFU assert_targets_equal parameter should go + +### 570: assertion on 'rows' failed + +branch with subsample on one end gives error + +```r +library(mlr3) +library(mlr3pipelines) + +task = tsk("iris") +resampling = rsmp("holdout") + +graph = gunion( + list( + po("pca") %>>% po("learner_cv", id = "featureless", lrn("classif.featureless")), + po("subsample") %>>% po("learner_cv", id = "rpart", lrn("classif.rpart"))) + ) %>>% + po("featureunion") %>>% + po("learner", lrn("classif.rpart")) + +resample(task, graph, resampling) + +``` + +- possibly improve error message +- what are the options here? fill with NAs? aggregate? + +### 571: POFU seems more broken now + +this does not seem to give an error any more: + +```r + pos = PipeOpSubsample$new() + pos$param_set$values$frac = 0.5 + g = pipeline_greplicate( + pos %>>% PipeOpPCA$new(), + 2 + ) %>>% PipeOpFeatureUnion$new(c("a", "b")) + task = mlr_tasks$get("iris") + expect_error(g$train(task), "Assertion on 'rows'") +``` + + +### 607: PipeOpPredictionUnion + +https://github.com/mlr-org/miesmuschel/blob/smashy_ex/R/PipeOpPredictionUnion.R + +name is confusing, since it is rbinding, not cbinding + +### 634: New Pipeop: Split data by row_ids / logical arg + +https://gist.github.com/pfistfl/6b190f0612535817bdd33fe8f8bd6548 + +- how do we combine this with zero-inflated things? + +### 646: Bootstrap resampling + +Apparently the problem is that bootstrapping uses some rows repeatedly, which somehow breaks with mlr3's assumption that row_ids are unique values. + +```r +library("mlr3") +library("mlr3pipelines") +options(mlr3.debug=TRUE) +resample(tsk("iris"), po("pca") %>>% lrn("classif.featureless"), rsmp("bootstrap")) +``` + + +### 696: PipeOpFeatureUnion breaks predict_newdata when all features of original task aver overwritten + + +```r +gr <- list(po("select", selector = selector_none()), po("nop")) %>>!% po("featureunion", innum = c("a", "")) %>>!% + { l <- lrn("classif.rpart") ; l$properties <- setdiff(l$properties, "missings") ; l } +gr$train(tsk("iris")) +#> $classif.rpart.output +#> NULL +#> +gr$predict(tsk("iris")) +#> $classif.rpart.output +#> for 150 observations: +#> row_ids truth response +#> 1 setosa setosa +#> 2 setosa setosa +#> 3 setosa setosa +#> --- +#> 148 virginica virginica +#> 149 virginica virginica +#> 150 virginica virginica + +lr <- as_learner(gr) +lr$train(tsk("iris")) +lr$predict_newdata(iris[1:4]) +#> Error in map_values(names(x), self$old, self$new) : +#> Assertion on 'x' failed: Must be of type 'atomic vector', not 'NULL'. +#> This happened PipeOp classif.rpart's $predict() +``` + +### 697: POFU should use feature_types as reported by input tasks and not the datatypes it gets from $data() (mlr-org/mlr3#685). + +- how does data conversion work between task and backend? + +### cbind backend simplification + + +## Notes + +- task filter: integer ids as reported by backend +- backend$data +- duplicated IDs, possibly problem with resampling +- col conversion + - "conversion should hapen" + - setequal factorlevel: convert, otherwise kA + - maybe happens when there are fewer levels than before (? -- check) + - predict-newdata + +## Synthesis + +### Mostly independent of data backend + +- Handling Tasks with different Targets (126) +- Data Comparison when merging (271) +- Error not triggered (571) +- predict_newdata issues (696: all cols replaced, 697: feature types) + +### POFU behaviour + +- POFU should use DataBackend cbind (388) +- POFU assert_targets_equal parameter should go (390) + +### "merging" with different row IDs + +- 216, 570, 646 +- what if multiple rows are generated? + - fill with NAs? + - aggregate? + - might be relevant for PipeOpLearenerCV with RepCV or bootstrapping +- left, right, outer, inner join? +- train() vs predict() +- predictions + - see how missing predictions / NAs are handled + +### Predictions + +- PipeOpFeatureUnion (607) + +### Splitting + +- by row_ids / logical arg (634) +- zero inflated zeug + +### What else? + +- predict cols that are then used as input +- auto-simplification From c6f33eeb637d8282a9afd6841827130ac4b6b402 Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 13 Aug 2024 15:05:19 +0200 Subject: [PATCH 16/16] notes --- attic/pofu.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/attic/pofu.md b/attic/pofu.md index 7fa73aa3f..bcc1a6c35 100644 --- a/attic/pofu.md +++ b/attic/pofu.md @@ -134,7 +134,6 @@ lr$predict_newdata(iris[1:4]) ### cbind backend simplification - ## Notes - task filter: integer ids as reported by backend @@ -146,6 +145,11 @@ lr$predict_newdata(iris[1:4]) - maybe happens when there are fewer levels than before (? -- check) - predict-newdata +- databackendrename +- backends are read-only, but want to be able to copy / extend +- do we want one multicbind, one join? +- how does this id-stuff work again? + ## Synthesis ### Mostly independent of data backend @@ -171,6 +175,7 @@ lr$predict_newdata(iris[1:4]) - train() vs predict() - predictions - see how missing predictions / NAs are handled +- how does it cope with prediction IDs being wrong? ### Predictions @@ -185,3 +190,14 @@ lr$predict_newdata(iris[1:4]) - predict cols that are then used as input - auto-simplification + +## Use Cases + + - operation performed on subset of rows, e.g. subset >> op() | otherop() + - join means NAs are introduced + - learner_cv makes prediction only for some inputs, or makes multiple predictions + - join means NAs on missing predictions? + - join aggregation of (learner_cv) predictions? + - rows with missing predictions are dropped? + - in all of the above: predict just cbinds + -