Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KeyedDistribution and KeyedSampleable #1

Merged
merged 33 commits into from
Mar 16, 2021
Merged

Conversation

bencottier
Copy link
Contributor

@bencottier bencottier commented Mar 5, 2021

This proof of concept for KeyedDistributions extends the AxisKeys.jl ecosystem with KeyedSampleable and KeyedDistribution types. The types thinly wrap Sampleables with a vector of keys, corresponding to the variates of the Sampleable. This is analogous to KeyedArray wrapping an array.

Not implemented:

@rofinn
Copy link
Member

rofinn commented Mar 5, 2021

Callable lookup syntax (not sure what this would do)

I feel like that would likely just return the same thing as marginalise?

@glennmoy
Copy link
Member

glennmoy commented Mar 8, 2021

Should Sampleable have all the statistical functions e.g. mean, in addition to Distribution? This doesn't mention it https://juliastats.org/Distributions.jl/latest/extends/#Create-a-Distribution

To subtype Sampleable I believe you only have to implement rand (univariate) or length and _rand! (multivariate).

src/KeyedDistributions.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
@glennmoy
Copy link
Member

glennmoy commented Mar 8, 2021

One of the equality tests is failing, so there may be a problem with my struct implementation

Has it anything to do with using ==() instead of isequal(), or constructing the KeyedDistribution inside the function?

@glennmoy
Copy link
Member

glennmoy commented Mar 8, 2021

Callable lookup syntax (not sure what this would do)

I feel like that would likely just return the same thing as marginalise?

Yeah I think this means extracting the relevant elements by calling the keyed distribution like a function (as AxisKeys can do)

julia> A = KeyedArray([0.1, 0.2, 0.3, 0.4, 0.5], :obj=>[:a, :b, :c, :d, :e])

julia> d = MvNormal(A)

julia> kd = KeyedDistribution(d)

julia> kd(obj=[:a, :c, :d])  # returns another KD marginalised on [a, c, d]

@glennmoy
Copy link
Member

glennmoy commented Mar 8, 2021

I think we can open issues for the remaining tasks and in this PR nail down the Distributions API at least.

What's your sense of the time/effort this would take?

Co-authored-by: Glenn Moynihan <glenn.moynihan@invenialabs.co.uk>
@bencottier
Copy link
Contributor Author

Should Sampleable have all the statistical functions e.g. mean, in addition to Distribution? This doesn't mention it https://juliastats.org/Distributions.jl/latest/extends/#Create-a-Distribution

To subtype Sampleable I believe you only have to implement rand (univariate) or length and _rand! (multivariate).

I thought so, it's just that IndexedDistributions implemented the statistical functions for both, so I wonder if it's expected somewhere in our code.

Also I believe we have to overload rand regardless to return a KeyedArray. It was a regular array without.

@bencottier
Copy link
Contributor Author

bencottier commented Mar 8, 2021

Callable lookup syntax (not sure what this would do)

I feel like that would likely just return the same thing as marginalise?

Yeah I think this means extracting the relevant elements by calling the keyed distribution like a function (as AxisKeys can do)

julia> A = KeyedArray([0.1, 0.2, 0.3, 0.4, 0.5], :obj=>[:a, :b, :c, :d, :e])

julia> d = MvNormal(A)

julia> kd = KeyedDistribution(d)

julia> kd(obj=[:a, :c, :d])  # returns another KD marginalised on [a, c, d]

That makes sense, just note it only corresponds to marginalising for some distributions (I think in the exponential family e.g. normal, t distribution). With this syntax, we could have a marginalize method for clarity that simply uses it under the hood.

@bencottier
Copy link
Contributor Author

I think we can open issues for the remaining tasks and in this PR nail down the Distributions API at least.

What's your sense of the time/effort this would take?

I estimate at most 1 day spent on this.

@bencottier
Copy link
Contributor Author

bencottier commented Mar 9, 2021

One of the equality tests is failing, so there may be a problem with my struct implementation

Has it anything to do with using ==() instead of isequal(), or constructing the KeyedDistribution inside the function?

EDIT: the fix was to specify F<:VariateForm, S<:ValueSupport.

If I do kd2 = KeyedDistribution(d, keys); kd == kd2 it's the same error:

ERROR: type KeyedDistribution has no field KeyedDistribution
Stacktrace:
 [1] getproperty(::KeyedDistribution{Multivariate,Continuous,MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}}, ::Symbol) at ./Base.jl:33
 [2] ==(::KeyedDistribution{Multivariate,Continuous,MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}}, ::KeyedDistribution{Multivariate,Continuous,MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}}) at /Users/bencottier/.julia/packages/KeyedDistributions/i1TQS/src/KeyedDistributions.jl:15
 [3] top-level scope at REPL[36]:1

Whereas isequal(kd, kd2) == true.

The fields and types are all equal. I assume it relates to the way @auto_hash_equals processes the struct.

Fixed by specifying the type of `F` and `S`
- Increase coverage
- Separate Distribution-only methods
@bencottier bencottier changed the title POC: Implement KeyedDistribution and KeyedSampleable Implement KeyedDistribution and KeyedSampleable Mar 9, 2021
@codecov
Copy link

codecov bot commented Mar 9, 2021

Codecov Report

Merging #1 (1ef92ce) into main (883d9c9) will increase coverage by 98.00%.
The diff coverage is 98.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##           main       #1       +/-   ##
=========================================
+ Coverage      0   98.00%   +98.00%     
=========================================
  Files         0        1        +1     
  Lines         0       50       +50     
