diff --git a/inst/include/bvhardraw.h b/inst/include/bvhardraw.h index 32be743d..8fdc289e 100644 --- a/inst/include/bvhardraw.h +++ b/inst/include/bvhardraw.h @@ -816,6 +816,29 @@ inline void dl_mn_sparsity(Eigen::VectorXd& group_param, Eigen::VectorXi& grp_ve } } +inline void dl_mn_sparsity(Eigen::VectorXd& group_param, Eigen::VectorXi& grp_vec, Eigen::VectorXi& grp_id, + double& global_param, Eigen::Ref local_param, Eigen::Ref latent_param, + double& shape, double& rate, + Eigen::Ref coef_vec, boost::random::mt19937& rng) { + Eigen::Array group_id; + int mn_size = 0; + for (int i = 0; i < grp_id.size(); i++) { + group_id = grp_vec.array() == grp_id[i]; + mn_size = group_id.count(); + Eigen::VectorXd mn_scl(mn_size); + for (int j = 0, k = 0; j < coef_vec.size(); ++j) { + if (group_id[j]) { + mn_scl[k++] = coef_vec[j] * coef_vec[j] / (global_param * global_param * local_param[j] * local_param[j] * latent_param[j]); + } + } + group_param[i] = sqrt(1 / gamma_rand( + shape + mn_size / 2, + 1 / (rate + mn_scl.sum() / 2), + rng + )); + } +} + // Log-density for Dirichlet Hyperparameter in DL // // Log density of Dirichlet hyperparameter ignoring constant term diff --git a/inst/include/bvharmcmc.h b/inst/include/bvharmcmc.h index 031e0717..5c289978 100644 --- a/inst/include/bvharmcmc.h +++ b/inst/include/bvharmcmc.h @@ -739,7 +739,9 @@ class McmcDl : public BaseMcmc { using BaseMcmc::contem_coef; using BaseMcmc::updateCoefRecords; void updateCoefPrec() override { - dl_mn_sparsity(group_lev, grp_vec, grp_id, global_lev, local_lev, shape, scl, coef_vec.head(num_alpha), rng); + // dl_mn_sparsity(group_lev, grp_vec, grp_id, global_lev, local_lev, shape, scl, coef_vec.head(num_alpha), rng); + dl_latent(latent_local, global_lev * local_lev.array() * coef_var.array(), coef_vec.head(num_alpha), rng); + dl_mn_sparsity(group_lev, grp_vec, grp_id, global_lev, local_lev, latent_local, shape, scl, coef_vec.head(num_alpha), rng); for (int j = 0; j < num_grp; j++) { coef_var = (grp_vec.array() == grp_id[j]).select( group_lev[j], @@ -752,7 +754,7 @@ class McmcDl : public BaseMcmc { if (is_group::value) { global_lev = dl_global_sparsity(local_lev.array() * coef_var.array(), dir_concen, coef_vec.head(num_alpha), rng); } - dl_latent(latent_local, global_lev * local_lev.array() * coef_var.array(), coef_vec.head(num_alpha), rng); + // dl_latent(latent_local, global_lev * local_lev.array() * coef_var.array(), coef_vec.head(num_alpha), rng); prior_alpha_prec.head(num_alpha) = 1 / ((global_lev * local_lev.array() * coef_var.array()).square() * latent_local.array()); } void updatePenalty() override {