diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index f145094325b..a9685d7ca00 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -32,17 +32,18 @@ inline auto log_softmax(T&& x) { * * @tparam Vec Eigen vector with `fvar` scalar * @param x vector to transform - * @return log softmax of the vector - * @throw std::domain_error if the input size is 0 + * @return log softmax of the vector, or an empty result if the input is empty */ template * = nullptr> inline auto log_softmax(Vec&& x) { using vec = std::decay_t; constexpr int Rows = vec::RowsAtCompileTime; constexpr int Cols = vec::ColsAtCompileTime; - using T = typename value_type_t::Scalar; - check_nonzero_size("log_softmax", "x", x); + using T = typename value_type_t::Scalar; decltype(auto) x_ref = to_ref(std::forward(x)); + if (x_ref.size() == 0) { + return Eigen::Matrix, Rows, Cols>{}; + } const auto s = softmax(value_of(x_ref)); const auto d_in = x_ref.d(); const auto dot_sd = s.dot(d_in); diff --git a/stan/math/fwd/fun/softmax.hpp b/stan/math/fwd/fun/softmax.hpp index 16f9cb5554e..97d019b6c23 100644 --- a/stan/math/fwd/fun/softmax.hpp +++ b/stan/math/fwd/fun/softmax.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,7 @@ inline auto softmax(T&& x) { * * @tparam Vec Eigen vector with `fvar` scalar * @param x vector to transform - * @return softmax of the vector + * @return softmax of the vector, or an empty result if the input is empty */ template * = nullptr> inline auto softmax(Vec&& x) { @@ -38,10 +39,10 @@ inline auto softmax(Vec&& x) { constexpr int Rows = vec::RowsAtCompileTime; constexpr int Cols = vec::ColsAtCompileTime; using T = typename value_type_t::Scalar; - if (x.size() == 0) { - return Eigen::Matrix, Rows, Cols>(); - } decltype(auto) x_ref = to_ref(std::forward(x)); + if (x_ref.size() == 0) { + return Eigen::Matrix, Rows, Cols>{}; + } const auto s = softmax(value_of(x_ref)); const auto d_in = x_ref.d(); const auto dot_sd = s.dot(d_in); diff --git a/stan/math/opencl/prim/log_softmax.hpp b/stan/math/opencl/prim/log_softmax.hpp index 49a32a29c31..32a961e9517 100644 --- a/stan/math/opencl/prim/log_softmax.hpp +++ b/stan/math/opencl/prim/log_softmax.hpp @@ -21,7 +21,9 @@ namespace math { template * = nullptr> inline matrix_cl log_softmax(const T& a) { - check_nonzero_size("log_softmax (OpenCL)", "x", a); + if (a.size() == 0) { + return matrix_cl(a.rows(), a.cols()); + } return make_holder_cl([](auto&& x) { return x - log_sum_exp(x); }, to_ref(a)); } diff --git a/stan/math/opencl/prim/softmax.hpp b/stan/math/opencl/prim/softmax.hpp index 12c6c310f05..7f34ac4fbea 100644 --- a/stan/math/opencl/prim/softmax.hpp +++ b/stan/math/opencl/prim/softmax.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace stan { @@ -23,7 +24,7 @@ template softmax(const T& a) { check_vector("softmax (OpenCL)", "a", a); if (a.size() == 0) { - return a; + return matrix_cl(a.rows(), a.cols()); } matrix_cl theta; if constexpr (stan::internal::is_trivial_kg_expression::value) { diff --git a/stan/math/opencl/rev/softmax.hpp b/stan/math/opencl/rev/softmax.hpp index 5cc0384730e..5d888f85db1 100644 --- a/stan/math/opencl/rev/softmax.hpp +++ b/stan/math/opencl/rev/softmax.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -22,11 +23,11 @@ namespace math { template * = nullptr> inline var_value> softmax(const var_value& A) { - if (A.size() == 0) { - return A; - } return make_callback_var( softmax(A.val()), [A](vari_value>& res) mutable { + if (res.val().size() == 0) { + return; + } A.adj() += elt_multiply( res.val(), (res.adj() - dot_product(res.adj(), res.val()))); }); diff --git a/stan/math/prim/fun/log_softmax.hpp b/stan/math/prim/fun/log_softmax.hpp index ac4393adfa8..d6c9017edb6 100644 --- a/stan/math/prim/fun/log_softmax.hpp +++ b/stan/math/prim/fun/log_softmax.hpp @@ -37,9 +37,9 @@ namespace math { * * @tparam Container type of input: an Eigen vector, `std::vector` of doubles, * or nested container whose scalar type is arithmetic - * @param[in] x vector or container of vectors to transform - * @return log softmax of the input, preserving the container structure - * @throw std::domain_error if any input vector is empty + * @param x vector or container of vectors to transform + * @return log softmax of the input, preserving the container structure; an + * empty result if any input vector is empty */ template * = nullptr, require_container_t* = nullptr, @@ -47,12 +47,16 @@ template * = nullptr, is_eigen>::value && !is_eigen_vector>::value>>* = nullptr> inline auto log_softmax(Container&& x) { - check_nonzero_size("log_softmax", "x", x); return make_holder( [](auto&& a) { return apply_vector_unary>::apply( std::forward(a), - [](auto&& v) { return v.array() - log_sum_exp(v); }); + [](auto&& v) -> plain_type_t { + if (v.size() == 0) { + return v; + } + return (v.array() - log_sum_exp(v)).matrix(); + }); }, to_ref(std::forward(x))); } diff --git a/stan/math/prim/fun/softmax.hpp b/stan/math/prim/fun/softmax.hpp index 01443e9856e..63a51876a60 100644 --- a/stan/math/prim/fun/softmax.hpp +++ b/stan/math/prim/fun/softmax.hpp @@ -1,19 +1,18 @@ #ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP #define STAN_MATH_PRIM_FUN_SOFTMAX_HPP +#include #include #include #include #include -#include namespace stan { namespace math { /** - * Return the softmax of the specified vector. + * Return the softmax of the specified vector, or of each vector in a container. * - *

* \f$ * \mbox{softmax}(y) * = \frac{\exp(y)} @@ -39,36 +38,33 @@ namespace math { * \end{array} * \f$ * - * @tparam Vec type of the input vector - * @param[in] v Vector to transform. - * @return Unit simplex result of the softmax transform of the vector. + * @tparam Container type of input: an Eigen vector, `std::vector` of doubles, + * or nested container whose scalar type is arithmetic + * @param x vector or container of vectors to transform + * @return softmax of the input, preserving the container structure; an empty + * result if any input vector is empty */ -template * = nullptr> -inline plain_type_t softmax(Vec&& v) { - if (v.size() == 0) { - return v; - } - decltype(auto) v_ref = to_ref(std::forward(v)); - const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp(); - return (theta / theta.sum()).matrix(); -} - -/** - * Return the softmax of each vector in an array. - * - * @tparam T `std::vector` whose scalar type is arithmetic - * @param[in] x Array of vectors to transform. - * @return Array of unit simplex results. - */ -template * = nullptr> -inline auto softmax(T&& x) { - return apply_vector_unary::apply(std::forward(x), [](auto&& v) { - return softmax(std::forward(v)); - }); +template * = nullptr, + require_container_t* = nullptr, + require_not_t>::value + && !is_eigen_vector>::value>>* = nullptr> +inline auto softmax(Container&& x) { + return make_holder( + [](auto&& a) { + return apply_vector_unary>::apply( + std::forward(a), + [](auto&& v) -> plain_type_t { + if (v.size() == 0) { + return v; + } + const auto theta = (v.array() - v.maxCoeff()).exp(); + return (theta / theta.sum()).matrix(); + }); + }, + to_ref(std::forward(x))); } } // namespace math } // namespace stan - #endif diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index 47a59104bd4..4c77c09e15f 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -18,13 +18,14 @@ namespace math { * * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar * @param x input - * @return log softmax of the input - * @throw std::domain_error if the input size is 0 + * @return log softmax of the input, or an empty result if the input is empty */ template * = nullptr> inline auto log_softmax(T&& x) { - check_nonzero_size("log_softmax", "x", x); auto x_arena = to_arena(std::forward(x)); + if (x_arena.size() == 0) { + return x_arena; + } using return_t = return_var_matrix_t, T>; arena_t res = log_softmax(x_arena.val()); @@ -42,7 +43,6 @@ inline auto log_softmax(T&& x) { * @tparam T `std::vector` whose scalar type is `var` * @param x array of vectors to transform * @return array of log softmax results - * @throw std::domain_error if any element size is 0 */ template * = nullptr> inline auto log_softmax(T&& x) { diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index 8ef25d76d70..cd377c9cb1d 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -18,16 +19,16 @@ namespace math { * * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar * @param x input - * @return softmax of the input + * @return softmax of the input, or an empty result if the input is empty */ template * = nullptr> inline auto softmax(T&& x) { auto x_arena = to_arena(std::forward(x)); - using return_t - = return_var_matrix_t, T>; if (x_arena.size() == 0) { return x_arena; } + using return_t + = return_var_matrix_t, T>; arena_t res = softmax(x_arena.val()); reverse_pass_callback([x_arena, res]() mutable { x_arena.adj().array() diff --git a/test/unit/math/mix/fun/log_softmax_test.cpp b/test/unit/math/mix/fun/log_softmax_test.cpp index 5d8915fea10..b50b769c26d 100644 --- a/test/unit/math/mix/fun/log_softmax_test.cpp +++ b/test/unit/math/mix/fun/log_softmax_test.cpp @@ -4,7 +4,7 @@ TEST(MathMixMatFun, logSoftmax) { auto f = [](const auto& x) { return stan::math::log_softmax(x); }; // Column Vectors - Eigen::VectorXd x0(0); // error case + Eigen::VectorXd x0(0); stan::test::expect_ad(f, x0); stan::test::expect_ad_matvar(f, x0); @@ -34,7 +34,7 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad_matvar(f, x3c); // Row Vectors - Eigen::RowVectorXd rx0(0); // error case + Eigen::RowVectorXd rx0(0); stan::test::expect_ad(f, rx0); stan::test::expect_ad_matvar(f, rx0); @@ -64,7 +64,7 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad_matvar(f, rx3c); // std vectors - std::vector stx0(0); // error case + std::vector stx0(0); stan::test::expect_ad(f, stx0); std::vector stx1{0}; @@ -83,7 +83,7 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad(f, stx3c); // Nested containers - std::vector stvx0{x0, x0}; // error case + std::vector stvx0{x0, x0}; stan::test::expect_ad(f, stvx0); stan::test::expect_ad_matvar(f, stvx0); @@ -91,7 +91,7 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad(f, stvx1); stan::test::expect_ad_matvar(f, stvx1); - std::vector strx0{rx0, rx0}; // error case + std::vector strx0{rx0, rx0}; stan::test::expect_ad(f, strx0); stan::test::expect_ad_matvar(f, strx0); @@ -99,7 +99,7 @@ TEST(MathMixMatFun, logSoftmax) { stan::test::expect_ad(f, strx1); stan::test::expect_ad_matvar(f, strx1); - std::vector> ststx0{stx0, stx0}; // error case + std::vector> ststx0{stx0, stx0}; stan::test::expect_ad(f, ststx0); std::vector> ststx1{stx1, stx1}; diff --git a/test/unit/math/mix/fun/softmax_test.cpp b/test/unit/math/mix/fun/softmax_test.cpp index 248bd975376..97b3fc64332 100644 --- a/test/unit/math/mix/fun/softmax_test.cpp +++ b/test/unit/math/mix/fun/softmax_test.cpp @@ -69,7 +69,7 @@ TEST(MathMixMatFun, softmax) { expect_ad_matvar(f, rd2); // Arrays of vectors (array[] vector and array[] row_vector) - std::vector stvx0{a, a}; // error case + std::vector stvx0{a, a}; stan::test::expect_ad(tols, f, stvx0); expect_ad_matvar(f, stvx0); @@ -81,7 +81,7 @@ TEST(MathMixMatFun, softmax) { stan::test::expect_ad(tols, f, stvx2); expect_ad_matvar(f, stvx2); - std::vector strx0{ra, ra}; // error case + std::vector strx0{ra, ra}; stan::test::expect_ad(tols, f, strx0); expect_ad_matvar(f, strx0); diff --git a/test/unit/math/opencl/rev/log_softmax_test.cpp b/test/unit/math/opencl/rev/log_softmax_test.cpp index b9efb726921..2a258cf5295 100644 --- a/test/unit/math/opencl/rev/log_softmax_test.cpp +++ b/test/unit/math/opencl/rev/log_softmax_test.cpp @@ -22,6 +22,11 @@ TEST(OpenCLLogSoftmax, prim_rev_size_1) { stan::math::test::compare_cpu_opencl_prim_rev(log_softmax_functor, a); } +TEST(OpenCLLogSoftmax, prim_rev_size_0) { + Eigen::VectorXd a(0); + EXPECT_EQ(0, stan::math::log_softmax(a).size()); +} + TEST(OpenCLLogSoftmax, prim_rev_values_large) { int N = 71; diff --git a/test/unit/math/opencl/rev/softmax_test.cpp b/test/unit/math/opencl/rev/softmax_test.cpp index dbf1fc1b4f8..6f32284897b 100644 --- a/test/unit/math/opencl/rev/softmax_test.cpp +++ b/test/unit/math/opencl/rev/softmax_test.cpp @@ -13,10 +13,8 @@ TEST(OpenCLSoftmax, prim_rev_values_small) { } TEST(OpenCLSoftmax, prim_rev_size_0) { - int N = 0; - - Eigen::VectorXd a(N); - stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a); + Eigen::VectorXd a(0); + EXPECT_EQ(0, stan::math::softmax(a).size()); } TEST(OpenCLSoftmax, prim_rev_values_large) { diff --git a/test/unit/math/prim/fun/log_softmax_test.cpp b/test/unit/math/prim/fun/log_softmax_test.cpp index 9649a20e13d..4e327f8b13d 100644 --- a/test/unit/math/prim/fun/log_softmax_test.cpp +++ b/test/unit/math/prim/fun/log_softmax_test.cpp @@ -84,9 +84,8 @@ TEST(MathMatrixPrimMat, log_softmax_neg_inf) { EXPECT_FLOAT_EQ(2.0 - lse_finite, result[2]); } -TEST(MathMatrixPrimMat, log_softmax_exception) { +TEST(MathMatrixPrimMat, log_softmax_empty) { using stan::math::log_softmax; - stan::math::vector_d v0; // size == 0 - - EXPECT_THROW(log_softmax(v0), std::invalid_argument); + stan::math::vector_d v0; + EXPECT_EQ(0, log_softmax(v0).size()); } diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index 7a38156d95b..8880a03ed3e 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -46,6 +46,12 @@ TEST(MathMatrixPrimMat, softmax_neg_inf) { EXPECT_FLOAT_EQ(1.0, theta.sum()); } +TEST(MathMatrixPrimMat, softmax_empty) { + using stan::math::softmax; + Eigen::Matrix v0; // size == 0 + EXPECT_EQ(0, softmax(v0).size()); +} + TEST(MathMatrixPrimMat, softmax_row_vector) { using Eigen::Dynamic; using Eigen::Matrix;