diff --git a/Project.toml b/Project.toml index 813eea5..d3e78cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KeyedDistributions" uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.15" +version = "0.1.16" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/KeyedDistributions.jl b/src/KeyedDistributions.jl index 7eed39f..a12d99c 100644 --- a/src/KeyedDistributions.jl +++ b/src/KeyedDistributions.jl @@ -35,11 +35,13 @@ for T in (:Distribution, :Sampleable) The length of each key vector in must match the length along each dimension. !!! note - For distributions that can be marginalized exactly, the $($KeyedT) can be + For distributions that can be marginalised exactly, the $($KeyedT) can be marginalised via the indexing or lookup syntax just like `KeyedArray`s. i.e. One can use square or round brackets to retain certain indices or keys and marginalise out the others. For example for `D::KeyedMvNormal` over `:a, :b, :c`: - - `D(:a)` or D(1) will marginalise out `:b, :c` and return a `KeyedMvNormal` + - `D(:a)` or D[1] will marginalise out `:b, :c` and return a `KeyedNormal` + over `:a`. + - `D([:a])` or D[[1]] will marginalise out `:b, :c` and return a `KeyedMvNormal` over `:a`. - `D([:a, :b])` or `D[[1, 2]]` will marginalise out `:c` and return a `KeyedMvNormal` over `:a, :b`. @@ -119,7 +121,7 @@ function Base.getindex(d::KeyedMvNormal, i::Vector)::KeyedMvNormal end function Base.getindex(d::KeyedMvNormal, i::Integer)::KeyedDistribution - return KeyedDistribution(Normal(d.d.μ[i], d.d.Σ[i, i]), [axiskeys(d)[1][i]]) + return KeyedDistribution(Normal(d.d.μ[i], sqrt(d.d.Σ[i, i])), [axiskeys(d)[1][i]]) end function Base.getindex(d::KeyedGenericMvTDist, i::Vector)::KeyedGenericMvTDist @@ -291,7 +293,7 @@ for f in (:logpdf, :quantile, :mgf, :cf) @eval Distributions.$f(d::KeyedDistribution{<:Univariate}, x) = $f(distribution(d), x) end -for f in (:minimum, :maximum, :modes, :mode, :skewness, :kurtosis) +for f in (:mean, :var, :minimum, :maximum, :modes, :mode, :skewness, :kurtosis) @eval Distributions.$f(d::KeyedDistribution{<:Univariate}) = $f(distribution(d)) end @@ -341,7 +343,7 @@ end Distributions.components(kd::KeyedMixtureModel) = Distributions.components(kd.d) -function (mm::KeyedMixtureModel)(keys...) +function (mm::KeyedMixtureModel)(keys...) margcomps = map(Distributions.components(mm)) do c inds = first(map(AxisKeys.findindex, keys, axiskeys(mm))) _marginalize(c, inds) diff --git a/test/runtests.jl b/test/runtests.jl index 76d0e23..8ecd224 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -205,6 +205,8 @@ using Test @test entropy(kd) == entropy(d) ≈ -0.1904993792294276 @test entropy(kd, 2) == entropy(d, 2) ≈ -0.27483250970672124 + @test mean(kd) == mean(d) == 0.5 + @test var(kd) == var(d) ≈ 0.04 @test minimum(kd) == minimum(d) == -Inf @test maximum(kd) == maximum(d) == Inf @test modes(kd) == modes(d) == [0.5] @@ -326,7 +328,11 @@ using Test @test d([:a, :c]) == d[[1, 3]] == d13 @test d([:a]) == d[[1]] == KeyedDistribution(MvNormal(m[[1]], s[[1], [1]]); id=[:a]) - @test d(:a) == d[1] == KeyedDistribution(Normal(m[1], s[1, 1]), [:a]) + @test d(:a) == d[1] == KeyedDistribution(Normal(m[1], sqrt(s[1, 1])), [:a]) + + # Ensure correct variance when returning a KeyedNormal + # https://github.com/invenia/KeyedDistributions.jl/issues/49 + @test var(d(:a)) == s[1, 1] end @testset "KeyedMvNormal constructed without named keys" begin @@ -337,7 +343,11 @@ using Test @test d([1, 3]) == d[[1, 3]] == d13 @test d([1]) == d[[1]] == KeyedDistribution(MvNormal(m[[1]], s[[1], [1]]), [1]) - @test d(1) == d[1] == KeyedDistribution(Normal(m[1], s[1, 1]), [1]) + @test d(1) == d[1] == KeyedDistribution(Normal(m[1], sqrt(s[1, 1])), [1]) + + # Ensure correct variance when returning a KeyedNormal + # https://github.com/invenia/KeyedDistributions.jl/issues/49 + @test var(d(1)) == s[1, 1] end