Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring gridCvDeep #104

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 85 additions & 71 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ setEstimator <- function(learningRate = "auto",
) {

checkIsClass(learningRate, c("numeric", "character"))
if (inherits(learningRate, "character")) {
if (learningRate != "auto") {
stop(paste0('Learning rate should be either a numeric or "auto",
you provided: ', learningRate))
}
if (inherits(learningRate, "character") && learningRate != "auto") {
stop(paste0('Learning rate should be either a numeric or "auto",
you provided: ', learningRate))
}
checkIsClass(weightDecay, "numeric")
checkHigherEqual(weightDecay, 0.0)
Expand Down Expand Up @@ -113,23 +111,7 @@ setEstimator <- function(learningRate = "auto",
class(estimatorSettings$device))
}

paramsToTune <- list()
for (name in names(estimatorSettings)) {
param <- estimatorSettings[[name]]
if (length(param) > 1 && is.atomic(param)) {
paramsToTune[[paste0("estimator.", name)]] <- param
}
if ("params" %in% names(param)) {
for (name2 in names(param[["params"]])) {
param2 <- param[["params"]][[name2]]
if (length(param2) > 1) {
paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2
}
}
}
}
estimatorSettings$paramsToTune <- paramsToTune

estimatorSettings$paramsToTune <- extractParamsToTune(estimatorSettings)
return(estimatorSettings)
}

Expand Down Expand Up @@ -342,14 +324,10 @@ gridCvDeep <- function(mappedData,
currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]

currentEstimatorSettings <-
fillEstimatorSettings(modelSettings$estimatorSettings, fitParams,
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
paramSearch[[gridId]])

# initiate prediction
prediction <- NULL

fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
currentEstimatorSettings$modelType <- modelSettings$modelType
currentModelParams$catFeatures <- dataset$get_cat_features()$shape[[1]]
currentModelParams$numFeatures <-
dataset$get_numerical_features()$shape[[1]]
Expand All @@ -363,63 +341,35 @@ gridCvDeep <- function(mappedData,
currentEstimatorSettings$learningRate <- lr
}

learnRates <- list()
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))
trainDataset <-
torch$utils$data$Subset(dataset,
indices = as.integer(which(fold != i) - 1))
# -1 for python 0-based indexing
testDataset <-
torch$utils$data$Subset(dataset,
indices = as.integer(which(fold == i) - 1))
# -1 for python 0-based indexing

estimator <- createEstimator(modelType = modelSettings$modelType,
modelParameters = currentModelParams,
estimatorSettings =
currentEstimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out
fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
crossValidationResults <-
doCrossvalidation(dataset,
labels = labels,
modelSettings = currentModelParams,
estimatorSettings = currentEstimatorSettings)
learnRates <- crossValidationResults$learnRates
prediction <- crossValidationResults$prediction

gridPerformance <-
PatientLevelPrediction::computeGridPerformance(prediction,
paramSearch[[gridId]])
maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2)))
gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
param = paramSearch[[gridId]],
gridPerformance =
PatientLevelPrediction::computeGridPerformance(prediction,
paramSearch[[gridId]])
gridPerformance = gridPerformance
)
gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary$learnRates <-
rep(list(unlist(learnRates[[maxIndex]]$LRs)),
nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary))
gridSearchPredictons[[gridId]]$param$learnSchedule <-
learnRates[[maxIndex]]

# remove all predictions that are not the max performance
indexOfMax <-
which.max(unlist(lapply(gridSearchPredictons,
function(x) x$gridPerformance$cvPerformance)))
for (i in seq_along(gridSearchPredictons)) {
if (!is.null(gridSearchPredictons[[i]])) {
if (i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- list(NULL)
}
if (!is.null(gridSearchPredictons[[i]]) && i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- list(NULL)
}
}
ParallelLogger::logInfo(paste0("Caching all grid search results and
Expand Down Expand Up @@ -543,3 +493,67 @@ createEstimator <- function(modelType,
estimator_settings = estimatorSettings)
return(estimator)
}

doCrossvalidation <- function(dataset,
labels,
modelSettings,
estimatorSettings) {
fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
learnRates <- list()
prediction <- NULL
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))

# -1 for python 0-based indexing
trainDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold != i) - 1))

# -1 for python 0-based indexing
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(modelType = estimatorSettings$modelType,
modelParameters = modelSettings,
estimatorSettings = estimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
return(results = list(prediction = prediction,
learnRates = learnRates))

}

extractParamsToTune <- function(estimatorSettings) {
paramsToTune <- list()
for (name in names(estimatorSettings)) {
param <- estimatorSettings[[name]]
if (length(param) > 1 && is.atomic(param)) {
paramsToTune[[paste0("estimator.", name)]] <- param
}
if ("params" %in% names(param)) {
for (name2 in names(param[["params"]])) {
param2 <- param[["params"]][[name2]]
if (length(param2) > 1) {
paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2
}
}
}
}
return(paramsToTune)
}