diff --git a/Project.toml b/Project.toml index 93b28de..58bccee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KeyedDistributions" uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.11" +version = "0.1.12" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/KeyedDistributions.jl b/src/KeyedDistributions.jl index 9f7bb8d..988d14b 100644 --- a/src/KeyedDistributions.jl +++ b/src/KeyedDistributions.jl @@ -54,6 +54,12 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) + if d isa Distribution && mean(d) isa KeyedArray && !(axiskeys(mean(d)) == keys) + throw(ArgumentError( + "Distribution keys $(axiskeys(mean(d))) do not match " * + "KeyedDistribution keys $(keys)" + )) + end L = Tuple(:_ for _ in 1:length(key_lengths)) return new{F, S, typeof(d), L}(d, keys) end @@ -65,6 +71,12 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) + if d isa Distribution && mean(d) isa KeyedArray && !(named_axiskeys(mean(d)) == named_keys) + throw(ArgumentError( + "Distribution keys $(named_axiskeys(mean(d))) do not match " * + "KeyedDistribution keys $(named_keys)" + )) + end return new{F, S, typeof(d), keys(named_keys)}(d, values(named_keys)) end diff --git a/test/runtests.jl b/test/runtests.jl index 28e9c54..73ff686 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -347,6 +347,14 @@ using Test @test d([1]) == d[[1]] == KeyedDistribution(GenericMvTDist(3, m[[1]], submat(W, [1])), [1]) end + + @testset "construct with distribution backed by KeyedArray" begin + ka = wrapdims(rand(3); t=["a", "b", "c"]) + mvn = MvNormal(ka, ones(3)) + @test_throws ArgumentError KeyedDistribution(mvn, ["a", "b", "not c"]) + @test_throws ArgumentError KeyedDistribution(mvn; t=["a", "b", "not c"]) + @test_throws ArgumentError KeyedDistribution(mvn; not_t=["a", "b", "c"]) + end end @testset "NamedDims functions" begin