Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pofu #624

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.3.2
VignetteBuilder: knitr
Collate:
'DataBackendJoin.R'
'DataBackendMultiCbind.R'
'Graph.R'
'GraphLearner.R'
'mlr_pipeops.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
export("%>>!%")
export("%>>%")
export(DataBackendJoin)
export(DataBackendMultiCbind)
export(Graph)
export(GraphLearner)
export(LearnerClassifAvg)
Expand Down
160 changes: 160 additions & 0 deletions R/DataBackendJoin.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@


#' @export
DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE,
public = list(
initialize = function(b1, b2, type, by_b1 = NULL, by_b2 = NULL, b1_index_colname = NULL, b2_index_colname = NULL) {
assert_backend(b1)
assert_backend(b2)

if ("data.table" %nin% intersect(b1$data_formats, b2$data_formats)) {
stop("DataBackendJoin currently only supports DataBackends that support 'data.table' format.")
}

assert_choice(type, c("left", "right", "outer", "inner"))

colnames_b1 = b1$colnames
colnames_b2 = b2$colnames
allcolnames = union(colnames_b1, colnames_b2)

assert_choice(by_b1, colnames_b1, null.ok = TRUE)
assert_choice(by_b2, colnames_b2, null.ok = TRUE)

assert_string(b1_index_colname, null.ok = TRUE)
assert_string(b2_index_colname, null.ok = TRUE)

if (!is.null(b1_index_colname) && b1_index_colname %in% setdiff(allcolnames, b1$primary_key)) stopf("b1_index_colname '%s' already a non-primary-key column in b1 or b2.", b1_index_colname)
if (!is.null(b2_index_colname) && b2_index_colname %in% setdiff(allcolnames, b2$primary_key)) stopf("b2_index_colname '%s' already a non-primary-key column in b2 or b2.", b2_index_colname)
if (!is.null(b1_index_colname) && !is.null(b2_index_colname) && b1_index_colname == b2_index_colname) stop("b1_index_colname and b2_index_colname must be different, but are both '%s'.", b1_index_colname)

colnames = unique(c(allcolnames, b1_index_colname, b2_index_colname))

rownames_b1 = b1$rownames
rownames_b2 = b2$rownames

joinby_b1 = if (is.null(by_b1)) rownames_b1 else b1$data(rownames_b1, by_b1, data_format = "data.table")[[1]]
joinby_b2 = if (is.null(by_b2)) rownames_b2 else b2$data(rownames_b2, by_b2, data_format = "data.table")[[1]]

index_table = merge(data.table(rownames_b1, joinby_b1), data.table(rownames_b2, joinby_b2), by.x = "joinby_b1", by.y = "joinby_b2",
all.x = type %in% c("left", "outer"), all.y = type %in% c("right", "outer"), sort = FALSE, allow.cartesian = TRUE)

set(index_table, , "joinby_b1", NULL)

pk = "..row_id"
index = 0
while (pk %in% allcolnames) {
index = index + 1
pk = paste0("..row_id.", index)
}

super$initialize(list(
b1 = b1, b2 = b2,
colnames_b1 = setdiff(colnames_b1, colnames_b2),
allcolnames = unique(c(colnames_b1, colnames_b2, b1_index_colname, b2_index_colname, pk)),
index_table = index_table,
b1_index_colname = b1_index_colname,
b2_index_colname = b2_index_colname,
pk = pk,
aux_hash = calculate_hash(by_b1, by_b2, type, b1_index_colname, b2_index_colname)
), primary_key = pk, data_formats = "data.table")
},

data = function(rows, cols, data_format = "data.table") {
d = private$.data
rows = rows[inrange(rows, 1, nrow(d$index_table))]
indices = d$index_table[rows]
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
indices[!is.na(rownames_b1), b1_index := seq_len(length(b1_rows))]
indices[!is.na(rownames_b2), b2_index := seq_len(length(b2_rows))]
b1_index = indices[, b1_index]
b2_index = indices[, b2_index]

data = d$b2$data(b2_rows, cols, data_format = "data.table")[b2_index]
remainingcols = intersect(cols, d$colnames_b1)
if (length(remainingcols)) {
data = cbind(data, d$b1$data(b1_rows, cols, data_format = "data.table")[b1_index])
}
setkeyv(data, NULL)
if (d$pk %in% cols) {
set(data, , d$pk, rows)
}
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
rownames_b2 = indices$rownames_b2
set(data, , d$b2_index_colname, rownames_b2)
}
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
rownames_b1 = indices$rownames_b1
set(data, ,d$b1_index_colname, rownames_b1)
}
data[, intersect(cols, names(data)), with = FALSE]
},

