Skip to content

Commit

Permalink
Fix numerical embeddings + add tests (#92)
Browse files Browse the repository at this point in the history
* fix dimension mismatch for numerical embeddings

* add unit tests
  • Loading branch information
egillax authored Oct 12, 2023
1 parent 84bbb18 commit 2d8f9af
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
4 changes: 2 additions & 2 deletions inst/python/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def __init__(self,
nn.init.kaiming_uniform_(parameter, a=math.sqrt(5))

def forward(self, input):
x = self.weight.unsqueeze(0) * input.unsqueeze(-1)
x = self.weight[None] * input[..., None]
if self.bias is not None:
x = x + self.bias.unsqueeze(-1)
x = x + self.bias[None]
return x


Expand Down
29 changes: 29 additions & 0 deletions tests/testthat/test-Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,32 @@ test_that("dimHidden ratio works as expected", {
dimHiddenRatio = 4/3))

})

test_that("numerical embedding works as expected", {
embeddings <- 32L # size of embeddings
features <- 2L # number of numerical features
patients <- 9L

numTensor <- torch$randn(c(patients, features))

numericalEmbeddingClass <- reticulate::import_from_path("ResNet", path=path)$NumericalEmbedding
numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = TRUE)
out <- numericalEmbedding(numTensor)

# should be patients x features x embedding size
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = FALSE)

out <- numericalEmbedding(numTensor)
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

})

0 comments on commit 2d8f9af

Please sign in to comment.