Skip to content

Commit

Permalink
Merge pull request #832 from stan-dev/variable-skeleton-dims
Browse files Browse the repository at this point in the history
Fix variable_skeleton() with containers
  • Loading branch information
andrjohns authored Aug 23, 2023
2 parents 2b04e4f + 97d1142 commit a9c898b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
6 changes: 5 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,11 @@ create_skeleton <- function(param_metadata, model_variables,
names(model_variables$generated_quantities))
}
lapply(param_metadata[target_params], function(par_dims) {
array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims))
if ((length(par_dims) == 0)) {
array(0, dim = 1)
} else {
array(0, dim = par_dims)
}
})
}

Expand Down
35 changes: 35 additions & 0 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,38 @@ test_that("Model methods can be initialised for models with no data", {
expect_no_error(fit <- mod$sample())
expect_equal(fit$log_prob(5), -12.5)
})

test_that("Variable skeleton returns correct dimensions for matrices", {
skip_if(os_is_wsl())

stan_file <- write_stan_file("
data {
int N;
int K;
}
parameters {
real x_real;
matrix[N,K] x_mat;
vector[K] x_vec;
row_vector[K] x_rowvec;
}
model {
x_real ~ std_normal();
}")
mod <- cmdstan_model(stan_file, compile_model_methods = TRUE,
force_recompile = TRUE)
N <- 4
K <- 3
fit <- mod$sample(data = list(N = N, K = K), chains = 1,
iter_warmup = 1, iter_sampling = 1)

target_skeleton <- list(
x_real = array(0, dim = 1),
x_mat = array(0, dim = c(N, K)),
x_vec = array(0, dim = K),
x_rowvec = array(0, dim = K)
)

expect_equal(fit$variable_skeleton(),
target_skeleton)
})

0 comments on commit a9c898b

Please sign in to comment.