Skip to content

Commit

Permalink
add existing splitter and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Nov 18, 2024
1 parent bb82e1a commit 0a65fc5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
35 changes: 32 additions & 3 deletions R/DataSplitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ createDefaultSplitSetting <- function(testFraction = 0.25,
return(splitSettings)
}

#' Create the settings for defining how the plpData are split into
#' test/validation/train sets using an existing split - good to use for
#' reproducing results from a different run
#' @param splitIds (data.frame) A data frame with rowId and index columns of
#' type integer/numeric. Index is -1 for test set, positive integer for train
#' set folds
#' @return An object of class \code{splitSettings}
#' @export
createExistingSplitSettings <- function(splitIds) {
checkIsClass(splitIds, "data.frame")
checkColumnNames(splitIds, c("rowId", "index"))
checkIsClass(splitIds$rowId, c("integer", "numeric"))
checkIsClass(splitIds$index, c("integer", "numeric"))
checkHigherEqual(splitIds$index, -1)

splitSettings <- list(splitIds = splitIds)
attr(splitSettings, "fun") <- "existingSplitter"
class(splitSettings) <- "splitSettings"
return(splitSettings)
}


#' Split the plpData into test/train sets using a splitting settings of class
Expand Down Expand Up @@ -561,7 +581,16 @@ checkInputsSplit <- function(test, train, nfold, seed) {
ParallelLogger::logDebug(paste0("nfold: ", nfold))
checkIsClass(nfold, c("numeric", "integer"))
checkHigher(nfold, 1)

ParallelLogger::logInfo(paste0('seed: ', seed))
checkIsClass(seed, c('numeric','integer'))

ParallelLogger::logInfo(paste0("seed: ", seed))
checkIsClass(seed, c("numeric", "integer"))
}

existingSplitter <- function(population, splitSettings) {
splitIds <- splitSettings$splitIds
# check all row Ids are in population
if (sum(!splitIds$rowId %in% population$rowId) > 0) {
stop("Not all rowIds in splitIds are in the population")
}
return(splitIds)
}
35 changes: 35 additions & 0 deletions tests/testthat/test-dataSplitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -417,5 +417,40 @@ test_that("Data splitting by subject", {
# test that no subject is not assigned a fold
expect_equal(sum(test$index==0), 0)

# test that no subject is not assigned a fold
expect_equal(sum(test$index == 0), 0)
})

test_that("Existing data splitter works", {
# split by age
age <- population$ageYear
# create empty index same lengths as age
index <- rep(0, length(age))
index[age > 43] <- -1 # test set
index[age <= 35] <- 1 # train fold 1
index[age > 35 & age <= 43] <- 2 # train fold 2
splitIds <- data.frame(rowId = population$rowId, index = index)
splitSettings <- createExistingSplitSettings(splitIds)
ageSplit <- splitData(
plpData = plpData,
population = population,
splitSettings = splitSettings
)

# test only old people in test
expect_equal(
length(ageSplit$Test$labels$rowId),
sum(age > 43)
)
# only young people in train
expect_equal(
length(ageSplit$Train$labels$rowId),
sum(age <= 43)
)
# no overlap
expect_equal(
length(intersect(ageSplit$Test$labels$rowId, ageSplit$Train$labels$rowId)),
0
)

})

0 comments on commit 0a65fc5

Please sign in to comment.