From 2d8f9afc349fc6594a392b672e5439b8557ba986 Mon Sep 17 00:00:00 2001 From: Egill Axfjord Fridgeirsson Date: Thu, 12 Oct 2023 19:46:36 +0200 Subject: [PATCH] Fix numerical embeddings + add tests (#92) * fix dimension mismatch for numerical embeddings * add unit tests --- inst/python/ResNet.py | 4 ++-- tests/testthat/test-Transformer.R | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/inst/python/ResNet.py b/inst/python/ResNet.py index f680eb2..cef4b49 100644 --- a/inst/python/ResNet.py +++ b/inst/python/ResNet.py @@ -130,9 +130,9 @@ def __init__(self, nn.init.kaiming_uniform_(parameter, a=math.sqrt(5)) def forward(self, input): - x = self.weight.unsqueeze(0) * input.unsqueeze(-1) + x = self.weight[None] * input[..., None] if self.bias is not None: - x = x + self.bias.unsqueeze(-1) + x = x + self.bias[None] return x diff --git a/tests/testthat/test-Transformer.R b/tests/testthat/test-Transformer.R index b3e421f..00e6ebc 100644 --- a/tests/testthat/test-Transformer.R +++ b/tests/testthat/test-Transformer.R @@ -126,3 +126,32 @@ test_that("dimHidden ratio works as expected", { dimHiddenRatio = 4/3)) }) + +test_that("numerical embedding works as expected", { + embeddings <- 32L # size of embeddings + features <- 2L # number of numerical features + patients <- 9L + + numTensor <- torch$randn(c(patients, features)) + + numericalEmbeddingClass <- reticulate::import_from_path("ResNet", path=path)$NumericalEmbedding + numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features, + embedding_dim = embeddings, + bias = TRUE) + out <- numericalEmbedding(numTensor) + + # should be patients x features x embedding size + expect_equal(out$shape[[0]], patients) + expect_equal(out$shape[[1]], features) + expect_equal(out$shape[[2]], embeddings) + + numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features, + embedding_dim = embeddings, + bias = FALSE) + + out <- numericalEmbedding(numTensor) + expect_equal(out$shape[[0]], patients) + expect_equal(out$shape[[1]], features) + expect_equal(out$shape[[2]], embeddings) + + })