head = function(n = 6L) {
rows = first(self$rownames, n)
self$data(rows = rows, cols = self$colnames)
},
distinct = function(rows, cols, na_rm = TRUE) {
d = private$.data
indices = d$index_table[rows]
rownames_b1 = rownames_b2 = NULL
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
d2 = private$.data$b2$distinct(rows = b2_rows, cols = cols, na_rm = na_rm)
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
d2[[d$b2_index_colname]] = if (na_rm) unique(b2_rows) else unique(indices$rownames_b2)
}
d1 = private$.data$b1$distinct(rows = b1_rows, cols = setdiff(cols, names(d2)), na_rm = na_rm)
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
d1[[d$b1_index_colname]] = if (na_rm) unique(b1_rows) else unique(indices$rownames_b1)
}

if (!na_rm && length(b1_rows) < length(rows)) {
d1 = map(d1, function(x) if (any(is.na(x))) x else c(x, NA))
}
if (!na_rm && length(b2_rows) < length(rows)) {
d2 = map(d2, function(x) if (any(is.na(x))) x else c(x, NA))
}
res = c(d1, d2)
if (d$pk %in% cols) {
res[[d$pk]] = unique(rows)
}

res[match(cols, names(res), nomatch = 0)]
},
missings = function(rows, cols) {
d = private$.data
indices = d$index_table[rows]
rownames_b1 = rownames_b2 = NULL
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
m2 = private$.data$b2$missings(b2_rows, cols)
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
m2[d$b2_index_colname] = 0L
}
m1 = private$.data$b1$missings(b1_rows, setdiff(cols, names(m2)))
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
m1[d$b1_index_colname] = 0L
}
m1 = m1 + length(rows) - length(b1_rows)
m2 = m2 + length(rows) - length(b2_rows)
res = c(m1, m2)
if (d$pk %in% cols) {
res[d$pk] = 0L
}
res[match(cols, names(res), nomatch = 0)]
}
),
active = list(
rownames = function() seq_len(nrow(private$.data$index_table)),
colnames = function() private$.data$allcolnames,
nrow = function() nrow(private$.data$index_table),
ncol = function() length(private$.data$allcolnames)
),
private = list(
.calculate_hash = function() {
d = private$.data
calculate_hash(d$b1$hash, d$b2$hash,d$aux_hash)
}
)
)
134 changes: 134 additions & 0 deletions R/DataBackendMultiCbind.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@


