Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet tensor #219

Merged
merged 73 commits into from
Jan 15, 2025
Merged

Dirichlet tensor #219

merged 73 commits into from
Jan 15, 2025

Conversation

Raphael-Tresor
Copy link
Contributor

@Raphael-Tresor Raphael-Tresor commented Dec 24, 2024

Issue

The branch is not ExponentialFamily compatible yet. The pack unpack mechanism does not work.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?

Second, why to do all these shenanigans with a[:, m, n] when you could simply to an array of Dirichlet distributions?

E.g.

a = [
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]

and then later on simply

a[1, 2] # returns Dirichlet 

I'm also not sure if the functionality is really correct, the logpdf can't be correct right? Why tests didn't pick it up?

@wouterwln
Copy link
Member

@bvdmitri This is a generalization of MatrixDirichlet (which is not really a Matrix-Dirichlet distribution but rather a collection of Dirichlets stored as a matrix). This distribution is the conjugate prior for the parameters of the general Transition node, which we need to do POMDPs. It is just a generalization of MatrixDirichlet, so I would say it is in the Exponential Family.

Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
@Raphael-Tresor
Copy link
Contributor Author

Raphael-Tresor commented Jan 7, 2025

I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?

Second, why to do all these shenanigans with a[:, m, n] when you could simply to an array of Dirichlet distributions?

E.g.

a = [
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]

and then later on simply

a[1, 2] # returns Dirichlet 

I'm also not sure if the functionality is really correct, the logpdf can't be correct right? Why tests didn't pick it up?

Thank yo for the review @bvdmitri.

I discussed this with @Nimrais, and I think it is indeed a member of the exponential family.
About the implementation: creating an array of Dirichlets was my first idea, but after a discussion with @wouterwln I changed my mind. It seems that an array of Dirichlets might be slow and working directly with Float-array should be more efficient. I will think about it again.

@bvdmitri
Copy link
Member

bvdmitri commented Jan 7, 2025

@Raphael-Tresor

It seems that an array of Dirichlets might be slow and working directly with Float-array should be more efficient. I will think about it again.

I agree that working directly with Float-array, but only if you implement it correctly. I doubt that the current implementation is faster or more efficient because it creates a lot of slices and allocates new memory on every access, which is perhaps even slower than using arrays of Dirichlet distributions. You sometimes use @view but not everywhere. I don't have any benchmarks, it just more like a feeling. This being said, I don't really have a strong opinion on how to implement it and leave the choice to you.

@wouterwln
Copy link
Member

wouterwln commented Jan 7, 2025

@bvdmitri I'll work on performance, for now we just need a correct generalization of MatrixDirichlet for higher order tensors, such that we can parameterize Transition and TransitionMixture nodes.

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

@bvdmitri I'll work on performance, for now we just need a correct generalization of MatrixDirichlet for higher order tensors, such that we can parameterize Transition and TransitionMixture nodes.

If the performance is not a goal at this point, why not make an array of Dirichlet distributions? It's far easier to understand at least.

function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
l = size(getnaturalparameters(ef))
values = [x[:,i] for i in CartesianIndices(Base.tail(size(x)))]
## The element of the array should be the a categorical distribution (an vector of postive value that sum to 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a comment, this is the documentation make it a docstring, smt like this

"""
BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)

Check if the input `x` is within the support of a Tensor Dirichlet distribution.

Requirements:
- Each column of x must represent a valid categorical distribution (sum to 1, all values ≥ 0)
- Dimensions must match the natural parameters of the distribution

"""

Base.size(dist::TensorDirichlet) = size(dist.a)
Base.eltype(::TensorDirichlet{T}) where {T} = T

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int)
Copy link
Member

@Nimrais Nimrais Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of the vague methods (with a tuple), you specify the type of ones, while in another, you don't. Is there any reason why you aren't writing it like this?

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) 
    return TensorDirichlet(ones(Float64, dims, dims))
end

Copy link
Contributor Author

@Raphael-Tresor Raphael-Tresor Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, I missed this difference.
Float64 seems to be the default type, so it should be the same.
I will add the Float64 argument for clarity.
Thanks

