diff --git a/R/Estimator.R b/R/Estimator.R index 656b647..374bf2d 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -58,11 +58,9 @@ setEstimator <- function(learningRate = "auto", ) { checkIsClass(learningRate, c("numeric", "character")) - if (inherits(learningRate, "character")) { - if (learningRate != "auto") { - stop(paste0('Learning rate should be either a numeric or "auto", - you provided: ', learningRate)) - } + if (inherits(learningRate, "character") && learningRate != "auto") { + stop(paste0('Learning rate should be either a numeric or "auto", + you provided: ', learningRate)) } checkIsClass(weightDecay, "numeric") checkHigherEqual(weightDecay, 0.0) @@ -113,23 +111,7 @@ setEstimator <- function(learningRate = "auto", class(estimatorSettings$device)) } - paramsToTune <- list() - for (name in names(estimatorSettings)) { - param <- estimatorSettings[[name]] - if (length(param) > 1 && is.atomic(param)) { - paramsToTune[[paste0("estimator.", name)]] <- param - } - if ("params" %in% names(param)) { - for (name2 in names(param[["params"]])) { - param2 <- param[["params"]][[name2]] - if (length(param2) > 1) { - paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2 - } - } - } - } - estimatorSettings$paramsToTune <- paramsToTune - + estimatorSettings$paramsToTune <- extractParamsToTune(estimatorSettings) return(estimatorSettings) } @@ -342,14 +324,10 @@ gridCvDeep <- function(mappedData, currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames] currentEstimatorSettings <- - fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, + fillEstimatorSettings(modelSettings$estimatorSettings, + fitParams, paramSearch[[gridId]]) - - # initiate prediction - prediction <- NULL - - fold <- labels$index - ParallelLogger::logInfo(paste0("Max fold: ", max(fold))) + currentEstimatorSettings$modelType <- modelSettings$modelType currentModelParams$catFeatures <- dataset$get_cat_features()$shape[[1]] currentModelParams$numFeatures <- dataset$get_numerical_features()$shape[[1]] @@ -363,63 +341,35 @@ gridCvDeep <- function(mappedData, currentEstimatorSettings$learningRate <- lr } - learnRates <- list() - for (i in 1:max(fold)) { - ParallelLogger::logInfo(paste0("Fold ", i)) - trainDataset <- - torch$utils$data$Subset(dataset, - indices = as.integer(which(fold != i) - 1)) - # -1 for python 0-based indexing - testDataset <- - torch$utils$data$Subset(dataset, - indices = as.integer(which(fold == i) - 1)) - # -1 for python 0-based indexing - - estimator <- createEstimator(modelType = modelSettings$modelType, - modelParameters = currentModelParams, - estimatorSettings = - currentEstimatorSettings) - estimator$fit(trainDataset, testDataset) - - ParallelLogger::logInfo("Calculating predictions on left out - fold set...") - - prediction <- rbind( - prediction, - predictDeepEstimator( - plpModel = estimator, - data = testDataset, - cohort = labels[fold == i, ] - ) - ) - learnRates[[i]] <- list( - LRs = estimator$learn_rate_schedule, - bestEpoch = estimator$best_epoch - ) - } + crossValidationResults <- + doCrossvalidation(dataset, + labels = labels, + modelSettings = currentModelParams, + estimatorSettings = currentEstimatorSettings) + learnRates <- crossValidationResults$learnRates + prediction <- crossValidationResults$prediction + + gridPerformance <- + PatientLevelPrediction::computeGridPerformance(prediction, + paramSearch[[gridId]]) maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2))) gridSearchPredictons[[gridId]] <- list( prediction = prediction, param = paramSearch[[gridId]], - gridPerformance = - PatientLevelPrediction::computeGridPerformance(prediction, - paramSearch[[gridId]]) + gridPerformance = gridPerformance ) gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary$learnRates <- rep(list(unlist(learnRates[[maxIndex]]$LRs)), nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary)) gridSearchPredictons[[gridId]]$param$learnSchedule <- learnRates[[maxIndex]] - # remove all predictions that are not the max performance indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))) for (i in seq_along(gridSearchPredictons)) { - if (!is.null(gridSearchPredictons[[i]])) { - if (i != indexOfMax) { - gridSearchPredictons[[i]]$prediction <- list(NULL) - } + if (!is.null(gridSearchPredictons[[i]]) && i != indexOfMax) { + gridSearchPredictons[[i]]$prediction <- list(NULL) } } ParallelLogger::logInfo(paste0("Caching all grid search results and @@ -543,3 +493,67 @@ createEstimator <- function(modelType, estimator_settings = estimatorSettings) return(estimator) } + +doCrossvalidation <- function(dataset, + labels, + modelSettings, + estimatorSettings) { + fold <- labels$index + ParallelLogger::logInfo(paste0("Max fold: ", max(fold))) + learnRates <- list() + prediction <- NULL + for (i in 1:max(fold)) { + ParallelLogger::logInfo(paste0("Fold ", i)) + + # -1 for python 0-based indexing + trainDataset <- torch$utils$data$Subset(dataset, + indices = + as.integer(which(fold != i) - 1)) + + # -1 for python 0-based indexing + testDataset <- torch$utils$data$Subset(dataset, + indices = + as.integer(which(fold == i) - 1)) + estimator <- createEstimator(modelType = estimatorSettings$modelType, + modelParameters = modelSettings, + estimatorSettings = estimatorSettings) + estimator$fit(trainDataset, testDataset) + + ParallelLogger::logInfo("Calculating predictions on left out fold set...") + + prediction <- rbind( + prediction, + predictDeepEstimator( + plpModel = estimator, + data = testDataset, + cohort = labels[fold == i, ] + ) + ) + learnRates[[i]] <- list( + LRs = estimator$learn_rate_schedule, + bestEpoch = estimator$best_epoch + ) + } + return(results = list(prediction = prediction, + learnRates = learnRates)) + +} + +extractParamsToTune <- function(estimatorSettings) { + paramsToTune <- list() + for (name in names(estimatorSettings)) { + param <- estimatorSettings[[name]] + if (length(param) > 1 && is.atomic(param)) { + paramsToTune[[paste0("estimator.", name)]] <- param + } + if ("params" %in% names(param)) { + for (name2 in names(param[["params"]])) { + param2 <- param[["params"]][[name2]] + if (length(param2) > 1) { + paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2 + } + } + } + } + return(paramsToTune) +}