From d0c310ecd280f4972b8fb5dccff1d6f0381f0678 Mon Sep 17 00:00:00 2001 From: Gowtham Rao Date: Wed, 31 Jan 2024 20:22:51 -0500 Subject: [PATCH] Conditionally use timeId https://github.com/OHDSI/FeatureExtraction/issues/225 --- R/CompareCohorts.R | 108 +++++++++++++++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/R/CompareCohorts.R b/R/CompareCohorts.R index 80cf2074..4b86fe6e 100644 --- a/R/CompareCohorts.R +++ b/R/CompareCohorts.R @@ -59,6 +59,14 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort if (!isAggregatedCovariateData(covariateData2)) { stop("Covariate2 data is not aggregated") } + if (colnames(covariateData1$covariates) |> sort() != colnames(covariateData1$covariates) |> sort()) { + stop("Covariate1 and Covariate2 do not have the same structure") + } + covariateDataHasTimeId <- FALSE + if ("timeId" %in% colnames(covariateData1$covariates)) { + covariateDataHasTimeId <- TRUE + } + result <- tibble() if (!is.null(covariateData1$covariates) && !is.null(covariateData2$covariates)) { covariates1 <- covariateData1$covariates @@ -66,24 +74,37 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort covariates1 <- covariates1 %>% filter(cohortDefinitionId == cohortId1) } - covariates1 <- covariates1 %>% - select( - covariateId = "covariateId", - count1 = "sumValue" - ) %>% - collect() + + if (covariateDataHasTimeId) { + covariates1 <- covariates1 %>% + select(timeId = "timeId", + covariateId = "covariateId", + count1 = "sumValue") %>% + collect() + } else { + covariates1 <- covariates1 %>% + select(covariateId = "covariateId", + count1 = "sumValue") %>% + collect() + } covariates2 <- covariateData2$covariates if (!is.null(cohortId2)) { covariates2 <- covariates2 %>% filter(cohortDefinitionId == cohortId2) } - covariates2 <- covariates2 %>% - select( - covariateId = "covariateId", - count2 = "sumValue" - ) %>% - collect() + if (covariateDataHasTimeId) { + covariates2 <- covariates2 %>% + select(timeId = "timeId", + covariateId = "covariateId", + count2 = "sumValue") %>% + collect() + } else { + covariates2 <- covariates2 %>% + select(covariateId = "covariateId", + count2 = "sumValue") %>% + collect() + } n1 <- attr(covariateData1, "metaData")$populationSize if (!is.null(cohortId1)) { @@ -102,7 +123,13 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort m$sd2 <- sqrt(m$mean2 * (1 - m$mean2)) m$sd <- sqrt((m$sd1^2 + m$sd2^2) / 2) m$stdDiff <- (m$mean2 - m$mean1) / m$sd - result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + if (covariateDataHasTimeId) { + result <- + bind_rows(result, m[, c("covariateId", "timeId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + } else { + result <- + bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + } } if (!is.null(covariateData1$covariatesContinuous) && !is.null(covariateData2$covariatesContinuous)) { covariates1 <- covariateData1$covariatesContinuous @@ -110,26 +137,45 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort covariates1 <- covariates1 %>% filter(cohortDefinitionId == cohortId1) } - covariates1 <- covariates1 %>% - select( - covariateId = "covariateId", - mean1 = "averageValue", - sd1 = "standardDeviation" - ) %>% - collect() + + if (covariateDataHasTimeId) { + covariates1 <- covariates1 %>% + select( + timeId = "timeId", + covariateId = "covariateId", + mean1 = "averageValue", + sd1 = "standardDeviation" + ) %>% + collect() + } else { + covariates1 <- covariates1 %>% + select(covariateId = "covariateId", + mean1 = "averageValue", + sd1 = "standardDeviation") %>% + collect() + } covariates2 <- covariateData2$covariatesContinuous if (!is.null(cohortId2)) { covariates2 <- covariates2 %>% filter(cohortDefinitionId == cohortId2) } - covariates2 <- covariates2 %>% - select( - covariateId = "covariateId", - mean2 = "averageValue", - sd2 = "standardDeviation" - ) %>% - collect() + if (covariateDataHasTimeId) { + covariates2 <- covariates2 %>% + select( + timeId = "timeId", + covariateId = "covariateId", + mean2 = "averageValue", + sd2 = "standardDeviation" + ) %>% + collect() + } else { + covariates2 <- covariates2 %>% + select(covariateId = "covariateId", + mean2 = "averageValue", + sd2 = "standardDeviation") %>% + collect() + } m <- merge(covariates1, covariates2, all = T) m$mean1[is.na(m$mean1)] <- 0 @@ -138,7 +184,13 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort m$sd2[is.na(m$sd2)] <- 0 m$sd <- sqrt(m$sd1^2 + m$sd2^2) m$stdDiff <- (m$mean2 - m$mean1) / m$sd - result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + if (covariateDataHasTimeId) { + result <- + bind_rows(result, m[, c("covariateId", "timeId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + } else { + result <- + bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")]) + } } covariateRef1 <- covariateData1$covariateRef %>% collect()