diff --git a/stan/math/prim/fun/dot_self.hpp b/stan/math/prim/fun/dot_self.hpp index 89d5deffe25..c9d67f56e86 100644 --- a/stan/math/prim/fun/dot_self.hpp +++ b/stan/math/prim/fun/dot_self.hpp @@ -10,6 +10,11 @@ namespace stan { namespace math { +template * = nullptr> +inline T dot_self(const T& x) { + return x * x; +} + inline double dot_self(const std::vector& x) { double sum = 0.0; for (double i : x) { diff --git a/stan/math/prim/prob/std_normal_lpdf.hpp b/stan/math/prim/prob/std_normal_lpdf.hpp index de1e4c15698..05585e38aab 100644 --- a/stan/math/prim/prob/std_normal_lpdf.hpp +++ b/stan/math/prim/prob/std_normal_lpdf.hpp @@ -4,10 +4,10 @@ #include #include #include -#include #include #include -#include +#include +#include #include namespace stan { @@ -43,22 +43,16 @@ return_type_t std_normal_lpdf(const T_y& y) { return 0.0; } - T_partials_return logp(0.0); + const auto& y_val = as_value_column_vector_or_scalar(y_ref); + T_partials_return logp = -dot_self(y_val) / 2.0; auto ops_partials = make_partials_propagator(y_ref); - scalar_seq_view y_vec(y_ref); - size_t N = stan::math::size(y); - - for (size_t n = 0; n < N; n++) { - const T_partials_return y_val = y_vec.val(n); - logp += y_val * y_val; - if (!is_constant_all::value) { - partials<0>(ops_partials)[n] -= y_val; - } + if (!is_constant_all::value) { + partials<0>(ops_partials) = -y_val; } - logp *= -0.5; + if (include_summand::value) { - logp += NEG_LOG_SQRT_TWO_PI * N; + logp += NEG_LOG_SQRT_TWO_PI * math::size(y); } return ops_partials.build(logp); diff --git a/test/unit/math/mix/prob/std_normal_test.cpp b/test/unit/math/mix/prob/std_normal_test.cpp index 22a3483409a..828293e0084 100644 --- a/test/unit/math/mix/prob/std_normal_test.cpp +++ b/test/unit/math/mix/prob/std_normal_test.cpp @@ -7,4 +7,12 @@ TEST_F(AgradRev, mathMixScalFun_std_normal) { stan::test::expect_ad(f, -0.3); stan::test::expect_ad(f, 0.0); stan::test::expect_ad(f, 1.7); + + Eigen::VectorXd x(3); + x << -0.3, 0.0, 1.7; + std::vector x2{0.0, 1.7}; + + stan::test::expect_ad(f, x); + stan::test::expect_ad(f, x.transpose().eval()); + stan::test::expect_ad(f, x2); }