From c8bcb99de4e3d128d48ea89d29ecb92d5db7d06a Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 31 Oct 2024 16:00:14 -0400 Subject: [PATCH] Cherry-pick diagnose changes from #1290 --- src/cmdstan/diagnose.cpp | 110 +++++++++--------- .../interface/example_output/corr_gauss.nom | 4 +- .../example_output/corr_gauss_depth15.nom | 4 +- .../example_output/corr_gauss_depth8.nom | 4 +- .../example_output/eight_schools.nom | 4 +- src/test/interface/example_output/mix.nom | 8 +- 6 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/cmdstan/diagnose.cpp b/src/cmdstan/diagnose.cpp index 76d8fdbbf8..9a451e9d57 100644 --- a/src/cmdstan/diagnose.cpp +++ b/src/cmdstan/diagnose.cpp @@ -1,12 +1,15 @@ +#include #include -#include +#include #include #include #include #include #include -double RHAT_MAX = 1.05; +using cmdstan::return_codes; + +double RHAT_MAX = 1.01499; // round to 1.01 void diagnose_usage() { std::cout << "USAGE: diagnose [ ... ]" @@ -26,7 +29,7 @@ void diagnose_usage() { int main(int argc, const char *argv[]) { if (argc == 1) { diagnose_usage(); - return 0; + return return_codes::OK; } // Parse any arguments specifying filenames @@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) { if (!filenames.size()) { std::cout << "No valid input files, exiting." << std::endl; - return 0; + return return_codes::NOT_OK; } std::cout << std::fixed << std::setprecision(2); - // Parse specified files - std::cout << "Processing csv files: " << filenames[0]; - ifstream.open(filenames[0].c_str()); - - stan::io::stan_csv stan_csv - = stan::io::stan_csv_reader::parse(ifstream, &std::cout); - stan::mcmc::chains<> chains(stan_csv); - ifstream.close(); - - if (filenames.size() > 1) - std::cout << ", "; - else - std::cout << std::endl << std::endl; - - for (std::vector::size_type chain = 1; chain < filenames.size(); - ++chain) { - std::cout << filenames[chain]; - ifstream.open(filenames[chain].c_str()); - stan_csv = stan::io::stan_csv_reader::parse(ifstream, &std::cout); - chains.add(stan_csv); - ifstream.close(); - if (chain < filenames.size() - 1) - std::cout << ", "; - else - std::cout << std::endl << std::endl; + std::vector csv_parsed; + for (int i = 0; i < filenames.size(); ++i) { + std::ifstream infile; + std::stringstream out; + stan::io::stan_csv sample; + infile.open(filenames[i].c_str()); + try { + sample = stan::io::stan_csv_reader::parse(infile, &out); + // csv_reader warnings are errors - fail fast. + if (!out.str().empty()) { + throw std::invalid_argument(out.str()); + } + csv_parsed.push_back(sample); + } catch (const std::invalid_argument &e) { + std::cout << "Cannot parse input csv file: " << filenames[i] << e.what() + << "." << std::endl; + return return_codes::NOT_OK; + } } - + stan::mcmc::chainset chains(csv_parsed); + stan::io::stan_csv_metadata metadata = csv_parsed[0].metadata; + std::vector param_names = csv_parsed[0].header; + size_t num_params = param_names.size(); int num_samples = chains.num_samples(); std::vector bad_n_eff_names; std::vector bad_rhat_names; bool has_errors = false; - for (int i = 0; i < chains.num_params(); ++i) { - if (chains.param_name(i) == std::string("treedepth__")) { + for (int i = 0; i < num_params; ++i) { + if (param_names[i] == std::string("treedepth__")) { std::cout << "Checking sampler transitions treedepth." << std::endl; - int max_limit = stan_csv.metadata.max_depth; + int max_limit = metadata.max_depth; long n_max = 0; - Eigen::VectorXd t_samples = chains.samples(i); + Eigen::MatrixXd draws = chains.samples(i); + Eigen::VectorXd t_samples + = Eigen::Map(draws.data(), draws.size()); for (long n = 0; n < t_samples.size(); ++n) { if (t_samples(n) >= max_limit) { ++n_max; @@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) { std::cout << "Treedepth satisfactory for all transitions." << std::endl << std::endl; } - } else if (chains.param_name(i) == std::string("divergent__")) { + } else if (param_names[i] == std::string("divergent__")) { std::cout << "Checking sampler transitions for divergences." << std::endl; int n_divergent = chains.samples(i).sum(); if (n_divergent > 0) { @@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) { std::cout << "No divergent transitions found." << std::endl << std::endl; } - } else if (chains.param_name(i) == std::string("energy__")) { + } else if (param_names[i] == std::string("energy__")) { std::cout << "Checking E-BFMI - sampler transitions HMC potential energy." << std::endl; - Eigen::VectorXd e_samples = chains.samples(i); + Eigen::MatrixXd draws = chains.samples(i); + Eigen::VectorXd e_samples + = Eigen::Map(draws.data(), draws.size()); double delta_e_sq_mean = 0; - double e_mean = 0; - double e_var = 0; - e_mean += e_samples(0); - e_var += e_samples(0) * (e_samples(0) - e_mean); + double e_mean = chains.mean(i); + double e_var = chains.variance(i); for (long n = 1; n < e_samples.size(); ++n) { double e = e_samples(n); double delta_e_sq = (e - e_samples(n - 1)) * (e - e_samples(n - 1)); double d = delta_e_sq - delta_e_sq_mean; delta_e_sq_mean += d / n; d = e - e_mean; - e_mean += d / (n + 1); - e_var += d * (e - e_mean); } - - e_var /= static_cast(e_samples.size() - 1); double e_bfmi = delta_e_sq_mean / e_var; double e_bfmi_threshold = 0.3; if (e_bfmi < e_bfmi_threshold) { @@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) { } else { std::cout << "E-BFMI satisfactory." << std::endl << std::endl; } - } else if (chains.param_name(i).find("__") == std::string::npos) { - double n_eff = chains.effective_sample_size(i); + } else if (param_names[i].find("__") == std::string::npos) { + auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i); + double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail; if (n_eff / num_samples < 0.001) - bad_n_eff_names.push_back(chains.param_name(i)); + bad_n_eff_names.push_back(param_names[i]); - double split_rhat = chains.split_potential_scale_reduction(i); + auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i); + double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail; if (split_rhat > RHAT_MAX) - bad_rhat_names.push_back(chains.param_name(i)); + bad_rhat_names.push_back(param_names[i]); } } if (bad_n_eff_names.size() > 0) { @@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) { << " may be substantially lower than quoted." << std::endl << std::endl; } else { - std::cout << "Effective sample size satisfactory." << std::endl + std::cout << "Rank-normalized split effective sample size satisfactory " + << "for all parameters." << std::endl << std::endl; } if (bad_rhat_names.size() > 0) { has_errors = true; - std::cout << "The following parameters had split R-hat greater than " + std::cout << "The following parameters had rank-normalized split R-hat " + "greater than " << RHAT_MAX << ":" << std::endl; std::cout << " "; for (size_t n = 0; n < bad_rhat_names.size() - 1; ++n) @@ -207,7 +208,8 @@ int main(int argc, const char *argv[]) { << " effective parameterization." << std::endl << std::endl; } else { - std::cout << "Split R-hat values satisfactory all parameters." << std::endl + std::cout << "Rank-normalized split R-hat values satisfactory " + << "for all parameters." << std::endl << std::endl; } if (!has_errors) @@ -215,5 +217,5 @@ int main(int argc, const char *argv[]) { else std::cout << "Processing complete." << std::endl; - return 0; + return return_codes::OK; } diff --git a/src/test/interface/example_output/corr_gauss.nom b/src/test/interface/example_output/corr_gauss.nom index a6cd709481..64d07f8279 100644 --- a/src/test/interface/example_output/corr_gauss.nom +++ b/src/test/interface/example_output/corr_gauss.nom @@ -9,8 +9,8 @@ No divergent transitions found. Checking E-BFMI - sampler transitions HMC potential energy. E-BFMI satisfactory. -Effective sample size satisfactory. +Rank-normalized split effective sample size satisfactory for all parameters. -Split R-hat values satisfactory all parameters. +Rank-normalized split R-hat values satisfactory for all parameters. Processing complete. diff --git a/src/test/interface/example_output/corr_gauss_depth15.nom b/src/test/interface/example_output/corr_gauss_depth15.nom index f348c9e339..aa9095d6c7 100644 --- a/src/test/interface/example_output/corr_gauss_depth15.nom +++ b/src/test/interface/example_output/corr_gauss_depth15.nom @@ -7,8 +7,8 @@ No divergent transitions found. Checking E-BFMI - sampler transitions HMC potential energy. E-BFMI satisfactory. -Effective sample size satisfactory. +Rank-normalized split effective sample size satisfactory for all parameters. -Split R-hat values satisfactory all parameters. +Rank-normalized split R-hat values satisfactory for all parameters. Processing complete, no problems detected. diff --git a/src/test/interface/example_output/corr_gauss_depth8.nom b/src/test/interface/example_output/corr_gauss_depth8.nom index 32c345dc92..5a476ee328 100644 --- a/src/test/interface/example_output/corr_gauss_depth8.nom +++ b/src/test/interface/example_output/corr_gauss_depth8.nom @@ -9,8 +9,8 @@ No divergent transitions found. Checking E-BFMI - sampler transitions HMC potential energy. E-BFMI satisfactory. -Effective sample size satisfactory. +Rank-normalized split effective sample size satisfactory for all parameters. -Split R-hat values satisfactory all parameters. +Rank-normalized split R-hat values satisfactory for all parameters. Processing complete. diff --git a/src/test/interface/example_output/eight_schools.nom b/src/test/interface/example_output/eight_schools.nom index b306a5bb6b..ae4fc866dc 100644 --- a/src/test/interface/example_output/eight_schools.nom +++ b/src/test/interface/example_output/eight_schools.nom @@ -11,8 +11,8 @@ Checking E-BFMI - sampler transitions HMC potential energy. The E-BFMI, 0.26, is below the nominal threshold of 0.30 which suggests that HMC may have trouble exploring the target distribution. If possible, try to reparameterize the model. -Effective sample size satisfactory. +Rank-normalized split effective sample size satisfactory for all parameters. -Split R-hat values satisfactory all parameters. +Rank-normalized split R-hat values satisfactory for all parameters. Processing complete. diff --git a/src/test/interface/example_output/mix.nom b/src/test/interface/example_output/mix.nom index fbe76363e9..11693838f2 100644 --- a/src/test/interface/example_output/mix.nom +++ b/src/test/interface/example_output/mix.nom @@ -7,12 +7,10 @@ No divergent transitions found. Checking E-BFMI - sampler transitions HMC potential energy. E-BFMI satisfactory. -The following parameters had fewer than 0.001 effective draws per transition: - mu[1], mu[2], theta -Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted. +Rank-normalized split effective sample size satisfactory for all parameters. -The following parameters had split R-hat greater than 1.05: - mu[1], mu[2], theta +The following parameters had rank-normalized split R-hat greater than 1.01: + mu[1], mu[2], sigma[1], theta Such high values indicate incomplete mixing and biased estimation. You should consider regularizating your model with additional prior information or a more effective parameterization.