Skip to content

Commit

Permalink
Merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Dec 5, 2023
2 parents 02d3d44 + 8942e11 commit fa55849
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 2.0.1.9999
Version: 2.0.2
Date: 18-04-2023
Authors@R: c(
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")),
Expand Down
2 changes: 1 addition & 1 deletion R/TrainingCache-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ trainingCache <- R6::R6Class(
return(all(unlist(lapply(private$.paramPersistence$gridSearchPredictions,
function(x) !is.null(x$gridPerformance)))))
},

#' @description
#' Gets the last index from the cached grid search
#' @returns Last grid search index
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ if (!dir.exists(fitEstimatorPath)) {
fitEstimatorResults <- fitEstimator(trainData$Train,
modelSettings = modelSettings,
analysisId = 1,
analysisPath = fitEstimatorPath)
analysisPath = fitEstimatorPath)
32 changes: 32 additions & 0 deletions tests/testthat/test-Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ test_that("transformer nn-module works", {
dim_hidden = 32
)
output <- model(input)
expect_equal(output$shape[0], 10L)
input$num <- reticulate::py_none()
output <- model(input)
expect_equal(output$shape[0], 10L)
input$num <- reticulate::py_none()
output <- model(input)
Expand Down Expand Up @@ -164,3 +167,32 @@ test_that("numerical embedding works as expected", {
expect_equal(out$shape[[2]], embeddings)

})

test_that("numerical embedding works as expected", {
embeddings <- 32L # size of embeddings
features <- 2L # number of numerical features
patients <- 9L

numTensor <- torch$randn(c(patients, features))

numericalEmbeddingClass <- reticulate::import_from_path("ResNet", path=path)$NumericalEmbedding
numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = TRUE)
out <- numericalEmbedding(numTensor)

# should be patients x features x embedding size
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = FALSE)

out <- numericalEmbedding(numTensor)
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

})

0 comments on commit fa55849

Please sign in to comment.