Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KeyedDistribution and KeyedSampleable #1

Merged
merged 33 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d305bc0
Implement prototype
bencottier Mar 5, 2021
318c726
Apply suggestions from code review
bencottier Mar 8, 2021
fff2948
Limit KeyedSampleable methods
bencottier Mar 9, 2021
1390c89
Merge branch 'bc/prototype' of https://github.com/invenia/KeyedDistri…
bencottier Mar 9, 2021
f827bfa
Rearrange tests after refactor
bencottier Mar 9, 2021
2d53943
Bugfix: incorrect `==`
bencottier Mar 9, 2021
177f619
Reorganise tests
bencottier Mar 9, 2021
64bedc0
Fix main branch references
bencottier Mar 9, 2021
5b42400
Update julia compat to 1.3
bencottier Mar 10, 2021
5951cd9
Bugfix: Documenter cannot find docstring
bencottier Mar 10, 2021
c120037
Remove CI on 1.0 and change julia compat to 1.5
bencottier Mar 10, 2021
7ae0b91
Set main as devbranch to build docs
bencottier Mar 10, 2021
590874f
Test _rand!
bencottier Mar 10, 2021
fbd6e19
Remove axiskeys type piracy of Sampleable
bencottier Mar 11, 2021
b4fd1af
Use StableRNGs
bencottier Mar 11, 2021
1ac5d27
Remove parent and keyless methods
bencottier Mar 11, 2021
e97b1da
Make keys a tuple of vectors
bencottier Mar 11, 2021
837af84
Add subtyping tests
bencottier Mar 11, 2021
2d7dc6d
Support matrix-variate and univariate distributions
bencottier Mar 11, 2021
d00c8c1
Add default keys for no-keys constructor
bencottier Mar 11, 2021
774e083
Remove unnecessary type checks
bencottier Mar 11, 2021
84da47f
Clarify docstrings
bencottier Mar 11, 2021
ec836a5
Generalise cov for Matrixvariate
bencottier Mar 11, 2021
104f353
Update tests for KeyedArray/PDMats issue
bencottier Mar 12, 2021
ae1ab09
Compare more properties to underlying distribution
bencottier Mar 12, 2021
41cf7b7
Edit comment
bencottier Mar 12, 2021
c001761
Clear up
bencottier Mar 12, 2021
1210eb2
Simplify _rand!
bencottier Mar 12, 2021
00d34d3
Fix docs references
bencottier Mar 12, 2021
f217487
Add workaround for _logpdf with inner KeyedArray
bencottier Mar 12, 2021
fdcdebb
Clean up
bencottier Mar 12, 2021
4479cb4
Add univariate-only methods
bencottier Mar 15, 2021
1ef92ce
Refactor keys checking in constructor
bencottier Mar 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@ uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
AutoHashEquals = "0.2"
AxisKeys = "0.1"
Distributions = "0.24"
julia = "1"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["LinearAlgebra", "Statistics", "Test"]
116 changes: 115 additions & 1 deletion src/KeyedDistributions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,119 @@
module KeyedDistributions

# Write your package code here.
using AutoHashEquals
using AxisKeys
using AxisKeys: keyless
using Distributions
using Random: AbstractRNG

export KeyedDistribution, KeyedSampleable
export axiskeys, distribution


for T in (:Distribution, :Sampleable)
KeyedT = Symbol(:Keyed, T)
@eval begin
@auto_hash_equals struct $KeyedT{F, S, D<:$T{F, S}} <: $T{F, S}
d::D
keys::AbstractVector

"""
$($KeyedT)(d<:$($T), keys::AbstractVector)

Stores `keys` for each variate alongside the `$($T)` `d`.
"""
function $KeyedT(d::$T{F, S}, keys::AbstractVector) where {F, S}
length(d) == length(keys) || throw(DimensionMismatch(
"number of keys ($(length(keys))) must match " *
"number of variates ($(length(d)))"
))
return new{F, S, typeof(d)}(d, keys)
end
end
end
end

