From 7789c57fd6e5ba77f874c896ac8b5cae227de131 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 13 Jan 2025 23:57:50 +0800 Subject: [PATCH 1/2] Simplify vectorisation of std_normal_lpdf --- stan/math/prim/fun/dot_self.hpp | 5 +++++ stan/math/prim/prob/std_normal_lpdf.hpp | 22 ++++++++-------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/stan/math/prim/fun/dot_self.hpp b/stan/math/prim/fun/dot_self.hpp index 89d5deffe25..c9d67f56e86 100644 --- a/stan/math/prim/fun/dot_self.hpp +++ b/stan/math/prim/fun/dot_self.hpp @@ -10,6 +10,11 @@ namespace stan { namespace math { +template * = nullptr> +inline T dot_self(const T& x) { + return x * x; +} + inline double dot_self(const std::vector& x) { double sum = 0.0; for (double i : x) { diff --git a/stan/math/prim/prob/std_normal_lpdf.hpp b/stan/math/prim/prob/std_normal_lpdf.hpp index de1e4c15698..afe6fb3b70f 100644 --- a/stan/math/prim/prob/std_normal_lpdf.hpp +++ b/stan/math/prim/prob/std_normal_lpdf.hpp @@ -4,10 +4,10 @@ #include #include #include -#include #include #include -#include +#include +#include #include namespace stan { @@ -43,22 +43,16 @@ return_type_t std_normal_lpdf(const T_y& y) { return 0.0; } - T_partials_return logp(0.0); + T_partials_return y_val = as_value_column_vector_or_scalar(y_ref); + T_partials_return logp = -dot_self(y_val) / 2.0; auto ops_partials = make_partials_propagator(y_ref); - scalar_seq_view y_vec(y_ref); - size_t N = stan::math::size(y); - - for (size_t n = 0; n < N; n++) { - const T_partials_return y_val = y_vec.val(n); - logp += y_val * y_val; - if (!is_constant_all::value) { - partials<0>(ops_partials)[n] -= y_val; - } + if (!is_constant_all::value) { + partials<0>(ops_partials) = -y_val; } - logp *= -0.5; + if (include_summand::value) { - logp += NEG_LOG_SQRT_TWO_PI * N; + logp += NEG_LOG_SQRT_TWO_PI * math::size(y); } return ops_partials.build(logp); From a929e7eb519bdcfe678f735398cb196c3cad49e9 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Tue, 14 Jan 2025 20:32:04 +0800 Subject: [PATCH 2/2] Fix opencl --- stan/math/prim/prob/std_normal_lpdf.hpp | 2 +- test/unit/math/mix/prob/std_normal_test.cpp | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/stan/math/prim/prob/std_normal_lpdf.hpp b/stan/math/prim/prob/std_normal_lpdf.hpp index afe6fb3b70f..05585e38aab 100644 --- a/stan/math/prim/prob/std_normal_lpdf.hpp +++ b/stan/math/prim/prob/std_normal_lpdf.hpp @@ -43,7 +43,7 @@ return_type_t std_normal_lpdf(const T_y& y) { return 0.0; } - T_partials_return y_val = as_value_column_vector_or_scalar(y_ref); + const auto& y_val = as_value_column_vector_or_scalar(y_ref); T_partials_return logp = -dot_self(y_val) / 2.0; auto ops_partials = make_partials_propagator(y_ref); diff --git a/test/unit/math/mix/prob/std_normal_test.cpp b/test/unit/math/mix/prob/std_normal_test.cpp index 22a3483409a..828293e0084 100644 --- a/test/unit/math/mix/prob/std_normal_test.cpp +++ b/test/unit/math/mix/prob/std_normal_test.cpp @@ -7,4 +7,12 @@ TEST_F(AgradRev, mathMixScalFun_std_normal) { stan::test::expect_ad(f, -0.3); stan::test::expect_ad(f, 0.0); stan::test::expect_ad(f, 1.7); + + Eigen::VectorXd x(3); + x << -0.3, 0.0, 1.7; + std::vector x2{0.0, 1.7}; + + stan::test::expect_ad(f, x); + stan::test::expect_ad(f, x.transpose().eval()); + stan::test::expect_ad(f, x2); }