Skip to content

Commit

Permalink
Fix MvNormal marginalisation when returning Normal (#50)
Browse files Browse the repository at this point in the history
* Correct docstring marginalisation description

* Specialise `mean` and `var` for univariate distributions

* Add failing variance test for marginialising to return `Normal`

* Test univariate `mean` and `var`

* Pass standard deviation instead of variance to `Normal` constructor

* Bump patch version

* Correct marginalisation notation to square brackets for integer indices
  • Loading branch information
mjp98 authored Jan 9, 2023
1 parent 1f77b79 commit dc5d9e0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KeyedDistributions"
uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.15"
version = "0.1.16"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Expand Down
12 changes: 7 additions & 5 deletions src/KeyedDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ for T in (:Distribution, :Sampleable)
The length of each key vector in must match the length along each dimension.
!!! note
For distributions that can be marginalized exactly, the $($KeyedT) can be
For distributions that can be marginalised exactly, the $($KeyedT) can be
marginalised via the indexing or lookup syntax just like `KeyedArray`s.
i.e. One can use square or round brackets to retain certain indices or keys and
marginalise out the others. For example for `D::KeyedMvNormal` over `:a, :b, :c`:
- `D(:a)` or D(1) will marginalise out `:b, :c` and return a `KeyedMvNormal`
- `D(:a)` or D[1] will marginalise out `:b, :c` and return a `KeyedNormal`
over `:a`.
- `D([:a])` or D[[1]] will marginalise out `:b, :c` and return a `KeyedMvNormal`
over `:a`.
- `D([:a, :b])` or `D[[1, 2]]` will marginalise out `:c` and return a
`KeyedMvNormal` over `:a, :b`.
Expand Down Expand Up @@ -119,7 +121,7 @@ function Base.getindex(d::KeyedMvNormal, i::Vector)::KeyedMvNormal
end

function Base.getindex(d::KeyedMvNormal, i::Integer)::KeyedDistribution
return KeyedDistribution(Normal(d.d.μ[i], d.d.Σ[i, i]), [axiskeys(d)[1][i]])
return KeyedDistribution(Normal(d.d.μ[i], sqrt(d.d.Σ[i, i])), [axiskeys(d)[1][i]])
end

function Base.getindex(d::KeyedGenericMvTDist, i::Vector)::KeyedGenericMvTDist
Expand Down Expand Up @@ -291,7 +293,7 @@ for f in (:logpdf, :quantile, :mgf, :cf)
@eval Distributions.$f(d::KeyedDistribution{<:Univariate}, x) = $f(distribution(d), x)
end

for f in (:minimum, :maximum, :modes, :mode, :skewness, :kurtosis)
for f in (:mean, :var, :minimum, :maximum, :modes, :mode, :skewness, :kurtosis)
@eval Distributions.$f(d::KeyedDistribution{<:Univariate}) = $f(distribution(d))
end

Expand Down Expand Up @@ -341,7 +343,7 @@ end

Distributions.components(kd::KeyedMixtureModel) = Distributions.components(kd.d)

function (mm::KeyedMixtureModel)(keys...)
function (mm::KeyedMixtureModel)(keys...)
margcomps = map(Distributions.components(mm)) do c
inds = first(map(AxisKeys.findindex, keys, axiskeys(mm)))
_marginalize(c, inds)
Expand Down
14 changes: 12 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ using Test
@test entropy(kd) == entropy(d) -0.1904993792294276
@test entropy(kd, 2) == entropy(d, 2) -0.27483250970672124

@test mean(kd) == mean(d) == 0.5
@test var(kd) == var(d) 0.04
@test minimum(kd) == minimum(d) == -Inf
@test maximum(kd) == maximum(d) == Inf
@test modes(kd) == modes(d) == [0.5]
Expand Down Expand Up @@ -326,7 +328,11 @@ using Test
@test d([:a, :c]) == d[[1, 3]] == d13

@test d([:a]) == d[[1]] == KeyedDistribution(MvNormal(m[[1]], s[[1], [1]]); id=[:a])
@test d(:a) == d[1] == KeyedDistribution(Normal(m[1], s[1, 1]), [:a])
@test d(:a) == d[1] == KeyedDistribution(Normal(m[1], sqrt(s[1, 1])), [:a])

# Ensure correct variance when returning a KeyedNormal
# https://github.com/invenia/KeyedDistributions.jl/issues/49
@test var(d(:a)) == s[1, 1]
end

@testset "KeyedMvNormal constructed without named keys" begin
Expand All @@ -337,7 +343,11 @@ using Test
@test d([1, 3]) == d[[1, 3]] == d13

@test d([1]) == d[[1]] == KeyedDistribution(MvNormal(m[[1]], s[[1], [1]]), [1])
@test d(1) == d[1] == KeyedDistribution(Normal(m[1], s[1, 1]), [1])
@test d(1) == d[1] == KeyedDistribution(Normal(m[1], sqrt(s[1, 1])), [1])

# Ensure correct variance when returning a KeyedNormal
# https://github.com/invenia/KeyedDistributions.jl/issues/49
@test var(d(1)) == s[1, 1]
end


Expand Down

2 comments on commit dc5d9e0

@mjp98
Copy link
Member Author

@mjp98 mjp98 commented on dc5d9e0 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75420

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.16 -m "<description of version>" dc5d9e097da985072f67b22e890dfd32c004462f
git push origin v0.1.16

Please sign in to comment.