From 0a65fc553df59d69bb2adaf17490fef170e1834d Mon Sep 17 00:00:00 2001 From: egillax Date: Mon, 18 Nov 2024 16:23:53 +0100 Subject: [PATCH] add existing splitter and tests --- R/DataSplitting.R | 35 ++++++++++++++++++++++++++--- tests/testthat/test-dataSplitting.R | 35 +++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/R/DataSplitting.R b/R/DataSplitting.R index 0c2374285..97960a7e8 100644 --- a/R/DataSplitting.R +++ b/R/DataSplitting.R @@ -93,6 +93,26 @@ createDefaultSplitSetting <- function(testFraction = 0.25, return(splitSettings) } +#' Create the settings for defining how the plpData are split into +#' test/validation/train sets using an existing split - good to use for +#' reproducing results from a different run +#' @param splitIds (data.frame) A data frame with rowId and index columns of +#' type integer/numeric. Index is -1 for test set, positive integer for train +#' set folds +#' @return An object of class \code{splitSettings} +#' @export +createExistingSplitSettings <- function(splitIds) { + checkIsClass(splitIds, "data.frame") + checkColumnNames(splitIds, c("rowId", "index")) + checkIsClass(splitIds$rowId, c("integer", "numeric")) + checkIsClass(splitIds$index, c("integer", "numeric")) + checkHigherEqual(splitIds$index, -1) + + splitSettings <- list(splitIds = splitIds) + attr(splitSettings, "fun") <- "existingSplitter" + class(splitSettings) <- "splitSettings" + return(splitSettings) +} #' Split the plpData into test/train sets using a splitting settings of class @@ -561,7 +581,16 @@ checkInputsSplit <- function(test, train, nfold, seed) { ParallelLogger::logDebug(paste0("nfold: ", nfold)) checkIsClass(nfold, c("numeric", "integer")) checkHigher(nfold, 1) - - ParallelLogger::logInfo(paste0('seed: ', seed)) - checkIsClass(seed, c('numeric','integer')) + + ParallelLogger::logInfo(paste0("seed: ", seed)) + checkIsClass(seed, c("numeric", "integer")) +} + +existingSplitter <- function(population, splitSettings) { + splitIds <- splitSettings$splitIds + # check all row Ids are in population + if (sum(!splitIds$rowId %in% population$rowId) > 0) { + stop("Not all rowIds in splitIds are in the population") + } + return(splitIds) } diff --git a/tests/testthat/test-dataSplitting.R b/tests/testthat/test-dataSplitting.R index b8ce628bb..64d2f5cd9 100644 --- a/tests/testthat/test-dataSplitting.R +++ b/tests/testthat/test-dataSplitting.R @@ -417,5 +417,40 @@ test_that("Data splitting by subject", { # test that no subject is not assigned a fold expect_equal(sum(test$index==0), 0) + # test that no subject is not assigned a fold + expect_equal(sum(test$index == 0), 0) +}) + +test_that("Existing data splitter works", { + # split by age + age <- population$ageYear + # create empty index same lengths as age + index <- rep(0, length(age)) + index[age > 43] <- -1 # test set + index[age <= 35] <- 1 # train fold 1 + index[age > 35 & age <= 43] <- 2 # train fold 2 + splitIds <- data.frame(rowId = population$rowId, index = index) + splitSettings <- createExistingSplitSettings(splitIds) + ageSplit <- splitData( + plpData = plpData, + population = population, + splitSettings = splitSettings + ) + + # test only old people in test + expect_equal( + length(ageSplit$Test$labels$rowId), + sum(age > 43) + ) + # only young people in train + expect_equal( + length(ageSplit$Train$labels$rowId), + sum(age <= 43) + ) + # no overlap + expect_equal( + length(intersect(ageSplit$Test$labels$rowId, ageSplit$Train$labels$rowId)), + 0 + ) })