"""
KeyedDistribution(d::Distribution)

Constructs a [`KeyedDistribution`](@ref) using keys stored in `d`.
The keys are copied from the first axis of the first parameter in `d`.
"""
function KeyedDistribution(d::D) where D <: Distribution
bencottier marked this conversation as resolved.
Show resolved Hide resolved
first_param = getfield(d, 1)
keys = first(axiskeys(first_param)) # axiskeys guaranteed to be Tuple{AbstractVector}?
bencottier marked this conversation as resolved.
Show resolved Hide resolved
return KeyedDistribution(d, keys)
end

const KeyedDistOrSampleable = Union{KeyedDistribution, KeyedSampleable}

# Access methods

"""
distribution(::KeyedDistribution) -> Distribution
distribution(::KeyedSampleable{F, S, D}) -> D

Return the wrapped distribution.
"""
distribution(d::KeyedDistOrSampleable) = d.d

# AxisKeys functionality

Base.parent(d::KeyedDistOrSampleable) = d.d
bencottier marked this conversation as resolved.
Show resolved Hide resolved

AxisKeys.keyless(d::KeyedDistOrSampleable) = parent(d)
bencottier marked this conversation as resolved.
Show resolved Hide resolved

"""
axiskeys(s::Sampleable)
bencottier marked this conversation as resolved.
Show resolved Hide resolved

Return the keys for the variates of the Sampleable.
For an [`KeyedDistribution`](@ref) or [`KeyedSampleable`](@ref) this
is the keys it was constructed with.
For any other `Sampleable` this is equal to `1:length(s)`.
"""
AxisKeys.axiskeys(d::KeyedDistOrSampleable) = tuple(d.keys)
bencottier marked this conversation as resolved.
Show resolved Hide resolved
AxisKeys.axiskeys(d::Sampleable) = tuple(Base.OneTo(length(d)))
bencottier marked this conversation as resolved.
Show resolved Hide resolved

# Standard functions to overload for new Distribution and/or Sampleable
# https://juliastats.org/Distributions.jl/latest/extends/#Create-a-Distribution

Distributions.sampler(d::KeyedDistOrSampleable) = sampler(keyless(d))

function Distributions._rand!(
rng::AbstractRNG,
d::KeyedDistOrSampleable,
x::AbstractVector{T}
bencottier marked this conversation as resolved.
Show resolved Hide resolved
) where T<:Real
sample = Distributions._rand!(rng, parent(d), x)
return KeyedArray(sample, axiskeys(d))
end

function Distributions._logpdf(d::KeyedDistOrSampleable, x::AbstractArray)
return Distributions._logpdf(parent(d), x)
end

Base.length(d::KeyedDistOrSampleable) = length(keyless(d))

Base.eltype(d::KeyedDistOrSampleable) = eltype(keyless(d))

# Also need to overload `rand` methods to return a KeyedArray

Base.rand(rng::AbstractRNG, d::KeyedDistOrSampleable) =
KeyedArray(rand(rng, parent(d)), axiskeys(d))

Base.rand(rng::AbstractRNG, d::KeyedDistOrSampleable, n::Int) =
KeyedArray(rand(rng, parent(d), n), (first(axiskeys(d)), Base.OneTo(n)))

# Statistics functions

Distributions.mean(d::KeyedDistOrSampleable) = KeyedArray(mean(keyless(d)), axiskeys(d))

Distributions.var(d::KeyedDistOrSampleable) = KeyedArray(var(keyless(d)), axiskeys(d))

Distributions.cov(d::KeyedDistOrSampleable) =
KeyedArray(cov(keyless(d)), (first(axiskeys(d)), first(axiskeys(d))))

Distributions.entropy(d::KeyedDistOrSampleable) = entropy(keyless(d))
Distributions.entropy(d::KeyedDistOrSampleable, b::Real) = entropy(keyless(d), b)

