Skip to content

Commit

Permalink
Merge pull request #508 from OHDSI/507-fix-predictCyclopsType
Browse files Browse the repository at this point in the history
Fix small bug in predictCyclopsType and lint function
  • Loading branch information
egillax authored Dec 4, 2024
2 parents 4c7dddd + fe69115 commit 621cdef
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions R/CyclopsModels.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,54 +291,54 @@ predictCyclops <- function(plpModel, data, cohort ) {
}

predictCyclopsType <- function(coefficients, population, covariateData, modelType = "logistic") {
if (!(modelType %in% c("logistic", "poisson", "survival","cox"))) {
if (!(modelType %in% c("logistic", "poisson", "survival", "cox"))) {
stop(paste("Unknown modelType:", modelType))
}
if (!FeatureExtraction::isCovariateData(covariateData)){
if (!FeatureExtraction::isCovariateData(covariateData)) {
stop("Needs correct covariateData")
}

intercept <- coefficients$betas[coefficients$covariateId%in%'(Intercept)']
if(length(intercept)==0) intercept <- 0
betas <- coefficients$betas[!coefficients$covariateIds%in%'(Intercept)']
intercept <- coefficients$betas[coefficients$covariateIds %in% "(Intercept)"]
if (length(intercept) == 0) intercept <- 0
betas <- coefficients$betas[!coefficients$covariateIds %in% "(Intercept)"]
coefficients <- data.frame(beta = betas,
covariateId = coefficients$covariateIds[coefficients$covariateIds!='(Intercept)']
covariateId = coefficients$covariateIds[coefficients$covariateIds != "(Intercept)"]
)
coefficients <- coefficients[coefficients$beta != 0, ]
if(sum(coefficients$beta != 0)>0){
if (sum(coefficients$beta != 0) > 0) {
covariateData$coefficients <- coefficients
on.exit(covariateData$coefficients <- NULL, add = TRUE)

prediction <- covariateData$covariates %>%
dplyr::inner_join(covariateData$coefficients, by= 'covariateId') %>%
dplyr::mutate(values = .data$covariateValue*.data$beta) %>%
dplyr::inner_join(covariateData$coefficients, by = "covariateId") %>%
dplyr::mutate(values = .data$covariateValue * .data$beta) %>%
dplyr::group_by(.data$rowId) %>%
dplyr::summarise(value = sum(.data$values, na.rm = TRUE)) %>%
dplyr::select("rowId", "value")

prediction <- as.data.frame(prediction)
prediction <- merge(population, prediction, by ="rowId", all.x = TRUE, fill = 0)
prediction <- merge(population, prediction, by = "rowId", all.x = TRUE, fill = 0)
prediction$value[is.na(prediction$value)] <- 0
prediction$value <- prediction$value + intercept
} else{
warning('Model had no non-zero coefficients so predicted same for all population...')
} else {
warning("Model had no non-zero coefficients so predicted same for all population...")
prediction <- population
prediction$value <- rep(0, nrow(population)) + intercept
}
if (modelType == "logistic") {
link <- function(x) {
return(1/(1 + exp(0 - x)))
return(1 / (1 + exp(0 - x)))
}
prediction$value <- link(prediction$value)
attr(prediction, "metaData")$modelType <- 'binary'
attr(prediction, "metaData")$modelType <- "binary"
} else if (modelType == "poisson" || modelType == "survival" || modelType == "cox") {

# add baseline hazard stuff

prediction$value <- exp(prediction$value)
attr(prediction, "metaData")$modelType <- 'survival'
if(modelType == "survival"){ # is this needed?
attr(prediction, 'metaData')$timepoint <- max(population$survivalTime, na.rm = T)
attr(prediction, "metaData")$modelType <- "survival"
if (modelType == "survival") { # is this needed?
attr(prediction, "metaData")$timepoint <- max(population$survivalTime, na.rm = TRUE)
}

}
Expand Down Expand Up @@ -526,4 +526,4 @@ reparamTransferCoefs <- function(inCoefs) {
coefs <- data.frame(betas = coefs, covariateIds = rownames(coefs), row.names = NULL)

return(coefs)
}
}

0 comments on commit 621cdef

Please sign in to comment.