diff --git a/Project.toml b/Project.toml index 2cf2679..fa000fe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KeyedDistributions" uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.9" +version = "0.2.0" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/KeyedDistributions.jl b/src/KeyedDistributions.jl index 866b41a..584a5f9 100644 --- a/src/KeyedDistributions.jl +++ b/src/KeyedDistributions.jl @@ -265,4 +265,13 @@ function Distributions.insupport(d::KeyedDistribution{<:Univariate}, x::Real) return insupport(distribution(d), x) end +# Overload equality comparison between `KeyedT` and underlying `T` +for T in (:Distribution, :Sampleable) + KeyedT = Symbol(:Keyed, T) + @eval begin + Base.:(==)(kd::$KeyedT, d::$T) = distribution(kd) == d + Base.:(==)(d::$T, kd::$KeyedT) = kd == d + end +end + end diff --git a/test/runtests.jl b/test/runtests.jl index 14ff3d4..084e0d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -364,4 +364,18 @@ using Test @test dimnames(rename(kd, (:name,))) == (:name, ) @test dimnames(rename(kd, :id => :name)) == (:name, ) end + + @testset "Equality comparison with wrapped type" begin + d = MvNormal([1.0, 2.0], [1.0, 1.0]); + + kd = KeyedDistribution(d, 1:length(d)); + + @test kd == d + @test d == kd + + ks = KeyedSampleable(d, 1:length(d)); + + @test ks == d + @test d == ks + end end