Skip to content

Commit

Permalink
test: add nsamples for gradlogpartition test
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Jan 13, 2025
1 parent 66601a4 commit 4af1f19
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ function test_exponentialfamily_interface(distribution;
test_fisherinformation_against_hessian = true,
test_fisherinformation_against_jacobian = true,
test_plogpdf_interface = true,
option_assume_no_allocations = false
option_assume_no_allocations = false,
nsamples_for_gradlogpartition_properties = 6000
)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand All @@ -84,7 +85,7 @@ function test_exponentialfamily_interface(distribution;
test_packing_unpacking && run_test_packing_unpacking(distribution)
test_isproper && run_test_isproper(distribution; assume_no_allocations = option_assume_no_allocations)
test_basic_functions && run_test_basic_functions(distribution; assume_no_allocations = option_assume_no_allocations)
test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution)
test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution, nsamples = nsamples_for_gradlogpartition_properties)
test_fisherinformation_properties && run_test_fisherinformation_properties(distribution)
test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations)
test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations)
Expand All @@ -96,7 +97,7 @@ function run_test_plogpdf_interface(distribution)
ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution)
η = getnaturalparameters(ef)
samples = rand(StableRNG(42), distribution, 10)
_, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples)
_, _samples = ExponentialFamily.check_logpdf(ef, samples)
ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples)
unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors)
@test all(unnormalized_logpdfs map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples))
Expand Down

0 comments on commit 4af1f19

Please sign in to comment.