#' @export
DataBackendMultiCbind = R6Class("DataBackendMultiCbind", inherit = DataBackend, cloneable = FALSE,
public = list(
initialize = function(bs) {
assert_list(bs, min.len = 1)
lapply(bs, assert_backend)

formats = Reduce(intersect, map(bs, "data_formats"))

private$.colnames = unique(unlist(map(bs, "colnames")))

# primary key: if all backends have the same pk, just use that one.
otherpk = unique(unlist(map(bs, "primary_key")))
if (length(otherpk) == 1) {
pk = otherpk
} else {
# otherwise: introduce a new primary key that is completely different from the previous ones.
pk = "..row_id"
index = 0
while (pk %in% private$.colnames) {
index = index + 1
pk = paste0("..row_id.", index)
}
private$.colnames = c(private$.colnames, pk)
}

super$initialize(list(bs = rev(bs)), pk, formats)
},
data = function(rows, cols, data_format = "data.table") {
bs = private$.data$bs

urows = unique(rows)

datas = list()
pks = character(length(bs))
include_pk = logical(length(bs))
cols_remaining = cols
allrows = list()
for (i in seq_along(bs)) {
## Not doing 'if (length(cols_remaining)) break' because there could still be tables remaining that add rows
pk = bs[[i]]$primary_key
pks[[i]] = pk
include_pk[[i]] = pk %in% cols_remaining
if (include_pk[[i]]) {
datas[[i]] = bs[[i]]$data(urows, cols_remaining, data_format = data_format)
cols_remaining = setdiff(cols_remaining, colnames(datas[[i]]))
} else {
datas[[i]] = bs[[i]]$data(urows, c(pk, cols_remaining), data_format = data_format)
cols_remaining = setdiff(cols_remaining, colnames(datas[[i]])[-1])
}
allrows[[i]] = datas[[i]][[pk]]
}
presentrows = unique(unlist(allrows))
join = list(presentrows)
result = do.call(cbind, pmap(list(datas, pks, include_pk), function(data, pk, include) {
if (include) {
result = data[join, on = pk, nomatch = NA]
set(result, result[[pk]] %nin% data[[pk]], pk, NA)
} else {
data[join, -pk, on = pk, with = FALSE, nomatch = NA]
}
}))
sbk = self$primary_key

set(result, , sbk, presentrows)
join = list(rows)
result[join, intersect(cols, colnames(result)), with = FALSE, on = sbk, nomatch = NULL]
},
head = function(n = 6L) {
rows = head(self$rownames, n)
self$data(rows = rows, cols = self$colnames)
},
distinct = function(rows, cols, na_rm = TRUE) {
bs = private$.data$bs
getpk = self$primary_key %in% cols
reslist = list()
remaining_cols = cols
if (!na_rm || getpk) {
rows = intersect(rows, self$rownames)
}
for (i in seq_along(bs)) {
if (!length(remaining_cols)) break
reslist[[i]] = bs[[i]]$distinct(rows = rows, cols = cols, na_rm = na_rm)
remaining_cols = setdiff(remaining_cols, names(reslist[[i]]))
if (!na_rm && !all(rows %in% bs[[i]]$rownames)) {
reslist[[i]] = map(reslist[[i]], function(x) if (any(is.na(x))) x else c(x, NA))
}
}
result = unlist(reslist, recursive = FALSE)
if (getpk) {
result[[self$primary_key]] = rows
}
result[match(cols, names(result), nomatch = 0)]
},
missings = function(rows, cols) {
rows = rows[rows %in% self$rownames]
bs = private$.data$bs
getpk = self$primary_key %in% cols
reslist = list()
remaining_cols = cols
for (i in seq_along(bs)) {
if (!length(remaining_cols)) break
missingrows = sum(rows %nin% bs[[i]]$rownames)
reslist[[i]] = bs[[i]]$missings(rows, remaining_cols) + missingrows
remaining_cols = setdiff(remaining_cols, names(reslist[[i]]))
}
result = unlist(reslist)
if (self$primary_key %in% cols) {
result[[self$primary_key]] = 0L
}
result[match(cols, names(result), nomatch = 0)]
}
),
active = list(
rownames = function() {
if (is.null(private$.rownames_cache)) private$.rownames_cache = unique(unlist(rev(map(private$.data$bs, "rownames"))))
private$.rownames_cache
},
colnames = function() {
private$.colnames
},
nrow = function() length(self$rownames),
ncol = function() length(self$colnames)
),
private = list(
.rownames_cache = NULL,
.colnames = NULL,
.calculate_hash = function() {
do.call(calculate_hash, private$.data$bs)
}
)
)
Loading
Loading