a::A
end

extract_collection(dist::TensorDirichlet) = [dist.a[:,i] for i in CartesianIndices(Base.tail(size(dist.a)))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract_collection(dist::TensorDirichlet) = (view(dist.a, :, i) for i in CartesianIndices(Base.tail(size(dist.a))))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes with a '@view' it will be way better. I will take your suggestion.

@wouterwln
Copy link
Member

@Nimrais Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore.

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore.

Sure, but then it would be nice to see the proof that it works faster. (@Raphael-Tresor it's more or less comment for you for the next time)

For example, a test that shows that the current way of implementing things is better than the naïve. Because if, at some point in time, it is starting to become slower, what is the point?

For example, smt like this, but written carefully

using ExponentialFamily
using BenchmarkTools
using Random
using Distributions
using Test

struct ArrayDirichlet{T}
    distributions::Array{Dirichlet{T}, 2}
end

function array_dirichlet(params::Array{T, 3}) where T
    m, n, k = size(params)
    dists = Array{Dirichlet{T}, 2}(undef, n, k)
    for i in 1:n, j in 1:k
        dists[i,j] = Dirichlet(params[:,i,j])
    end
    return ArrayDirichlet(dists)
end

struct TensorDirichlet{T}
    a::Array{T, 3}
end

dim = 10
grid_size = 20 
params = rand(dim, grid_size, grid_size) .+ 0.1

tensor_impl = TensorDirichlet(params)
array_impl = array_dirichlet(params)

test_points = [rand(dim, grid_size, grid_size) for _ in 1:100]
for x in test_points
    x ./= sum(x, dims=1)
end

tensor_bench = @benchmark for x in $test_points
    for i in 1:grid_size, j in 1:grid_size
        logpdf(Dirichlet(view($tensor_impl.a,:,i,j)), view(x,:,i,j))
    end
end

array_bench = @benchmark for x in $test_points
    for i in 1:grid_size, j in 1:grid_size
        logpdf($array_impl.distributions[i,j], view(x,:,i,j))
    end
end

println("Tensor Implementation:")
println("---------------------")
show(stdout, MIME("text/plain"), tensor_bench)
println("\n\nArray Implementation:")
println("--------------------")
show(stdout, MIME("text/plain"), array_bench)
@test min(array_bench.times...) > min(tensor_bench.times...)

@wouterwln
Copy link
Member

@Nimrais you mean something like this?

using ExponentialFamily
using BenchmarkTools
using Random
using Distributions
using Test
using BayesBase

struct ArrayDirichlet{T,S,P, N}
    distributions::AbstractArray{Dirichlet{T, S, P}, N}
end

function array_dirichlet(params::Array{T, N}) where {T, N}
    d, k... = size(params)
    return ArrayDirichlet(Dirichlet.(eachslice(params;dims=Tuple(2:N))))
end
function BayesBase.mean(d::ArrayDirichlet) where {T, N}
    return mean.(d.distributions)
end
dim = 10
grid_size = 20 
params = rand(dim, grid_size, grid_size) .+ 0.1

tensor_impl = TensorDirichlet(params)
array_impl = array_dirichlet(params)

tensor_bench = @benchmark mean($tensor_impl)

array_bench = @benchmark mean($array_impl)

println("Tensor Implementation:")
println("---------------------")
show(stdout, MIME("text/plain"), tensor_bench)
println("\n\nArray Implementation:")
println("--------------------")
show(stdout, MIME("text/plain"), array_bench)
@test min(array_bench.times...) > min(tensor_bench.times...)
Tensor Implementation:
---------------------
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.380 μs … 296.157 μs  ┊ GC (min … max):  0.00% … 97.94%
 Time  (median):     4.148 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   5.627 μs ±  19.098 μs  ┊ GC (mean ± σ):  27.80% ±  8.09%

                   ▃▆██▆▃                                      
  ▄▇▅▄▄▂▂▁▁▁▁▁▁▂▃▅▇███████▆▅▄▃▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2.38 μs         Histogram: frequency by time        7.62 μs <

 Memory estimate: 34.69 KiB, allocs estimate: 3.

Array Implementation:
--------------------
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.500 μs …  1.076 ms  ┊ GC (min … max): 0.00% … 75.59%
 Time  (median):     10.500 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.139 μs ± 24.553 μs  ┊ GC (mean ± σ):  4.44% ±  3.29%

  ▃▇██▆▄▁                                                     ▂
  ███████▇▆██████▇▆▅▆▄▅▅▃▄▄▄▄▄▄▄▁▄▃▅▄▄▅▃▅▄▅▃▄▁▄▃▄▄▄▅▁▅▃▄▃▃▃▄▄ █
  9.5 μs       Histogram: log(frequency) by time      28.9 μs <

 Memory estimate: 59.66 KiB, allocs estimate: 403.
Test Passed

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

@Nimrais you mean something like this

Yeah, but we have more than only the mean method; I did an example of this test for the logpdf you did for the mean method.

In my example, because the logpdf "method" for TensorDirechlet is not optimized, it takes longer than the naïve implementation.

min(array_bench.times...) > min(tensor_bench.times...) is not passing in my example.

Test Failed at /Users/mykola/repos/biaslab/ExponentialFamily.jl/test_2.jl:54
  Expression: min(array_bench.times...) > min(tensor_bench.times...)
   Evaluated: 1.2759209e7 > 3.3939458e7

@Nimrais Nimrais marked this pull request as ready for review January 14, 2025 11:38
Copy link

codecov bot commented Jan 14, 2025

Codecov Report

Attention: Patch coverage is 90.76923% with 18 lines in your changes missing coverage. Please review.

Project coverage is 81.76%. Comparing base (95af252) to head (2841038).
Report is 131 commits behind head on main.

Files with missing lines Patch % Lines
src/distributions/tensor_dirichlet.jl 89.15% 18 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #219      +/-   ##
==========================================
+ Coverage   80.49%   81.76%   +1.27%     
==========================================
  Files          39       41       +2     
  Lines        3117     3477     +360     
==========================================
+ Hits         2509     2843     +334     
- Misses        608      634      +26     
Flag Coverage Δ
?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@wouterwln wouterwln left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're ready to merge. There is still some minor stuff I will work on but everything considered I think we're ready to merge.

end

function BayesBase.pdf(dist::TensorDirichlet, xs::AbstractVector)
return map(x -> pdf(dist, x), xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this function @Nimrais ? Because I can spot a bug as it is right now. Either we delete this function or I can patch the bug. Let me know.

Copy link
Member

@Nimrais Nimrais Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the implemented rand(Td, n) returns a vector of samples, I don't think it's easy to fix. To be compatible with both Distributions.jl and ExponentialFamily.jl, the distribution needs to implement a method that will evaluate the pdf of a collection of samples. And we have a test for that; if you will remove this method, the test set will fail.

So the short answer is yes, we need it.

The long answer is that the bug is not only here but also in the rand method because the rand method doesn't work according to the Distributions.jl API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and strictly speaking, there is a bug if TensorDirichlet samples are vectors themselves, but this is only possible if we have only one Dirichlet. So first we can just raise an error if someone wants to use TensorDirichlet for that.

@wouterwln wouterwln requested review from Nimrais and bvdmitri January 14, 2025 12:08
@wouterwln
Copy link
Member

by the way @Nimrais @bvdmitri This completely replaces MatrixDirichlet. We should probably remove it. What do you think? Strictly speaking because we break some existing functionality we should then release a new major version of the package.

@Nimrais
Copy link
Member

Nimrais commented Jan 14, 2025

by the way @Nimrais @bvdmitri This completely replaces MatrixDirichlet. We should probably remove it. What do you think? Strictly speaking because we break some existing functionality we should then release a new major version of the package.

We can do it in a separate PR. We can have it anyway but re-use code here to implement it. Also, MatrixDirechlet could be slightly more performant in some edge cases.

@wouterwln wouterwln merged commit 48f3392 into main Jan 15, 2025
4 checks passed
@wouterwln wouterwln deleted the DirichletTensor branch January 15, 2025 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants