diff --git a/Project.toml b/Project.toml index cd039ff4..8065f868 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.6.0" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" @@ -29,6 +30,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" BayesBase = "1.5.0" +BlockArrays = "1.3.0" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" FastCholesky = "1.0" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 674e0375..28d7e69d 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -4,4 +4,5 @@ using ExponentialFamily.BayesBase const SUITE = BenchmarkGroup() -include("benchmarks/bernoulli.jl") \ No newline at end of file +include("benchmarks/bernoulli.jl") +include("benchmarks/tensordirichlet.jl") \ No newline at end of file diff --git a/benchmark/benchmarks/tensordirichlet.jl b/benchmark/benchmarks/tensordirichlet.jl new file mode 100644 index 00000000..8449de09 --- /dev/null +++ b/benchmark/benchmarks/tensordirichlet.jl @@ -0,0 +1,42 @@ + +SUITE["tensordirichlet"] = BenchmarkGroup( + ["tensordirichlet", "distribution"], + "prod" => BenchmarkGroup(["prod", "multiplication"]), + "convert" => BenchmarkGroup(["convert"]) +) + +# `prod` BenchmarkGroup ======================== +for rank in (3, 4, 5, 6) + for d in (5, 10, 20) + left = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) + right = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) + SUITE["tensordirichlet"]["prod"]["Closed(rank=$rank, d=$d)"] = @benchmarkable prod(ClosedProd(), $left, $right) + end +end + +# ============================================== + +# `convert` BenchmarkGroup ===================== +SUITE["tensordirichlet"]["convert"]["Convert from D to EF"] = @benchmarkable convert(ExponentialFamilyDistribution, dist) setup = begin + dist = TensorDirichlet(rand(5, 5, 5)) +end + +SUITE["tensordirichlet"]["convert"]["Convert from EF to D"] = @benchmarkable convert(Distribution, efdist) setup = begin + efdist = convert(ExponentialFamilyDistribution, TensorDirichlet(rand(5, 5, 5))) +end +# ============================================== + +for rank in (3, 4, 5, 6) + for d in (5, 10, 20) + distribution = TensorDirichlet(rand([d for _ in 1:rank]...)) + sample = rand(distribution) + SUITE["tensordirichlet"]["mean"]["rank=$rank, d=$d"] = @benchmarkable mean($distribution) + SUITE["tensordirichlet"]["rand"]["rank=$rank, d=$d"] = @benchmarkable rand($distribution) + SUITE["tensordirichlet"]["logpdf"]["rank=$rank, d=$d"] = @benchmarkable logpdf($distribution, $sample) + SUITE["tensordirichlet"]["var"]["rank=$rank, d=$d"] = @benchmarkable var($distribution) + SUITE["tensordirichlet"]["cov"]["rank=$rank, d=$d"] = @benchmarkable cov($distribution) + SUITE["tensordirichlet"]["entropy"]["rank=$rank, d=$d"] = @benchmarkable entropy($distribution) + end +end + +# ============================================== \ No newline at end of file diff --git a/docs/src/library.md b/docs/src/library.md index a271e018..9afe9323 100644 --- a/docs/src/library.md +++ b/docs/src/library.md @@ -21,4 +21,5 @@ ExponentialFamily.WishartFast ExponentialFamily.InverseWishartFast ExponentialFamily.NormalGamma ExponentialFamily.MvNormalWishart +ExponentialFamily.TensorDirichlet ``` \ No newline at end of file diff --git a/scripts/benchmark.jl b/scripts/benchmark.jl index 4ef9c9f9..1009bcc1 100644 --- a/scripts/benchmark.jl +++ b/scripts/benchmark.jl @@ -10,6 +10,6 @@ if isempty(ARGS) export_markdown("./benchmark_logs/last.md", result) else name = first(ARGS) - BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) - export_markdown("benchmark_vs_$(name)_result.md", result) + result = BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) + export_markdown("./benchmark_logs/benchmark_vs_$(name)_result.md", result) end diff --git a/src/ExponentialFamily.jl b/src/ExponentialFamily.jl index f2ec4c1c..bba5df01 100644 --- a/src/ExponentialFamily.jl +++ b/src/ExponentialFamily.jl @@ -72,5 +72,6 @@ include("distributions/poisson.jl") include("distributions/chi_squared.jl") include("distributions/mv_normal_wishart.jl") include("distributions/normal_gamma.jl") +include("distributions/tensor_dirichlet.jl") end diff --git a/src/distributions/binomial.jl b/src/distributions/binomial.jl index 67e37915..da5c284d 100644 --- a/src/distributions/binomial.jl +++ b/src/distributions/binomial.jl @@ -78,6 +78,10 @@ function (::NaturalToMean{Binomial})(tuple_of_η::Tuple{Any}, _) return (logistic(η₁),) end +function unpack_parameters(::Type{Binomial}, packed, _) + return unpack_parameters(Binomial, packed) +end + function unpack_parameters(::Type{Binomial}, packed) return (first(packed),) end diff --git a/src/distributions/categorical.jl b/src/distributions/categorical.jl index b8a12b4c..011ea64d 100644 --- a/src/distributions/categorical.jl +++ b/src/distributions/categorical.jl @@ -64,6 +64,10 @@ function (::NaturalToMean{Categorical})(tuple_of_η::Tuple{V}, _) where {V <: Ab return (softmax(convert(Vector, η)),) end +function unpack_parameters(::Type{Categorical}, packed, _) + return unpack_parameters(Categorical, packed) +end + function unpack_parameters(::Type{Categorical}, packed) return (packed,) end diff --git a/src/distributions/laplace.jl b/src/distributions/laplace.jl index 5661107c..b7b83094 100644 --- a/src/distributions/laplace.jl +++ b/src/distributions/laplace.jl @@ -152,6 +152,10 @@ function (::NaturalToMean{Laplace})(tuple_of_η::Tuple{Any}, _) return (-inv(η₁),) end +function unpack_parameters(::Type{Laplace}, packed, _) + return (first(packed),) +end + function unpack_parameters(::Type{Laplace}, packed) return (first(packed),) end diff --git a/src/distributions/negative_binomial.jl b/src/distributions/negative_binomial.jl index d9163c2e..9759a7b9 100644 --- a/src/distributions/negative_binomial.jl +++ b/src/distributions/negative_binomial.jl @@ -92,6 +92,10 @@ function (::NaturalToMean{NegativeBinomial})(tuple_of_η::Tuple{Any}, _) return (one(η₁) - exp(η₁),) end +function unpack_parameters(::Type{NegativeBinomial}, packed, _) + return (first(packed),) +end + function unpack_parameters(::Type{NegativeBinomial}, packed) return (first(packed),) end diff --git a/src/distributions/pareto.jl b/src/distributions/pareto.jl index ebd06d36..4d57b2c1 100644 --- a/src/distributions/pareto.jl +++ b/src/distributions/pareto.jl @@ -134,6 +134,10 @@ function (::NaturalToMean{Pareto})(tuple_of_η::Tuple{Any}, _) return (-η₁ - one(η₁),) end +function unpack_parameters(::Type{Pareto}, packed, _) + return (first(packed),) +end + function unpack_parameters(::Type{Pareto}, packed) return (first(packed),) end diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl new file mode 100644 index 00000000..a1106cbc --- /dev/null +++ b/src/distributions/tensor_dirichlet.jl @@ -0,0 +1,317 @@ +export TensorDirichlet, ContinuousTensorDistribution + +import SpecialFunctions: digamma, loggamma +import Base: eltype +import Distributions: pdf, logpdf +using Distributions +using SpecialFunctions, LogExpFunctions + +import FillArrays: Ones, Eye +import LoopVectorization: vmap, vmapreduce +using LinearAlgebra, Random + +using BlockArrays: BlockDiagonal + +const ContinuousTensorDistribution = Distribution{ArrayLikeVariate, Continuous} + +""" + TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution + +A tensor-valued Dirichlet distribution, where `T` is the element type of the tensor `A`. The distribution generalizes the Dirichlet distribution to handle multiple sets of parameters organized in a tensor structure. This distribution collects multiple independent Dirichlet distributions into a single efficient interface. The Dirichlet counts for the independent Dirichlet distributions are stored along the first dimension of `a`. This distribution can be used as a conjugate prior to a Categorical distribution with mulitple switch cases (such as a discrete state-transition with controls). + +# Fields +- `a::A`: The tensor parameter of the distribution, where each slice represents parameters of a Dirichlet distribution +- `α0::Ts`: The sum of parameters along the first dimension. +- `lmnB::Ts`: The log multinomial beta function values for each slice. + +The distribution models multiple independent Dirichlet distributions organized in a tensor structure, where each slice `a[:,i,j,...]` represents the parameters of an independent Dirichlet distribution. +""" +struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution + a::A + α0::Ts + lmnB::Ts + function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + if !all(x -> x > zero(x), alpha) + throw(ArgumentError("All elements of the alpha tensor should be positive")) + end + alpha0 = sum(alpha; dims = 1) + lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) + new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB) + end +end + +function BayesBase.logpdf(dist::TensorDirichlet{T, N, A, Ts}, xs::AbstractVector{A}) where {T, N, A, Ts} + return map(x -> logpdf(dist, x), xs) +end + +function BayesBase.pdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} + return exp(logpdf(dist, x)) +end + +function BayesBase.pdf(dist::TensorDirichlet, xs::AbstractVector) + return map(x -> pdf(dist, x), xs) +end + +BayesBase.params(dist::TensorDirichlet) = (dist.a,) + +function unpack_parameters(::Type{TensorDirichlet}, packed, conditioner) + packed = view(packed, 1:length(packed)) + return (reshape(packed, conditioner),) +end + +function join_conditioner(::Type{TensorDirichlet}, cparams, _) + return cparams +end + +function separate_conditioner(::Type{TensorDirichlet}, tuple_of_θ) + return (tuple_of_θ, size(tuple_of_θ[1])) +end + +isbasemeasureconstant(::Type{TensorDirichlet}) = ConstantBaseMeasure() + +getbasemeasure(::Type{TensorDirichlet}, conditioner) = (x) -> one(Float64) +getlogbasemeasure(::Type{TensorDirichlet}, conditioner) = (x) -> zero(Float64) + +getsufficientstatistics(::Type{TensorDirichlet}, conditioner) = (x -> vmap(log, x),) + +BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0 +function BayesBase.cov(dist::TensorDirichlet{T}) where {T} + s = size(dist.a) + news = (first(s), first(s), Base.tail(s)...) + v = zeros(T, news) + for i in CartesianIndices(Base.tail(size(dist.a))) + v[:, :, i] .= cov(Dirichlet(dist.a[:, i])) + end + return v +end +function BayesBase.var(dist::TensorDirichlet{T, N, A, Ts}) where {T, N, A, Ts} + α = dist.a + α0 = dist.α0 + c = inv.(α0 .^ 2 .* (α0 .+ 1)) + v = α .* (α0 .- α) .* c + return v +end +BayesBase.std(dist::TensorDirichlet) = sqrt.(var(dist)) + +Base.size(dist::TensorDirichlet) = size(dist.a) +Base.eltype(::TensorDirichlet{T}) where {T} = T + +function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) + return TensorDirichlet(ones(dims, dims)) +end + +function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Tuple) + return TensorDirichlet(ones(Float64, dims)) +end + +function BayesBase.entropy(dist::TensorDirichlet) + α = dist.a + α0 = dist.α0 + lmnB = dist.lmnB + return sum(-sum((α .- one(eltype(α))) .* (digamma.(α) .- digamma.(α0)); dims = 1) .+ lmnB) +end + +BayesBase.promote_variate_type(::Type{Multivariate}, ::Type{<:TensorDirichlet}) = TensorDirichlet +BayesBase.promote_variate_type(::Type{ArrayLikeVariate}, ::Type{<:Dirichlet}) = TensorDirichlet + +function BayesBase.rand(rng::AbstractRNG, dist::TensorDirichlet{T}) where {T} + container = similar(dist.a) + return rand!(rng, dist, container) +end + +function BayesBase.rand(rng::AbstractRNG, dist::TensorDirichlet{T}, nsamples::Int64) where {T} + container = [similar(dist.a) for _ in 1:nsamples] + rand!(rng, dist, container) + return container +end + +function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet, container::AbstractArray{T, N}) where {T <: Real, N} + for (i, αi) in zip(eachindex(container), dist.a) + @inbounds container[i] = rand(rng, Gamma(αi)) + end + container .= container ./ sum(container; dims = 1) +end + +# Add method for handling vector of arrays +function BayesBase.rand!( + rng::AbstractRNG, + dist::TensorDirichlet{T, N, A, Ts}, + container::AbstractArray{A, M} +) where {T <: Real, N, A <: AbstractArray{T, N}, Ts, M} + for c in container + size(c) == size(dist.a) || error("Size mismatch") + end + + @inbounds for c in container + rand!(rng, dist, c) + end + + return container +end + +function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} + if !insupport(dist, x) + return sum(xlogy.(one(eltype(dist.a)), zero(eltype(x)))) + end + α = dist.a + α0 = dist.α0 + s = sum(xlogy.(α .- 1, x); dims = 1) + return sum(s .- dist.lmnB) +end + +check_logpdf(::ExponentialFamilyDistribution{TensorDirichlet}, x::AbstractVector) = (MapBasedLogpdfCall(), x) +check_logpdf(::ExponentialFamilyDistribution{TensorDirichlet}, x) = (PointBasedLogpdfCall(), x) + +BayesBase.default_prod_rule(::Type{<:TensorDirichlet}, ::Type{<:TensorDirichlet}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::TensorDirichlet, right::TensorDirichlet) + return TensorDirichlet(left.a .+ right.a .- 1) +end + +function BayesBase.insupport(dist::TensorDirichlet{T, N, A, Ts}, x::AbstractArray{T, N}) where {T, N, A, Ts} + return size(dist) == size(x) && !any(x -> x < zero(x), x) && all(z -> z ≈ 1, sum(x; dims = 1)) +end + +function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x) + l = getconditioner(ef) + values = map(CartesianIndices(Base.tail(size(x)))) do i + slice = @view x[:, i] + sum(slice) ≈ 1 && all(y -> y > 0, slice) + end + return l == size(x) && all(values) +end + +# Natural parametrization + +function isproper(::NaturalParametersSpace, ::Type{TensorDirichlet}, η, conditioner) + return length(η) > 1 && all(isless.(-1, η)) && all(!isinf, η) && all(!isnan, η) +end +function isproper(::MeanParametersSpace, ::Type{TensorDirichlet}, θ, conditioner) + return length(θ) > 1 && all(>(0), θ) && all(!isinf, θ) +end + +function (::MeanToNatural{TensorDirichlet})(tuple_of_θ::Tuple{Any}, _) + (α,) = tuple_of_θ + return (α - Ones{Float64}(size(α)),) +end + +function (::NaturalToMean{TensorDirichlet})(tuple_of_η::Tuple{Any}, _) + (η,) = tuple_of_η + return (η + Ones{Float64}(size(η)),) +end + +function getlogpartition(::NaturalParametersSpace, ::Type{TensorDirichlet}, conditioner::NTuple{N, Int}) where {N} + k = conditioner[1] # Number of parameters per distribution + n_distributions = prod(Base.tail(conditioner)) # Total number of distributions + dirichlet_logpartition = getlogpartition(NaturalParametersSpace(), Dirichlet) + + return function (η::AbstractVector) + result = zero(eltype(η)) + for i in 1:n_distributions + idx_start = (i - 1) * k + 1 + idx_end = i * k + @views params = η[idx_start:idx_end] + result += dirichlet_logpartition(params) + end + + return result + end +end + +function getgradlogpartition( + ::NaturalParametersSpace, + ::Type{TensorDirichlet}, + conditioner::NTuple{N, Int} +) where {N} + k = conditioner[1] # Number of parameters per distribution + n_distributions = prod(Base.tail(conditioner)) # Total number of distributions + + # Get the "gradlogpartition" function for a standard Dirichlet + dirichlet_gradlogpartition = getgradlogpartition(NaturalParametersSpace(), Dirichlet) + + return function (η::AbstractVector{T}) where {T} + # Preallocate the output. We know we need `k * n_distributions` entries, + # of the same element type as `η`. + out = Vector{T}(undef, k * n_distributions) + + for i in 1:n_distributions + @inbounds begin + # For the i-th distribution, grab the slice of η + # and apply the Dirichlet gradlogpartition. + out[(i-1)*k+1:i*k] = dirichlet_gradlogpartition( + @view η[(i-1)*k+1:i*k] + ) + end + end + return out + end +end + +function getfisherinformation(::NaturalParametersSpace, ::Type{TensorDirichlet}, conditioner) + k = conditioner[1] # Number of parameters per distribution + n_distributions = prod(Base.tail(conditioner)) # Total number of distributions + dirichlet_fisher = getfisherinformation(NaturalParametersSpace(), Dirichlet) + + return function (η::AbstractVector) + blocks = Vector{Matrix{Float64}}(undef, n_distributions) + + for i in 1:n_distributions + idx_start = (i - 1) * k + 1 + idx_end = i * k + @views params = η[idx_start:idx_end] + blocks[i] = dirichlet_fisher(params) + end + + return BlockDiagonal(blocks) + end +end + +# Mean parametrization + +getlogpartition(::MeanParametersSpace, ::Type{TensorDirichlet}, conditioner) = + (η) -> begin + return mapreduce(x -> getlogpartition(MeanParametersSpace(), Dirichlet)(x), +, η) + end + +function getgradlogpartition(::MeanParametersSpace, ::Type{TensorDirichlet}, conditioner::NTuple{N, Int}) where {N} + k = conditioner[1] # Number of parameters per distribution + n_distributions = prod(Base.tail(conditioner)) # Total number of distributions + dirichlet_gradlogpartition = getgradlogpartition(MeanParametersSpace(), Dirichlet) + + return function (θ::AbstractVector{T}) where {T} + # Preallocate the output + out = Vector{T}(undef, k * n_distributions) + + for i in 1:n_distributions + @inbounds begin + # For each distribution, compute its gradient + out[(i-1)*k+1:i*k] = dirichlet_gradlogpartition( + @view θ[(i-1)*k+1:i*k] + ) + end + end + return out + end +end + +function getfisherinformation(::MeanParametersSpace, ::Type{TensorDirichlet}, conditioner::NTuple{N, Int}) where {N} + k = conditioner[1] # Number of parameters per distribution + n_distributions = prod(Base.tail(conditioner)) # Total number of distributions + dirichlet_fisher = getfisherinformation(MeanParametersSpace(), Dirichlet) + + return function (θ::AbstractVector{T}) where {T} + # Create blocks for block diagonal matrix + blocks = Vector{Matrix{Float64}}(undef, n_distributions) + + for i in 1:n_distributions + @inbounds begin + # For each distribution, compute its Fisher information + blocks[i] = dirichlet_fisher( + @view θ[(i-1)*k+1:i*k] + ) + end + end + + return BlockDiagonal(blocks) + end +end diff --git a/src/distributions/von_mises.jl b/src/distributions/von_mises.jl index e01ae1c2..e3bf120a 100644 --- a/src/distributions/von_mises.jl +++ b/src/distributions/von_mises.jl @@ -62,6 +62,12 @@ function (::NaturalToMean{VonMises})(tuple_of_η::Tuple{Any, Any}, conditioner) return (conditioner * π + μ, κ) end +function unpack_parameters(::Type{VonMises}, packed, _) + fi = firstindex(packed) + si = firstindex(packed) + 1 + return (packed[fi], packed[si]) +end + function unpack_parameters(::Type{VonMises}, packed) fi = firstindex(packed) si = firstindex(packed) + 1 diff --git a/src/distributions/von_mises_fisher.jl b/src/distributions/von_mises_fisher.jl index 68fae745..cd3eb652 100644 --- a/src/distributions/von_mises_fisher.jl +++ b/src/distributions/von_mises_fisher.jl @@ -59,12 +59,22 @@ function (::NaturalToMean{VonMisesFisher})(tuple_of_η::Tuple{Any}) return (μ, κ) end +function unpack_parameters(::MeanParametersSpace, ::Type{VonMisesFisher}, packed, ::Nothing) + (μ, κ) = (view(packed, 1:length(packed)-1), packed[end]) + + return (μ, κ) +end + function unpack_parameters(::MeanParametersSpace, ::Type{VonMisesFisher}, packed) (μ, κ) = (view(packed, 1:length(packed)-1), packed[end]) return (μ, κ) end +function unpack_parameters(::NaturalParametersSpace, ::Type{VonMisesFisher}, packed, ::Nothing) + return (packed,) +end + function unpack_parameters(::NaturalParametersSpace, ::Type{VonMisesFisher}, packed) return (packed,) end diff --git a/src/distributions/weibull.jl b/src/distributions/weibull.jl index 33fe3f8d..896bd8a3 100644 --- a/src/distributions/weibull.jl +++ b/src/distributions/weibull.jl @@ -85,6 +85,10 @@ function (::NaturalToMean{Weibull})(tuple_of_η::Tuple{Any}, conditioner) return ((-η)^inv(-conditioner),) end +function unpack_parameters(::Type{Weibull}, packed, _) + return (first(packed),) +end + function unpack_parameters(::Type{Weibull}, packed) return (first(packed),) end diff --git a/src/exponential_family.jl b/src/exponential_family.jl index 1a264ed1..cf603e37 100644 --- a/src/exponential_family.jl +++ b/src/exponential_family.jl @@ -401,7 +401,7 @@ function (transformation::NaturalToMean{T})(v::AbstractVector, ::Nothing) where end function (transformation::NaturalToMean{T})(v::AbstractVector, conditioner) where {T <: Distribution} - return pack_parameters(MeanParametersSpace(), T, transformation(unpack_parameters(NaturalParametersSpace(), T, v), conditioner)) + return pack_parameters(MeanParametersSpace(), T, transformation(unpack_parameters(NaturalParametersSpace(), T, v, conditioner), conditioner)) end function (transformation::MeanToNatural{T})(v::AbstractVector) where {T <: Distribution} @@ -413,7 +413,7 @@ function (transformation::MeanToNatural{T})(v::AbstractVector, ::Nothing) where end function (transformation::MeanToNatural{T})(v::AbstractVector, conditioner) where {T <: Distribution} - return pack_parameters(NaturalParametersSpace(), T, transformation(unpack_parameters(MeanParametersSpace(), T, v), conditioner)) + return pack_parameters(NaturalParametersSpace(), T, transformation(unpack_parameters(MeanParametersSpace(), T, v, conditioner), conditioner)) end """ @@ -697,7 +697,7 @@ Evaluates and returns the probability density function of the exponential family BayesBase.pdf(ef::ExponentialFamilyDistribution, x) = _pdf(ef, x) function _pdf(ef, x) - vartype, _x = check_logpdf(variate_form(typeof(ef)), typeof(x), eltype(x), ef, x) + vartype, _x = check_logpdf(ef, x) _pdf(vartype, ef, _x) end @@ -798,7 +798,16 @@ See also: [`MeanParametersSpace`](@ref), [`NaturalParametersSpace`](@ref) """ function unpack_parameters end -unpack_parameters(ef::ExponentialFamilyDistribution{T}) where {T} = unpack_parameters(NaturalParametersSpace(), T, getnaturalparameters(ef)) +unpack_parameters(ef::ExponentialFamilyDistribution{T}) where {T} = + unpack_parameters(NaturalParametersSpace(), T, getnaturalparameters(ef), getconditioner(ef)) + +function unpack_parameters(::Union{MeanParametersSpace, NaturalParametersSpace}, ::Type{T}, packed, conditioner) where {T} + unpack_parameters(T, packed, conditioner) +end + +function unpack_parameters(::Union{MeanParametersSpace, NaturalParametersSpace}, ::Type{T}, packed, ::Nothing) where {T} + unpack_parameters(T, packed) +end # Assume that for the most distributions the `unpack_parameters` does not depend on the `space` parameter unpack_parameters(::Union{MeanParametersSpace, NaturalParametersSpace}, ::Type{T}, packed) where {T} = unpack_parameters(T, packed) diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index 7292f8f7..82a5145a 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -68,7 +68,8 @@ function test_exponentialfamily_interface(distribution; test_fisherinformation_against_hessian = true, test_fisherinformation_against_jacobian = true, test_plogpdf_interface = true, - option_assume_no_allocations = false + option_assume_no_allocations = false, + nsamples_for_gradlogpartition_properties = 6000 ) T = ExponentialFamily.exponential_family_typetag(distribution) @@ -84,7 +85,7 @@ function test_exponentialfamily_interface(distribution; test_packing_unpacking && run_test_packing_unpacking(distribution) test_isproper && run_test_isproper(distribution; assume_no_allocations = option_assume_no_allocations) test_basic_functions && run_test_basic_functions(distribution; assume_no_allocations = option_assume_no_allocations) - test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution) + test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution, nsamples = nsamples_for_gradlogpartition_properties) test_fisherinformation_properties && run_test_fisherinformation_properties(distribution) test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations) test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations) @@ -96,7 +97,7 @@ function run_test_plogpdf_interface(distribution) ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution) η = getnaturalparameters(ef) samples = rand(StableRNG(42), distribution, 10) - _, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples) + _, _samples = ExponentialFamily.check_logpdf(ef, samples) ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples) unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors) @test all(unnormalized_logpdfs ≈ map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples)) @@ -165,11 +166,11 @@ function run_test_parameters_conversion(distribution) ) end - @test all(unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) .== tuple_of_η) - @test all(unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .== tuple_of_θ) + @test all(unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) .== tuple_of_η) + @test all(unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) .== tuple_of_θ) - @test_opt unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) - @test_opt unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) + @test_opt unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) + @test_opt unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) # Extra methods for conditioner free distributions if isnothing(conditioner) @@ -299,6 +300,7 @@ function run_test_basic_functions(distribution; nsamples = 10, test_gradients = @test logbasemeasure(ef, x) ≈ log(basemeasure(ef, x)) atol = 1e-8 @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T, conditioner))) @test @inferred(logpartition(ef)) == getlogpartition(T, conditioner)(η) + @test @inferred(gradlogpartition(ef)) == getgradlogpartition(NaturalParametersSpace(), T, conditioner)(η) @test @inferred(fisherinformation(ef)) == getfisherinformation(T, conditioner)(η) # Double check the `conditioner` free methods @@ -307,6 +309,7 @@ function run_test_basic_functions(distribution; nsamples = 10, test_gradients = @test @inferred(logbasemeasure(ef, x)) == getlogbasemeasure(T)(x) @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T))) @test @inferred(logpartition(ef)) == getlogpartition(T)(η) + @test @inferred(gradlogpartition(ef)) == getgradlogpartition(NaturalParametersSpace(), T)(η) @test @inferred(fisherinformation(ef)) == getfisherinformation(T)(η) end diff --git a/test/distributions/tensor_dirichlet_test.jl b/test/distributions/tensor_dirichlet_test.jl new file mode 100644 index 00000000..0057ac63 --- /dev/null +++ b/test/distributions/tensor_dirichlet_test.jl @@ -0,0 +1,353 @@ +@testitem "TensorDirichlet: common" begin + include("distributions_setuptests.jl") + + @test TensorDirichlet <: Distribution + @test TensorDirichlet <: ContinuousDistribution + @test TensorDirichlet <: ContinuousTensorDistribution + + @test value_support(TensorDirichlet) === Continuous + @test variate_form(TensorDirichlet) === ArrayLikeVariate + + @test_throws "ArgumentError: All elements of the alpha tensor should be positive" TensorDirichlet(zeros(3, 3, 3)) +end + +@testitem "TensorDirichlet: entropy" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + mat_entropy = sum(entropy.(mat_of_dir)) + @test entropy(distribution) ≈ mat_entropy + end + end + end +end + +@testitem "TensorDirichlet: var" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + temp = var.(mat_of_dir) + mat_var = similar(alpha) + for i in CartesianIndices(Base.tail(size(alpha))) + mat_var[:, i] = temp[i] + end + @test var(distribution) ≈ mat_var + end + end + end +end + +@testitem "TensorDirichlet: mean" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + temp = mean.(mat_of_dir) + mat_mean = similar(alpha) + for i in CartesianIndices(Base.tail(size(alpha))) + mat_mean[:, i] = temp[i] + end + @test mean(distribution) ≈ mat_mean + end + end + end +end + +@testitem "TensorDirichlet: std" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + temp = std.(mat_of_dir) + mat_std = similar(alpha) + for i in CartesianIndices(Base.tail(size(alpha))) + mat_std[:, i] = temp[i] + end + @test std(distribution) ≈ mat_std + end + end + end +end + +@testitem "TensorDirichlet: cov" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + temp = cov.(mat_of_dir) + old_shape = size(alpha) + new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...) + mat_cov = ones(new_shape) + for i in CartesianIndices(Base.tail(size(alpha))) + mat_cov[:, :, i] = temp[i] + end + @test cov(distribution) ≈ mat_cov + end + end + end +end + +@testitem "TensorDirichlet: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for len in 3:5 + α = rand(len, len, len) .+ 1 + let d = TensorDirichlet(α) + ef = test_exponentialfamily_interface(d; + option_assume_no_allocations = false, + nsamples_for_gradlogpartition_properties = 20000) + η1 = getnaturalparameters(ef) + conditioner = getconditioner(ef) + for x in [rand(1.0:2.0, len, len) for _ in 1:3] + x = x ./ sum(x) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1.0 + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (map(log, x),)) + @test @inferred(logpartition(ef)) ≈ mapreduce( + d -> getlogpartition(NaturalParametersSpace(), Dirichlet)(convert(Vector, d)), + +, + eachcol(reshape(first(unpack_parameters(TensorDirichlet, η1, conditioner)), len, len * len)) + ) + end + end + end + + inf_test = [Inf, 1.0] + nan_test = [NaN, 2.0] + negative_num_test = [-1.0, -1.2] + negative_num_natural_param_test = [-0.1, -0.2] + a = [1.0, 1.0] + b = [1.2, 3.3] + c = [0.2, 3.4] + d = [4.0, 5.0] + + tensorDiri = Array{Float64, 3}(undef, (2, 2, 2)) + + tensorDiri[:, 1, 1] .= a + tensorDiri[:, 1, 2] .= b + tensorDiri[:, 2, 1] .= c + tensorDiri[:, 2, 2] .= d + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test isproper(space, TensorDirichlet, tensorDiri) + @test !isproper(space, TensorDirichlet, Inf) + tensorDiri[:, 1, 1] .= nan_test + @test !isproper(space, TensorDirichlet, tensorDiri) + tensorDiri[:, 1, 1] .= inf_test + @test !isproper(space, TensorDirichlet, tensorDiri) + tensorDiri[:, 1, 1] .= negative_num_test + @test !isproper(space, TensorDirichlet, tensorDiri) + tensorDiri[:, 1, 1] .= a + end + tensorDiri[:, 1, 1] = negative_num_natural_param_test + @test !isproper(MeanParametersSpace(), TensorDirichlet, tensorDiri) + @test isproper(NaturalParametersSpace(), TensorDirichlet, tensorDiri) + + @test_throws Exception convert(ExponentialFamilyDistribution, TensorDirichlet([Inf Inf; 2 3])) +end + +@testitem "TensorDirichlet: prod with Distribution" begin + include("distributions_setuptests.jl") + + a = [1.0, 1.0] + b = [1.2, 3.3] + c = [0.2, 3.4] + d = [4.0, 5.0] + e = [5.0, 11.0] + f = [0.2, 0.6] + g = [2.0, 1.1] + + D1 = Array{Float64, 2}(undef, (2, 3)) + D1[:, 1] = c + D1[:, 2] = e + D1[:, 3] = f + + D2 = Array{Float64, 2}(undef, (2, 3)) + D2[:, 1] = b + D2[:, 2] = d + D2[:, 3] = g + + D3 = Array{Float64, 2}(undef, (2, 3)) + D3[:, 1] = D3[:, 2] = D3[:, 3] = a + + d1 = TensorDirichlet(D1) + d2 = TensorDirichlet(D2) + d3 = TensorDirichlet(D3) + @test @inferred( + prod(PreserveTypeProd(Distribution), d1, d2) ≈ TensorDirichlet([0.3999999999999999 8.0 1.2000000000000002; 5.699999999999999 15.0 0.7000000000000002]) + ) + @test @inferred(prod(PreserveTypeProd(Distribution), d1, d3)) ≈ TensorDirichlet( + [0.19999999999999996 5.0 0.19999999999999996; 3.4000000000000004 11.0 0.6000000000000001] + ) + @test @inferred(prod(PreserveTypeProd(Distribution), d2, d3)) ≈ TensorDirichlet([1.2000000000000002 4.0 2.0; 3.3 5.0 1.1]) +end + +@testitem "TensorDirichlet: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + for rank in 3:6 + for d in 3:6 + αleft = rand([d for _ in 1:rank]...) .+ 1 + αright = rand([d for _ in 1:rank]...) .+ 1 + @testset let (left, right) = (TensorDirichlet(αleft), TensorDirichlet(αright)) + test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd() + ) + ) + end + end + end +end + +@testitem "TensorDirichlet: promote_variate_type" begin + include("distributions_setuptests.jl") + + @test_throws MethodError promote_variate_type(Univariate, TensorDirichlet) + + @test promote_variate_type(Multivariate, Dirichlet) === Dirichlet + @test promote_variate_type(ArrayLikeVariate, Dirichlet) === TensorDirichlet + + @test promote_variate_type(Multivariate, TensorDirichlet) === TensorDirichlet +end + +@testitem "TensorDirichlet: prod with PreserveTypeProd{Distribution}" begin + include("distributions_setuptests.jl") + + for rank in (3, 5) + for d in (2, 5, 10) + for _ in 1:10 + alpha1 = rand([d for _ in 1:rank]...) .+ 1 + alpha2 = rand([d for _ in 1:rank]...) .+ 1 + distribution1 = TensorDirichlet(alpha1) + distribution2 = TensorDirichlet(alpha2) + + mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank))) + mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank))) + dim = rank - 1 + + prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1))) + for i in CartesianIndices(Base.tail(size(alpha1))) + prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i]) + end + mat_prod = similar(alpha1) + for i in CartesianIndices(Base.tail(size(alpha1))) + mat_prod[:, i] = prod_temp[i].alpha + end + @test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) ≈ TensorDirichlet(mat_prod) + end + end + end +end + +@testitem "TensorDirichlet: rand" begin + include("distributions_setuptests.jl") + using StableRNGs + import Random: seed! + rng = StableRNG(1234) + + for rank in (3, 5) + for d in (2, 3, 4, 5) + seed!(rng, 1234) + alpha = rand([d for _ in 1:rank]...) .+ 2 + distribution = TensorDirichlet(alpha) + seed!(rng, 1234) + sample = rand(rng, distribution) + + @test size(sample) == size(alpha) + @test all(sum(sample; dims = 1) .≈ 1) + + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + mat_sample = Array{Float64, rank}(undef, size(alpha)) + seed!(rng, 1234) + for i in CartesianIndices(Base.tail(size(alpha))) + mat_sample[:, i] = rand(rng, mat_of_dir[i]) + end + + @test sample ≈ mat_sample + + seed!(rng, 1234) + sample = rand(rng, distribution, 10) + @test size(sample) == (10,) + @test(all(x -> all(sum(x; dims = 1) .≈ 1), sample)) + @test(all(x -> size(x) == size(alpha), sample)) + end + end +end + +@testitem "TensorDirichlet: vague" begin + include("distributions_setuptests.jl") + + dirichlet = vague(TensorDirichlet, 3) + @test typeof(dirichlet.a) <: Array{Float64, 2} + @test size(dirichlet.a) == (3, 3) + + @test typeof(vague(TensorDirichlet, (2, 2, 2, 3)).a) <: Array{Float64, 4} + + @test vague(TensorDirichlet, 3) == vague(TensorDirichlet, (3, 3)) + + @test_throws MethodError vague(TensorDirichlet) +end + +@testitem "TensorDirichlet: logpdf" begin + include("distributions_setuptests.jl") + + for rank in (3, 4, 5, 6) + for d in (2, 3, 5, 10) + for i in 1:10 + alpha = rand([d for _ in 1:rank]...) + + distribution = TensorDirichlet(alpha) + mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank))) + + sample = rand(distribution) + sample ./= sum(sample, dims = 1) + + mat_logpdf = sum(logpdf.(mat_of_dir, eachslice(sample, dims = Tuple(2:rank)))) + @test logpdf(distribution, sample) ≈ mat_logpdf + @test pdf(distribution, sample) ≈ prod(pdf.(mat_of_dir, eachslice(sample, dims = Tuple(2:rank)))) + sample = ones(size(sample)) + mat_logpdf = sum(logpdf.(mat_of_dir, eachslice(sample, dims = Tuple(2:rank)))) + @test logpdf(distribution, sample) ≈ mat_logpdf + + sample = rand(distribution, 10) + lpdf = logpdf(distribution, sample) + @test all(lpdf .≈ map(s -> sum(logpdf.(mat_of_dir, eachslice(s, dims = Tuple(2:rank)))), sample)) + end + end + end +end \ No newline at end of file diff --git a/test/exponential_family_setuptests.jl b/test/exponential_family_setuptests.jl index 6226405f..2edef266 100644 --- a/test/exponential_family_setuptests.jl +++ b/test/exponential_family_setuptests.jl @@ -43,6 +43,7 @@ BayesBase.params(dist::ArbitraryDistributionFromExponentialFamily) = (dist.p1, d (::NaturalToMean{ArbitraryDistributionFromExponentialFamily})(params::Tuple) = (params[1] - 1, params[2] - 1) ExponentialFamily.unpack_parameters(::Type{ArbitraryDistributionFromExponentialFamily}, η) = (η[1], η[2]) +ExponentialFamily.unpack_parameters(::Type{ArbitraryDistributionFromExponentialFamily}, η, _) = (η[1], η[2]) # Arbitrary distribution (conditioned) struct ArbitraryConditionedDistributionFromExponentialFamily <: ContinuousUnivariateDistribution @@ -72,3 +73,4 @@ ExponentialFamily.join_conditioner(::Type{ArbitraryConditionedDistributionFromEx (::NaturalToMean{ArbitraryConditionedDistributionFromExponentialFamily})(params::Tuple, conditioner::Number) = (params[1] - conditioner,) ExponentialFamily.unpack_parameters(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η) = (η[1],) +ExponentialFamily.unpack_parameters(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η, _) = (η[1],)