=========================================
+ Hits          0       49       +49     
- Misses        0        1        +1     
Impacted Files Coverage Δ
src/KeyedDistributions.jl 98.00% <98.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 883d9c9...1ef92ce. Read the comment docs.

@bencottier
Copy link
Contributor Author

bencottier commented Mar 9, 2021

I think we can open issues for the remaining tasks and in this PR nail down the Distributions API at least.

Issues: #2 #3 #4 #5

src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
@bencottier
Copy link
Contributor Author

@mcabbott We're interested in what you think of this package and this PR. It applies the idea of a KeyedArray to a Distribution, and is meant to extend the AxisKeys ecosystem.

src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Outdated Show resolved Hide resolved
src/KeyedDistributions.jl Show resolved Hide resolved
@glennmoy
Copy link
Member

glennmoy commented Mar 12, 2021

I think this is missing a few methods from the Distributions interface for Univariates (we have all the Sampler functions).
These are listed here.
All of these can be done by just looping over them and using an @eval macro to delegate to the underlying distribution.

If you wanna push these to another PR that's fine.

Univariate

  • rand(::AbstractRNG, d::UnivariateDistribution)
  • sampler(d::Distribution)
  • logpdf(d::UnivariateDistribution, x::Real)
  • cdf(d::UnivariateDistribution, x::Real)
  • quantile(d::UnivariateDistribution, q::Real)
  • minimum(d::UnivariateDistribution)
  • maximum(d::UnivariateDistribution)
  • insupport(d::UnivariateDistribution, x::Real)

these are recommended

  • mean(d::UnivariateDistribution)
  • var(d::UnivariateDistribution)
  • modes(d::UnivariateDistribution)
  • mode(d::UnivariateDistribution)
  • skewness(d::UnivariateDistribution)
  • kurtosis(d::Distribution, ::Bool)
  • entropy(d::UnivariateDistribution, ::Real)
  • mgf(d::UnivariateDistribution, ::Any)
  • cf(d::UnivariateDistribution, ::Any)

Multi-variate

  • length(d::MultivariateDistribution)
  • sampler(d::Distribution)
  • eltype(d::Distribution)
  • Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)
  • Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)Matrix-variate

Matrix-variate

  • size(d::MatrixDistribution)
  • _rand!(rng::AbstractRNG, s::Spl, x::DenseMatrix{<:Real})
  • rand(d::MatrixDistribution)
  • sampler(d::MatrixDistribution)
  • Distributions._logpdf(d::MatrixDistribution, x::AbstractArray)

key_lengths = map(length, keys)
key_lengths == _size(d) || throw(ArgumentError(
"lengths of key vectors $key_lengths must match " *
"size of distribution $(_size(d))"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not be pirating size, but "size" in this message may be misleading. I did it for the sake of generalising to one error message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove reference to size in the error message?

"Dimensions of key vectors $key_lengths must match the distribution $(_size(d))"))

unless you think referring to dimensions is even more confusing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of dimensions, because it's ambiguous whether it's the number of dimensions or the size on each dimension. But you know what, this is testing both of those things, so maybe it's just right.

@eval Distributions.$f(d::KeyedDistribution{<:Univariate}) = $f(distribution(d))
end

# Needed to avoid method ambiguity
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method ambiguity errors from the tests, from old code when I put the below functions in the first @eval loop

MethodError: cdf(::KeyedDistribution{Univariate,Continuous,Normal{Float64}}, ::Int64) is ambiguous. Candidates:
    cdf(d::Distribution{Univariate,Continuous}, x::Real) in Distributions at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/univariates.jl:367
    cdf(d::KeyedDistribution{var"#s12",S,D} where D<:Distribution{var"#s12",S} where S<:ValueSupport where var"#s12"<:Univariate, args...) in KeyedDistributions at /Users/bencottier/JuliaEnvs/KeyedDistributions/src/KeyedDistributions.jl:173
  Possible fix, define
    cdf(::KeyedDistribution{Univariate,Continuous,D} where D<:Distribution{Univariate,Continuous}, ::Real)
  MethodError: insupport(::KeyedDistribution{Univariate,Continuous,Normal{Float64}}, ::Int64) is ambiguous. Candidates:
    insupport(d::Union{Type{D}, D}, x::Real) where D<:Distribution{Univariate,Continuous} in Distributions at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/univariates.jl:127
    insupport(d::KeyedDistribution{var"#s12",S,D} where D<:Distribution{var"#s12",S} where S<:ValueSupport where var"#s12"<:Univariate, b) in KeyedDistributions at /Users/bencottier/JuliaEnvs/KeyedDistributions/src/KeyedDistributions.jl:179
  Possible fix, define
    insupport(::D, ::Real) where D<:(KeyedDistribution{Univariate,Continuous,D} where D<:Distribution{Univariate,Continuous})

Copy link
Member

@glennmoy glennmoy Mar 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because we haven't provided enough type information to the function signature to restrict it to using our method.

We've only specified that it needs to be Univariate but Distributions provides the ValueSupport. So when it sees KeyedDistribution{Univariate, Continuous} it's conflicted between which version to use.

FWIW I don't understand why it still doesn't break outside the @eval loop, but we can/should fix it with the following

Distributions.cdf(d::KeyedArray{F, S, <:Distribution{F, S}, x) where {F, S} = cdf(distribution(d), x)

same with insupport.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't know why it worked this way.

end

@testset "Distributions types" begin
@testset "Univariate" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note we only test on Continuous distributions. We should probably test on Discrete at some point.
#7

@bencottier bencottier merged commit 92bb274 into main Mar 16, 2021
@bencottier bencottier deleted the bc/prototype branch March 16, 2021 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants