diff --git a/stan/math/prim/prob.hpp b/stan/math/prim/prob.hpp index 7e278d68a18..919f63b62c0 100644 --- a/stan/math/prim/prob.hpp +++ b/stan/math/prim/prob.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/prob/beta_neg_binomial_cdf.hpp b/stan/math/prim/prob/beta_neg_binomial_cdf.hpp new file mode 100644 index 00000000000..a762e9ae9f4 --- /dev/null +++ b/stan/math/prim/prob/beta_neg_binomial_cdf.hpp @@ -0,0 +1,168 @@ +#ifndef STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_CDF_HPP +#define STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_CDF_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** \ingroup prob_dists + * Returns the CDF of the Beta-Negative Binomial distribution with given + * number of successes, prior success, and prior failure parameters. + * Given containers of matching sizes, returns the product of probabilities. + * + * @tparam T_n type of failure parameter + * @tparam T_r type of number of successes parameter + * @tparam T_alpha type of prior success parameter + * @tparam T_beta type of prior failure parameter + * + * @param n failure parameter + * @param r Number of successes parameter + * @param alpha prior success parameter + * @param beta prior failure parameter + * @param precision precision for `grad_F32`, default \f$10^{-8}\f$ + * @param max_steps max iteration allowed for `grad_F32`, default \f$10^{8}\f$ + * @return probability or sum of probabilities + * @throw std::domain_error if r, alpha, or beta fails to be positive + * @throw std::invalid_argument if container sizes mismatch + */ +template +inline return_type_t beta_neg_binomial_cdf( + const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta, + const double precision = 1e-8, const int max_steps = 1e8) { + static constexpr const char* function = "beta_neg_binomial_cdf"; + check_consistent_sizes( + function, "Failures variable", n, "Number of successes parameter", r, + "Prior success parameter", alpha, "Prior failure parameter", beta); + if (size_zero(n, r, alpha, beta)) { + return 1.0; + } + + using T_r_ref = ref_type_t; + T_r_ref r_ref = r; + using T_alpha_ref = ref_type_t; + T_alpha_ref alpha_ref = alpha; + using T_beta_ref = ref_type_t; + T_beta_ref beta_ref = beta; + check_positive_finite(function, "Number of successes parameter", r_ref); + check_positive_finite(function, "Prior success parameter", alpha_ref); + check_positive_finite(function, "Prior failure parameter", beta_ref); + + scalar_seq_view n_vec(n); + scalar_seq_view r_vec(r_ref); + scalar_seq_view alpha_vec(alpha_ref); + scalar_seq_view beta_vec(beta_ref); + int size_n = stan::math::size(n); + size_t max_size_seq_view = max_size(n, r, alpha, beta); + + // Explicit return for extreme values + // The gradients are technically ill-defined, but treated as zero + for (int i = 0; i < size_n; i++) { + if (n_vec.val(i) < 0) { + return 0.0; + } + } + + using T_partials_return = partials_return_t; + T_partials_return cdf(1.0); + auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref); + for (size_t i = 0; i < max_size_seq_view; i++) { + // Explicit return for extreme values + // The gradients are technically ill-defined, but treated as zero + if (n_vec.val(i) == std::numeric_limits::max()) { + return 1.0; + } + auto n_dbl = n_vec.val(i); + auto r_dbl = r_vec.val(i); + auto alpha_dbl = alpha_vec.val(i); + auto beta_dbl = beta_vec.val(i); + auto b_plus_n = beta_dbl + n_dbl; + auto r_plus_n = r_dbl + n_dbl; + auto a_plus_r = alpha_dbl + r_dbl; + using a_t = return_type_t; + using b_t = return_type_t; + auto F = hypergeometric_3F2( + std::initializer_list{1.0, b_plus_n + 1.0, r_plus_n + 1.0}, + std::initializer_list{n_dbl + 2.0, a_plus_r + b_plus_n + 1.0}, + 1.0); + auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0) + - lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl) - lgamma(n_dbl + 2.0); + auto ccdf = stan::math::exp(C) * F; + cdf *= 1.0 - ccdf; + + if constexpr (!is_constant_all::value) { + auto chain_rule_term = -ccdf / (1.0 - ccdf); + auto digamma_n_r_alpha_beta = digamma(a_plus_r + b_plus_n + 1.0); + T_partials_return dF[6]; + grad_F32::value, !is_constant_all::value, + false, true, false>(dF, 1.0, b_plus_n + 1.0, r_plus_n + 1.0, + n_dbl + 2.0, a_plus_r + b_plus_n + 1.0, 1.0, + precision, max_steps); + + if constexpr (!is_constant::value || !is_constant::value) { + auto digamma_r_alpha = digamma(a_plus_r); + if constexpr (!is_constant::value) { + auto partial_lccdf = digamma(r_plus_n + 1.0) + + (digamma_r_alpha - digamma_n_r_alpha_beta) + + (dF[2] + dF[4]) / F - digamma(r_dbl); + partials<0>(ops_partials)[i] += partial_lccdf * chain_rule_term; + } + if constexpr (!is_constant::value) { + auto partial_lccdf = digamma_r_alpha - digamma_n_r_alpha_beta + + dF[4] / F - digamma(alpha_dbl); + partials<1>(ops_partials)[i] += partial_lccdf * chain_rule_term; + } + } + + if constexpr (!is_constant::value + || !is_constant::value) { + auto digamma_alpha_beta = digamma(alpha_dbl + beta_dbl); + if constexpr (!is_constant::value) { + partials<1>(ops_partials)[i] += digamma_alpha_beta * chain_rule_term; + } + if constexpr (!is_constant::value) { + auto partial_lccdf = digamma(b_plus_n + 1.0) - digamma_n_r_alpha_beta + + (dF[1] + dF[4]) / F + - (digamma(beta_dbl) - digamma_alpha_beta); + partials<2>(ops_partials)[i] += partial_lccdf * chain_rule_term; + } + } + } + } + + if constexpr (!is_constant::value) { + for (size_t i = 0; i < stan::math::size(r); ++i) { + partials<0>(ops_partials)[i] *= cdf; + } + } + if constexpr (!is_constant::value) { + for (size_t i = 0; i < stan::math::size(alpha); ++i) { + partials<1>(ops_partials)[i] *= cdf; + } + } + if constexpr (!is_constant::value) { + for (size_t i = 0; i < stan::math::size(beta); ++i) { + partials<2>(ops_partials)[i] *= cdf; + } + } + + return ops_partials.build(cdf); +} + +} // namespace math +} // namespace stan +#endif diff --git a/test/prob/beta_neg_binomial/beta_neg_binomial_cdf_test.hpp b/test/prob/beta_neg_binomial/beta_neg_binomial_cdf_test.hpp new file mode 100644 index 00000000000..8b80dcefa17 --- /dev/null +++ b/test/prob/beta_neg_binomial/beta_neg_binomial_cdf_test.hpp @@ -0,0 +1,91 @@ +// Arguments: Ints, Doubles, Doubles, Doubles +#include +#include +#include + +using stan::math::var; +using std::numeric_limits; +using std::vector; + +class AgradCdfBetaNegBinomial : public AgradCdfTest { + public: + void valid_values(vector>& parameters, vector& cdf) { + vector param(4); + + param[0] = 0; // n + param[1] = 1.0; // r + param[2] = 5.0; // alpha + param[3] = 1.0; // beta + parameters.push_back(param); + cdf.push_back(0.833333333333333); // expected cdf + } + + void invalid_values(vector& index, vector& value) { + // n + + // r + index.push_back(1U); + value.push_back(0.0); + + index.push_back(1U); + value.push_back(-1.0); + + index.push_back(1U); + value.push_back(std::numeric_limits::infinity()); + + // alpha + index.push_back(2U); + value.push_back(0.0); + + index.push_back(2U); + value.push_back(-1.0); + + index.push_back(2U); + value.push_back(std::numeric_limits::infinity()); + + // beta + index.push_back(3U); + value.push_back(0.0); + + index.push_back(3U); + value.push_back(-1.0); + + index.push_back(3U); + value.push_back(std::numeric_limits::infinity()); + } + + // BOUND INCLUDED IN ORDER FOR TEST TO PASS WITH CURRENT FRAMEWORK + bool has_lower_bound() { return false; } + + bool has_upper_bound() { return false; } + + template + stan::return_type_t cdf(const T_n& n, const T_r& r, + const T_size1& alpha, + const T_size2& beta, const T4&, + const T5&) { + return stan::math::beta_neg_binomial_cdf(n, r, alpha, beta); + } + + template + stan::return_type_t cdf_function( + const T_n& n, const T_r& r, const T_size1& alpha, const T_size2& beta, + const T4&, const T5&) { + using stan::math::lbeta; + using stan::math::lgamma; + using stan::math::log_sum_exp; + using std::vector; + + vector> lpmf_values; + + for (int i = 0; i <= n; i++) { + auto lpmf = lbeta(i + r, alpha + beta) - lbeta(r, alpha) + + lgamma(i + beta) - lgamma(i + 1) - lgamma(beta); + lpmf_values.push_back(lpmf); + } + + return exp(log_sum_exp(lpmf_values)); + } +};