-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dirichlet tensor #219
Dirichlet tensor #219
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?
Second, why to do all these shenanigans with a[:, m, n]
when you could simply to an array of Dirichlet
distributions?
E.g.
a = [
Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]
and then later on simply
a[1, 2] # returns Dirichlet
I'm also not sure if the functionality is really correct, the logpdf
can't be correct right? Why tests didn't pick it up?
@bvdmitri This is a generalization of |
Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
Thank yo for the review @bvdmitri. I discussed this with @Nimrais, and I think it is indeed a member of the exponential family. |
I agree that working directly with Float-array, but only if you implement it correctly. I doubt that the current implementation is faster or more efficient because it creates a lot of slices and allocates new memory on every access, which is perhaps even slower than using arrays of Dirichlet distributions. You sometimes use |
@bvdmitri I'll work on performance, for now we just need a correct generalization of |
Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
…ponentialFamily.jl into DirichletTensor
If the performance is not a goal at this point, why not make an array of Dirichlet distributions? It's far easier to understand at least. |
function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x) | ||
l = size(getnaturalparameters(ef)) | ||
values = [x[:,i] for i in CartesianIndices(Base.tail(size(x)))] | ||
## The element of the array should be the a categorical distribution (an vector of postive value that sum to 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a comment, this is the documentation make it a docstring, smt like this
"""
BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
Check if the input `x` is within the support of a Tensor Dirichlet distribution.
Requirements:
- Each column of x must represent a valid categorical distribution (sum to 1, all values ≥ 0)
- Dimensions must match the natural parameters of the distribution
"""
Base.size(dist::TensorDirichlet) = size(dist.a) | ||
Base.eltype(::TensorDirichlet{T}) where {T} = T | ||
|
||
function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In one of the vague
methods (with a tuple), you specify the type of ones, while in another, you don't. Is there any reason why you aren't writing it like this?
function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int)
return TensorDirichlet(ones(Float64, dims, dims))
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific reason, I missed this difference.
Float64 seems to be the default type, so it should be the same.
I will add the Float64
argument for clarity.
Thanks
a::A | ||
end | ||
|
||
extract_collection(dist::TensorDirichlet) = [dist.a[:,i] for i in CartesianIndices(Base.tail(size(dist.a)))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract_collection(dist::TensorDirichlet) = (view(dist.a, :, i) for i in CartesianIndices(Base.tail(size(dist.a))))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes with a '@view' it will be way better. I will take your suggestion.
@Nimrais Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore. |
Sure, but then it would be nice to see the proof that it works faster. (@Raphael-Tresor it's more or less comment for you for the next time) For example, a test that shows that the current way of implementing things is better than the naïve. Because if, at some point in time, it is starting to become slower, what is the point? For example, smt like this, but written carefully
|
@Nimrais you mean something like this?
|
Yeah, but we have more than only the In my example, because the
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #219 +/- ##
==========================================
+ Coverage 80.49% 81.76% +1.27%
==========================================
Files 39 41 +2
Lines 3117 3477 +360
==========================================
+ Hits 2509 2843 +334
- Misses 608 634 +26
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're ready to merge. There is still some minor stuff I will work on but everything considered I think we're ready to merge.
end | ||
|
||
function BayesBase.pdf(dist::TensorDirichlet, xs::AbstractVector) | ||
return map(x -> pdf(dist, x), xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this function @Nimrais ? Because I can spot a bug as it is right now. Either we delete this function or I can patch the bug. Let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the implemented rand(Td, n)
returns a vector of samples, I don't think it's easy to fix. To be compatible with both Distributions.jl and ExponentialFamily.jl, the distribution needs to implement a method that will evaluate the pdf of a collection of samples. And we have a test for that; if you will remove this method, the test set will fail.
So the short answer is yes, we need it.
The long answer is that the bug is not only here but also in the rand method because the rand method doesn't work according to the Distributions.jl API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and strictly speaking, there is a bug if TensorDirichlet samples are vectors themselves, but this is only possible if we have only one Dirichlet. So first we can just raise an error if someone wants to use TensorDirichlet for that.
We can do it in a separate PR. We can have it anyway but re-use code here to implement it. Also, MatrixDirechlet could be slightly more performant in some edge cases. |
Issue
The branch is not ExponentialFamily compatible yet. The pack unpack mechanism does not work.