diff --git a/R/ExtractData.R b/R/ExtractData.R index bf403406..b2462dd8 100644 --- a/R/ExtractData.R +++ b/R/ExtractData.R @@ -221,8 +221,9 @@ getPlpData <- function( checkIsClass(covariateSettings[[i]], "covariateSettings") } } - - checkIsClass(restrictPlpDataSettings, "restrictPlpDataSettings") + if (!is.null(restrictPlpDataSettings)) { + checkIsClass(restrictPlpDataSettings, "restrictPlpDataSettings") + } @@ -448,7 +449,7 @@ summary.plpData <- function(object, ...) { eventCount = 0, personCount = 0 ) - for (i in seq_along(outcomeCounts)) { + for (i in seq_len(nrow(outcomeCounts))) { outcomeCounts$eventCount[i] <- sum(object$outcomes$outcomeId == attr(object$outcomes, "metaData")$outcomeIds[i]) outcomeCounts$personCount[i] <- length(unique(object$outcomes$rowId[object$outcomes$outcomeId == attr(object$outcomes, "metaData")$outcomeIds[i]])) } diff --git a/tests/testthat/test-extractData.R b/tests/testthat/test-extractData.R index 0e15eb37..28a80f79 100644 --- a/tests/testthat/test-extractData.R +++ b/tests/testthat/test-extractData.R @@ -19,24 +19,24 @@ context("extractPlp") test_that("summary.plpData", { attr(plpData$outcomes, "metaData")$outcomeIds <- c(outcomeId) sum <- summary.plpData(plpData) - testthat::expect_equal(class(sum),'summary.plpData') + testthat::expect_equal(class(sum), "summary.plpData") }) test_that("getPlpData errors", { testthat::expect_error( getPlpData( databaseDetails = list(targetId = NULL) - ) + ) ) testthat::expect_error( getPlpData( - databaseDetails = list(targetId = c(1,2)) - ) + databaseDetails = list(targetId = c(1, 2)) + ) ) testthat::expect_error( getPlpData( databaseDetails = list(targetId = 1, outcomeIds = NULL) - ) + ) ) }) @@ -51,11 +51,35 @@ test_that("getCovariateData", { test_that("createDatabaseDetails with NULL cdmDatabaseId errors", { testthat::expect_error(createDatabaseDetails( - connectionDetails = list(), - cdmDatabaseSchema = 'main', - cdmDatabaseId = NULL, - targetId = 1, + connectionDetails = list(), + cdmDatabaseSchema = "main", + cdmDatabaseId = NULL, + targetId = 1, outcomeIds = outcomeId )) }) +test_that("getPlpData checks covariateSettings object", { + testthat::expect_error(getPlpData( + databaseDetails = list(targetId = 1, outcomeIds = outcomeId), + covariateSettings = list() + )) + + settings1 <- + FeatureExtraction::createCovariateSettings(useDemographicsGender = TRUE) + settings2 <- + FeatureExtraction::createCovariateSettings(useDemographicsAge = TRUE) + plpData <- getPlpData( + databaseDetails = databaseDetails, + covariateSettings = list(settings1, settings2) + ) + expect_equal(plpData$covariateData$covariateRef %>% dplyr::pull(.data$analysisId %>% length()), 3) + + settings3 <- list(covariateId = 3) + class(settings3) <- "NotCovariateSettings" + + expect_Error(getPlpData( + databaseDetails = databaseDetails, + covariateSettings = list(settings1, settings3) + )) +})