From 6d3bdbd9650ab1e9b49645a64553023cf1422166 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 7 Nov 2022 16:57:56 +0000 Subject: [PATCH 1/2] match inner and outer keys --- Project.toml | 2 +- src/KeyedDistributions.jl | 12 ++++++++++++ test/runtests.jl | 8 ++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 93b28de..58bccee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KeyedDistributions" uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.11" +version = "0.1.12" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/KeyedDistributions.jl b/src/KeyedDistributions.jl index 9f7bb8d..2038d5a 100644 --- a/src/KeyedDistributions.jl +++ b/src/KeyedDistributions.jl @@ -54,6 +54,12 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) + if mean(d) isa KeyedArray && !(axiskeys(mean(d)) == keys) + throw(ArgumentError( + "Distribution keys $(axiskeys(mean(d))) do not match " * + "KeyedDistribution keys $(keys)" + )) + end L = Tuple(:_ for _ in 1:length(key_lengths)) return new{F, S, typeof(d), L}(d, keys) end @@ -65,6 +71,12 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) + if mean(d) isa KeyedArray && !(named_axiskeys(mean(d)) == named_keys) + throw(ArgumentError( + "Distribution keys $(named_axiskeys(mean(d))) do not match " * + "KeyedDistribution keys $(named_keys)" + )) + end return new{F, S, typeof(d), keys(named_keys)}(d, values(named_keys)) end diff --git a/test/runtests.jl b/test/runtests.jl index 28e9c54..73ff686 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -347,6 +347,14 @@ using Test @test d([1]) == d[[1]] == KeyedDistribution(GenericMvTDist(3, m[[1]], submat(W, [1])), [1]) end + + @testset "construct with distribution backed by KeyedArray" begin + ka = wrapdims(rand(3); t=["a", "b", "c"]) + mvn = MvNormal(ka, ones(3)) + @test_throws ArgumentError KeyedDistribution(mvn, ["a", "b", "not c"]) + @test_throws ArgumentError KeyedDistribution(mvn; t=["a", "b", "not c"]) + @test_throws ArgumentError KeyedDistribution(mvn; not_t=["a", "b", "c"]) + end end @testset "NamedDims functions" begin From d3e3cd0ea146bf2501d59501c00dda8fc8667b1c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 8 Nov 2022 14:02:45 +0000 Subject: [PATCH 2/2] fix Sampleable case --- src/KeyedDistributions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/KeyedDistributions.jl b/src/KeyedDistributions.jl index 2038d5a..988d14b 100644 --- a/src/KeyedDistributions.jl +++ b/src/KeyedDistributions.jl @@ -54,7 +54,7 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) - if mean(d) isa KeyedArray && !(axiskeys(mean(d)) == keys) + if d isa Distribution && mean(d) isa KeyedArray && !(axiskeys(mean(d)) == keys) throw(ArgumentError( "Distribution keys $(axiskeys(mean(d))) do not match " * "KeyedDistribution keys $(keys)" @@ -71,7 +71,7 @@ for T in (:Distribution, :Sampleable) "lengths of key vectors $key_lengths must match " * "size of distribution $(_size(d))" )) - if mean(d) isa KeyedArray && !(named_axiskeys(mean(d)) == named_keys) + if d isa Distribution && mean(d) isa KeyedArray && !(named_axiskeys(mean(d)) == named_keys) throw(ArgumentError( "Distribution keys $(named_axiskeys(mean(d))) do not match " * "KeyedDistribution keys $(named_keys)"