Skip to content

Commit

Permalink
Try to convert types for parameter if value, valuebounds, and/or tran…
Browse files Browse the repository at this point in the history
…sform_parameterization types do not match.
  • Loading branch information
chenwilliam77 committed Jan 5, 2021
1 parent bbc65bd commit 4a94533
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ function parameter(key::Symbol,
scaling::Function = identity,
regimes::Dict{Symbol,OrderedDict{Int64,Any}} = Dict{Symbol,OrderedDict{Int64,Any}}(),
description::String = "No description available.",
tex_label::String = "") where {V<:Vector, T <: Float64, U <:Transform} #{V<:Vector, S<:Real, T <: Float64, U <:Transform}
tex_label::String = "") where {V<:Vector, T <: Real, U <:Transform} #{V<:Vector, S<:Real, T <: Float64, U <:Transform}

# If fixed=true, force bounds to match and leave prior as null. We need to define new
# variable names here because of lexical scoping.
Expand Down Expand Up @@ -445,6 +445,32 @@ function parameter(key::Symbol,
end
end

function parameter(key::Symbol,
value::Union{T1, V}, #value::Union{S,V},
valuebounds::Interval{T2} = (value,value),
transform_parameterization::Interval{T3} = (value,value),
transform::U = Untransformed(),
prior::Union{NullableOrPriorUnivariate, NullableOrPriorMultivariate} = NullablePriorUnivariate();
fixed::Bool = true,
scaling::Function = identity,
regimes::Dict{Symbol,OrderedDict{Int64,Any}} = Dict{Symbol,OrderedDict{Int64,Any}}(),
description::String = "No description available.",
tex_label::String = "") where {V<:Vector, T1 <: Real, T2 <: Real, T3 <: Real, U <:Transform}
warn_str = "The element types of the fields `value` ($(typeof(value))), `valuebounds` ($(eltype(valuebounds))), " *
"and `transform_parameterization` ($(eltype(transform_parameterization))) do not match. " *
"Attempting to convert all types to the same type as `value`. Note that the element type for the prior " *
"distribution should also be $(typeof(value))."
@warn warn_str

valuebounds_new = (convert(T1, valuebounds[1]), convert(T1, valuebounds[2]))
transform_parameterization_new = (convert(T1, transform_parameterization[1]),
convert(T1, transform_parameterization[2]))

return parameter(key, value, valuebounds_new, transform_parameterization_new,
transform, prior; fixed = fixed, scaling = scaling,
regimes = regimes, description = description, tex_label = tex_label)
end

function parameter_ad(key::Symbol,
value::Union{S,V},
valuebounds::Interval{T} = (value,value),
Expand Down
12 changes: 12 additions & 0 deletions test/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,19 @@ tomodel_answers[3] = 1.
@test differentiate_transform_to_real_line(u,u.value) != differentiate_transform_to_model_space(u,u.value)
end
end
end

@testset "Check type conversion for `value`, `valuebounds`, and `transform_parameterization`" begin
@info "The following warning is expected"
u1 = parameter(:σ_pist, 2.5230, (1, 5), (Float32(1), Float32(5)), Untransformed())
u2 = parameter(:σ_pist, 2.5230, (1., 5.), (1., 5.), Untransformed())

@test u1.value == u2.value
@test u1.valuebounds == u2.valuebounds
@test u1.transform_parameterization == u2.transform_parameterization
@test typeof(u1.value) == typeof(u2.value)
@test eltype(u1.valuebounds) == eltype(u2.valuebounds)
@test eltype(u1.transform_parameterization) == eltype(u2.transform_parameterization)
end

# probability
Expand Down

0 comments on commit 4a94533

Please sign in to comment.