diff --git a/R/utils.R b/R/utils.R index 40a6c5c7..4f40ca0c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -833,7 +833,8 @@ get_function_name <- function(fun_start, fun_end, model_lines) { "double", "Eigen::Matrix<(.*)>", "std::vector<(.*)>", - "std::tuple<(.*)>" + "std::tuple<(.*)>", + "std::complex<(.*)>" ) pattern <- paste0( # Only match if the type occurs at start of string diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index ba6e6004..ebe4f341 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -39,6 +39,36 @@ functions { tuple(int, tuple(array[] vector, array[] vector)) rtn_nest_tuple_vec_array(tuple(int, tuple(array[] vector, array[] vector)) x) { return x; } tuple(int, tuple(array[] row_vector, array[] row_vector)) rtn_nest_tuple_rowvec_array(tuple(int, tuple(array[] row_vector, array[] row_vector)) x) { return x; } tuple(int, tuple(array[] matrix, array[] matrix)) rtn_nest_tuple_matrix_array(tuple(int, tuple(array[] matrix, array[] matrix)) x) { return x; } + + complex rtn_complex(complex x) { return x; } + complex_vector rtn_complex_vec(complex_vector x) { return x; } + complex_row_vector rtn_complex_rowvec(complex_row_vector x) { return x; } + complex_matrix rtn_complex_matrix(complex_matrix x) { return x; } + + array[] complex rtn_complex_array(array[] complex x) { return x; } + array[] complex_vector rtn_complex_vec_array(array[] complex_vector x) { return x; } + array[] complex_row_vector rtn_complex_rowvec_array(array[] complex_row_vector x) { return x; } + array[] complex_matrix rtn_complex_matrix_array(array[] complex_matrix x) { return x; } + + tuple(complex, complex) rtn_tuple_complex(tuple(complex, complex) x) { return x; } + tuple(complex_vector, complex_vector) rtn_tuple_complex_vec(tuple(complex_vector, complex_vector) x) { return x; } + tuple(complex_row_vector, complex_row_vector) rtn_tuple_complex_rowvec(tuple(complex_row_vector, complex_row_vector) x) { return x; } + tuple(complex_matrix, complex_matrix) rtn_tuple_complex_matrix(tuple(complex_matrix, complex_matrix) x) { return x; } + + tuple(array[] complex, array[] complex) rtn_tuple_complex_array(tuple(array[] complex, array[] complex) x) { return x; } + tuple(array[] complex_vector, array[] complex_vector) rtn_tuple_complex_vec_array(tuple(array[] complex_vector, array[] complex_vector) x) { return x; } + tuple(array[] complex_row_vector, array[] complex_row_vector) rtn_tuple_complex_rowvec_array(tuple(array[] complex_row_vector, array[] complex_row_vector) x) { return x; } + tuple(array[] complex_matrix, array[] complex_matrix) rtn_tuple_complex_matrix_array(tuple(array[] complex_matrix, array[] complex_matrix) x) { return x; } + + tuple(int, tuple(complex, complex)) rtn_nest_tuple_complex(tuple(int, tuple(complex, complex)) x) { return x; } + tuple(int, tuple(complex_vector, complex_vector)) rtn_nest_tuple_complex_vec(tuple(int, tuple(complex_vector, complex_vector)) x) { return x; } + tuple(int, tuple(complex_row_vector, complex_row_vector)) rtn_nest_tuple_complex_rowvec(tuple(int, tuple(complex_row_vector, complex_row_vector)) x) { return x; } + tuple(int, tuple(complex_matrix, complex_matrix)) rtn_nest_tuple_complex_matrix(tuple(int, tuple(complex_matrix, complex_matrix)) x) { return x; } + + tuple(int, tuple(array[] complex, array[] complex)) rtn_nest_tuple_complex_array(tuple(int, tuple(array[] complex, array[] complex)) x) { return x; } + tuple(int, tuple(array[] complex_vector, array[] complex_vector)) rtn_nest_tuple_complex_vec_array(tuple(int, tuple(array[] complex_vector, array[] complex_vector)) x) { return x; } + tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) rtn_nest_tuple_complex_rowvec_array(tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) x) { return x; } + tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) rtn_nest_tuple_complex_matrix_array(tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) x) { return x; } }" stan_prog <- paste(function_decl, paste(readLines(testing_stan_file("bernoulli")), @@ -147,6 +177,90 @@ test_that("Functions handle types correctly", { expect_equal(mod$functions$rtn_nest_tuple_matrix_array(nest_tuple_matrix_array), nest_tuple_matrix_array) }) +test_that("Functions handle complex types correctly", { + skip_if(os_is_wsl()) + + ### Scalar + + complex_scalar <- complex(real = 2.1, imaginary = 21.3) + + expect_equal(mod$functions$rtn_complex(complex_scalar), complex_scalar) + + ### Container + + complex_vec <- complex(real = c(2,1.5,0.11, 1.2), imaginary = c(11.2,21.5,6.1,3.2)) + complex_rowvec <- t(complex_vec) + complex_matrix <- matrix(complex_vec, nrow=2, ncol=2) + + expect_equal(mod$functions$rtn_complex_vec(complex_vec), complex_vec) + expect_equal(mod$functions$rtn_complex_rowvec(complex_rowvec), complex_rowvec) + expect_equal(mod$functions$rtn_complex_matrix(complex_matrix), complex_matrix) + expect_equal(mod$functions$rtn_complex_array(complex_vec), complex_vec) + + ### Array of Container + + complex_vec_array <- list(complex_vec, complex_vec * 2, complex_vec + 0.1) + complex_rowvec_array <- list(complex_rowvec, complex_rowvec * 2, complex_rowvec + 0.1) + complex_matrix_array <- list(complex_matrix, complex_matrix * 2, complex_matrix + 0.1) + + expect_equal(mod$functions$rtn_complex_vec_array(complex_vec_array), complex_vec_array) + expect_equal(mod$functions$rtn_complex_rowvec_array(complex_rowvec_array), complex_rowvec_array) + expect_equal(mod$functions$rtn_complex_matrix_array(complex_matrix_array), complex_matrix_array) + + ### Tuple of Scalar + + tuple_complex <- list(complex_vec[1], complex_vec[2]) + expect_equal(mod$functions$rtn_tuple_complex(tuple_complex), tuple_complex) + + ### Tuple of Container + + tuple_complex_vec <- list(complex_vec, complex_vec * 1.2) + tuple_complex_rowvec <- list(complex_rowvec, complex_rowvec * 0.5) + tuple_complex_matrix <- list(complex_matrix, complex_matrix * 10.2) + + expect_equal(mod$functions$rtn_tuple_complex_array(tuple_complex_vec), tuple_complex_vec) + expect_equal(mod$functions$rtn_tuple_complex_vec(tuple_complex_vec), tuple_complex_vec) + expect_equal(mod$functions$rtn_tuple_complex_rowvec(tuple_complex_rowvec), tuple_complex_rowvec) + expect_equal(mod$functions$rtn_tuple_complex_matrix(tuple_complex_matrix), tuple_complex_matrix) + + ### Tuple of Container Arrays + + tuple_complex_vec_array <- list(complex_vec_array, complex_vec_array) + tuple_complex_rowvec_array <- list(complex_rowvec_array, complex_rowvec_array) + tuple_complex_matrix_array <- list(complex_matrix_array, complex_matrix_array) + + expect_equal(mod$functions$rtn_tuple_complex_vec_array(tuple_complex_vec_array), tuple_complex_vec_array) + expect_equal(mod$functions$rtn_tuple_complex_rowvec_array(tuple_complex_rowvec_array), tuple_complex_rowvec_array) + expect_equal(mod$functions$rtn_tuple_complex_matrix_array(tuple_complex_matrix_array), tuple_complex_matrix_array) + + ### Nested Tuple of Scalar + + nest_tuple_complex <- list(31, tuple_complex) + expect_equal(mod$functions$rtn_nest_tuple_complex(nest_tuple_complex), nest_tuple_complex) + + ### Nested Tuple of Container + + nest_tuple_complex_vec <- list(12, tuple_complex_vec) + nest_tuple_complex_rowvec <- list(2, tuple_complex_rowvec) + nest_tuple_complex_matrix <- list(-23, tuple_complex_matrix) + nest_tuple_complex_array <- list(21, tuple_complex_vec) + + expect_equal(mod$functions$rtn_nest_tuple_complex_array(nest_tuple_complex_vec), nest_tuple_complex_vec) + expect_equal(mod$functions$rtn_nest_tuple_complex_vec(nest_tuple_complex_vec), nest_tuple_complex_vec) + expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec(nest_tuple_complex_rowvec), nest_tuple_complex_rowvec) + expect_equal(mod$functions$rtn_nest_tuple_complex_matrix(nest_tuple_complex_matrix), nest_tuple_complex_matrix) + + ### Nested Tuple of Container Arrays + + nest_tuple_complex_vec_array <- list(-21, tuple_complex_vec_array) + nest_tuple_complex_rowvec_array <- list(1000, tuple_complex_rowvec_array) + nest_tuple_complex_matrix_array <- list(0, tuple_complex_matrix_array) + + expect_equal(mod$functions$rtn_nest_tuple_complex_vec_array(nest_tuple_complex_vec_array), nest_tuple_complex_vec_array) + expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec_array(nest_tuple_complex_rowvec_array), nest_tuple_complex_rowvec_array) + expect_equal(mod$functions$rtn_nest_tuple_complex_matrix_array(nest_tuple_complex_matrix_array), nest_tuple_complex_matrix_array) +}) + test_that("Functions can be exposed in fit object", { skip_if(os_is_wsl()) fit$expose_functions(verbose = TRUE)