From 935f51d99b37b5c0334467a7f29980e31db7d88f Mon Sep 17 00:00:00 2001 From: Egill Axfjord Fridgeirsson Date: Wed, 19 Apr 2023 05:26:39 -0400 Subject: [PATCH] version 1.1.4 (#67) Adds device input as a function to estimator --- DESCRIPTION | 7 +++--- NEWS.md | 6 +++++ R/Estimator-class.R | 7 +++++- R/Estimator.R | 5 +++-- man/setEstimator.Rd | 5 +++-- tests/testthat/test-Estimator.R | 40 ++++++++++++++++++++++++++++++++- tests/testthat/test-LRFinder.R | 2 +- 7 files changed, 61 insertions(+), 11 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7e4cb2e..03626df 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: DeepPatientLevelPrediction Type: Package Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model -Version: 1.1.3 -Date: 15-12-2022 +Version: 1.1.4 +Date: 18-04-2023 Authors@R: c( person("Egill", "Fridgeirsson", email = "e.fridgeirsson@erasmusmc.nl", role = c("aut", "cre")), person("Jenna", "Reps", email = "jreps@its.jnj.com", role = c("aut")), @@ -24,8 +24,7 @@ Imports: ParallelLogger (>= 2.0.0), PatientLevelPrediction (>= 6.0.4), rlang, - torch (>= 0.9.0), - torchopt, + torch (>= 0.10.0), withr Suggests: devtools, diff --git a/NEWS.md b/NEWS.md index 0472025..8202096 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +DeepPatientLevelPrediction 1.1.4 +====================== + - Remove torchopt dependancy since adamw is now in torch + - Update torch dependency to >=0.10.0 + - Allow device to be a function that resolves during Estimator initialization + DeepPatientLevelPrediction 1.1.3 ====================== - Fix actions after torch updated to v0.10 (#65) diff --git a/R/Estimator-class.R b/R/Estimator-class.R index 96a5f2d..5d2fe76 100644 --- a/R/Estimator-class.R +++ b/R/Estimator-class.R @@ -34,7 +34,12 @@ Estimator <- R6::R6Class( modelParameters, estimatorSettings) { self$seed <- estimatorSettings$seed - self$device <- estimatorSettings$device + if (is.function(estimatorSettings$device)) { + device <- estimatorSettings$device() + } else { + device <- estimatorSettings$device + } + self$device <- device torch::torch_manual_seed(seed=self$seed) self$model <- do.call(modelType, modelParameters) self$modelParameters <- modelParameters diff --git a/R/Estimator.R b/R/Estimator.R index c777d7d..3b9a000 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -26,7 +26,8 @@ #' @param weightDecay what weight_decay to use #' @param batchSize batchSize to use #' @param epochs how many epochs to train for -#' @param device what device to train on +#' @param device what device to train on, can be a string or a function to that evaluates +#' to the device during runtime #' @param optimizer which optimizer to use #' @param scheduler which learning rate scheduler to use #' @param criterion loss function to use @@ -41,7 +42,7 @@ setEstimator <- function(learningRate='auto', batchSize = 512, epochs = 30, device='cpu', - optimizer = torchopt::optim_adamw, + optimizer = torch::optim_adamw, scheduler = list(fun=torch::lr_reduce_on_plateau, params=list(patience=1)), criterion = torch::nn_bce_with_logits_loss, diff --git a/man/setEstimator.Rd b/man/setEstimator.Rd index b924b03..8454c55 100644 --- a/man/setEstimator.Rd +++ b/man/setEstimator.Rd @@ -10,7 +10,7 @@ setEstimator( batchSize = 512, epochs = 30, device = "cpu", - optimizer = torchopt::optim_adamw, + optimizer = torch::optim_adamw, scheduler = list(fun = torch::lr_reduce_on_plateau, params = list(patience = 1)), criterion = torch::nn_bce_with_logits_loss, earlyStopping = list(useEarlyStopping = TRUE, params = list(patience = 4)), @@ -27,7 +27,8 @@ setEstimator( \item{epochs}{how many epochs to train for} -\item{device}{what device to train on} +\item{device}{what device to train on, can be a string or a function to that evaluates +to the device during runtime} \item{optimizer}{which optimizer to use} diff --git a/tests/testthat/test-Estimator.R b/tests/testthat/test-Estimator.R index aa8459d..faf95db 100644 --- a/tests/testthat/test-Estimator.R +++ b/tests/testthat/test-Estimator.R @@ -298,4 +298,42 @@ test_that("setEstimator with paramsToTune is correctly added to hyperparameters" expect_equal(estimatorSettings2$learningRate, 1e-3) expect_equal(as.character(estimatorSettings2$metric), "auprc") expect_equal(estimatorSettings2$earlyStopping$params$patience, 10) -}) \ No newline at end of file +}) + +test_that("device as a function argument works", { + getDevice <- function() { + dev <- Sys.getenv("testDeepPLPDevice") + if (dev == ""){ + dev = "cpu" + } else{ + dev + } + } + + estimatorSettings <- setEstimator(device=getDevice) + + model <- setDefaultResNet(estimatorSettings = estimatorSettings) + model$param[[1]]$catFeatures <- 10 + + estimator <- Estimator$new(modelType="ResNet", + modelParameters = model$param[[1]], + estimatorSettings = estimatorSettings) + + expect_equal(estimator$device, "cpu") + + Sys.setenv("testDeepPLPDevice" = "meta") + + estimatorSettings <- setEstimator(device=getDevice) + + model <- setDefaultResNet(estimatorSettings = estimatorSettings) + model$param[[1]]$catFeatures <- 10 + + estimator <- Estimator$new(modelType="ResNet", + modelParameters = model$param[[1]], + estimatorSettings = estimatorSettings) + + expect_equal(estimator$device, "meta") + + Sys.unsetenv("testDeepPLPDevice") + + }) diff --git a/tests/testthat/test-LRFinder.R b/tests/testthat/test-LRFinder.R index 7f56a0f..30bdc8d 100644 --- a/tests/testthat/test-LRFinder.R +++ b/tests/testthat/test-LRFinder.R @@ -3,7 +3,7 @@ test_that("LR scheduler that changes per batch works", { model <- ResNet(catFeatures = 10, numFeatures = 1, sizeEmbedding = 32, sizeHidden = 64, numLayers = 1, hiddenFactor = 1) - optimizer <- torchopt::optim_adamw(model$parameters, lr=1e-7) + optimizer <- torch::optim_adamw(model$parameters, lr=1e-7) scheduler <- lrPerBatch(optimizer, startLR = 1e-7,