end
110 changes: 110 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,116 @@
using AxisKeys
using Distributions
using KeyedDistributions
using LinearAlgebra
using Random
using Statistics
using Test

@testset "KeyedDistributions.jl" begin
bencottier marked this conversation as resolved.
Show resolved Hide resolved
# Write your tests here.

@testset "Inner keys constructor" begin
keys = [:a, :b]
m = KeyedArray([1., 2.], keys)
d = MvNormal(m, [1., 1.])
kd = KeyedDistribution(d)

@test distribution(kd) == d
@test axiskeys(kd) == (keys, )
@test mean(kd) == m
end

@testset "Common" begin
X = rand(MersenneTwister(1234), 10, 5)
m = vec(mean(X; dims=1))
s = cov(X; dims=1)
d = MvNormal(m, s)
keys = [:a, :b, :c, :d, :e]

@testset for T in (KeyedDistribution, KeyedSampleable)
kd = T(d, keys)

@testset "base functions" begin
@test kd isa Sampleable
@test distribution(kd) == d
@test parent(kd) == d
@test axiskeys(kd) == (keys, )
@test length(kd) == length(d) == 5
@test eltype(kd) == eltype(d) == Float64
@test isequal(kd, T(d, [:a, :b, :c, :d, :e]))
@test ==(kd, T(d, keys))
end

@testset "statistical functions" begin
@test mean(kd) isa KeyedArray{Float64, 1}
@test parent(mean(kd)) == mean(d) == m
# @test axisnames(mean(kd)) == (:variates,)

@test var(kd) isa KeyedArray{Float64, 1}
@test parent(var(kd)) == var(d) == diag(s)
# @test axisnames(var(kd)) == (:variates,)

@test cov(kd) isa KeyedArray{Float64, 2}
@test parent(cov(kd)) == cov(d) == s
# @test axisnames(cov(kd)) == (:variates, :variates_)

@test entropy(kd) isa Number
@test entropy(kd) == entropy(d)
@test entropy(kd, 2) == entropy(d, 2)

@test Distributions._logpdf(kd, m) isa Number
@test Distributions._logpdf(kd, m) == Distributions._logpdf(d, m)

# statistical functions commute with parent on KeyedArray/KeyedDistribution
for f in (mean, var, cov)
@test f(parent(kd)) == parent(f(kd))
end
end

@testset "sampling" begin
# Samples from the distribution both wrapped and unwrapped should be the same.
@test rand(MersenneTwister(1), d) == rand(MersenneTwister(1), kd)
@test rand(MersenneTwister(1), d, 3) == rand(MersenneTwister(1), kd, 3)

rng = MersenneTwister(1)
glennmoy marked this conversation as resolved.
Show resolved Hide resolved

@testset "one-sample method" begin
expected = [
0.6048058078690228,
0.5560878435408365,
0.41599188102577894,
0.4756226986245742,
0.15366818427047801,
]
observed = rand(rng, kd)
@test observed isa KeyedArray
@test isapprox(observed, expected)
@test isapprox(observed(:a), expected[1])
# @test axisnames(observed) == (:variates,)
@test first(axiskeys(observed)) == first(axiskeys(kd))
end

@testset "multi-sample method" begin
expected = [
0.6080151671094673,
1.2415182218538203,
-4.4285504138930065e-5,
0.7298398256256964,
0.2103467702699237,
]
observed = rand(rng, kd, 1)
glennmoy marked this conversation as resolved.
Show resolved Hide resolved
@test observed isa KeyedArray
@test isapprox(observed, expected)
@test isapprox(observed(:a), [expected[1]])
# @test axisnames(observed) == (:variates, :samples)
@test first(axiskeys(observed)) == first(axiskeys(kd))
end
end
end
end

@testset "Invalid keys $T" for T in (KeyedDistribution, KeyedSampleable)
@test_throws DimensionMismatch T(MvNormal(ones(3)), ["foo"])
end

end