Skip to content

Commit

Permalink
Merge branch 'main' into level_matching
Browse files Browse the repository at this point in the history
  • Loading branch information
ajwheeler authored Mar 11, 2024
2 parents 343240b + 4d6b68f commit f6b3bd3
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 125 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Korg"
uuid = "acafc109-a718-429c-b0e5-afd7f8c7ae46"
authors = ["Adam Wheeler <a.wheeler@columbia.edu>"]
version = "0.29.2"
version = "0.29.3"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down
41 changes: 15 additions & 26 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,34 +269,23 @@ function fit_spectrum(obs_wls, obs_flux, obs_err, linelist, initial_guesses, fix
end
end

res, invH = if length(p0) == 1
# if we are fitting a single parameter, experimentation shows that Nelder-Mead (the default)
# is faster than BFGS

# there seems to be a problem with trace storage for this optimizer, so we don't request it
# the precision keyword is also ignored
optimize(chi2, p0, Optim.Options(x_tol=precision); autodiff=:forward), nothing
else
# if we are fitting a multiple parameters, use BFGS with autodiff
res = optimize(chi2, p0, BFGS(linesearch=LineSearches.BackTracking()),
Optim.Options(x_tol=precision, time_limit=10_000, store_trace=true,
extended_trace=true); autodiff=:forward)

# derivate relating the scaled parameters to the unscaled parameters
# (used to convert the approximate hessian to a covariance matrix in the unscaled params)
dp_dscaledp = map(res.minimizer, params_to_fit) do scaled_param, param_name
ForwardDiff.derivative(scaled_param) do scaled_param
unscale(Dict(param_name=>scaled_param))[param_name]
end
# if we are fitting a multiple parameters, use BFGS with autodiff
res = optimize(chi2, p0, BFGS(linesearch=LineSearches.BackTracking()),
Optim.Options(x_tol=precision, time_limit=10_000, store_trace=true,
extended_trace=true); autodiff=:forward)

# derivate relating the scaled parameters to the unscaled parameters
# (used to convert the approximate hessian to a covariance matrix in the unscaled params)
dp_dscaledp = map(res.minimizer, params_to_fit) do scaled_param, param_name
ForwardDiff.derivative(scaled_param) do scaled_param
unscale(Dict(param_name=>scaled_param))[param_name]
end
# the fact that the scaling is a diagonal operation means that we can do this as an element-wise
# product. If we think of ds/dp as a (diagonal) matrix, this is equivalent to
# (ds/dp)^T * invH * (ds/dp)
invH_scaled = res.trace[end].metadata["~inv(H)"]
invH = invH_scaled .* dp_dscaledp .* dp_dscaledp'

res, invH
end
# the fact that the scaling is a diagonal operation means that we can do this as an element-wise
# product. If we think of ds/dp as a (diagonal) matrix, this is equivalent to
# (ds/dp)^T * invH * (ds/dp)
invH_scaled = res.trace[end].metadata["~inv(H)"]
invH = invH_scaled .* dp_dscaledp .* dp_dscaledp'
solution = unscale(Dict(params_to_fit .=> res.minimizer))

trace = map(res.trace) do t
Expand Down
7 changes: 6 additions & 1 deletion src/synthesize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ function synthesize(atm::ModelAtmosphere, linelist, A_X::AbstractVector{<:Real},
#sort the lines if necessary
issorted(linelist; by=l->l.wl) || sort!(linelist, by=l->l.wl)
#discard lines far from the wavelength range being synthesized
linelist = filter(linelist) do line
nlines_before = length(linelist)
linelist = filter(linelist) do line # don't "filter!". It mutates the linelist.
map(wl_ranges) do wl_range
wl_range[1] - line_buffer <= line.wl <= wl_range[end]
end |> any
end

#TODO clean up filtering
for i in 1:length(NLTE_lines)
lines, bs = NLTE_lines[i]
Expand All @@ -166,6 +168,9 @@ function synthesize(atm::ModelAtmosphere, linelist, A_X::AbstractVector{<:Real},
end |> any
end
NLTE_lines[i] = (lines, bs)

if nlines_before != 0 && length(linelist) == 0
@warn "The provided linelist was not empty, but none of the lines were within the provided wavelength range."
end

if length(A_X) != MAX_ATOMIC_NUMBER || (A_X[1] != 12)
Expand Down
96 changes: 1 addition & 95 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include("fit.jl")
include("autodiff.jl")
include("autodiffable_conv.jl")
include("atmosphere.jl")
include("synthesize.jl")

@testset "atomic data" begin
@test (Korg.MAX_ATOMIC_NUMBER
Expand Down Expand Up @@ -116,62 +117,6 @@ end
end
end

@testset "synthesis" begin

@testset "abundances" begin
@test (format_A_X()
== format_A_X(0)
== format_A_X(0, 0)
== format_A_X(Dict{String, Float64}())
== format_A_X(Dict{Int, Float64}())
== format_A_X(0, Dict(1=>0.0); solar_relative=true)
== format_A_X(0, 0, Dict(1=>0.0); solar_relative=true)
== format_A_X(0, Dict("H"=>0.0); solar_relative=true)
== format_A_X(0, Dict(1=>12.0); solar_relative=false)
== format_A_X(0, Dict("H"=>12.0); solar_relative=false))

# make sure silly H abundances are caught
@test_throws ArgumentError format_A_X(0.0, Dict("H"=>0); solar_relative=false)
@test_throws ArgumentError format_A_X(0.0, Dict(1=>0); solar_relative=false)
@test_throws ArgumentError format_A_X(0.0, Dict("H"=>12); solar_relative=true)
@test_throws ArgumentError format_A_X(0.0, Dict(1=>12); solar_relative=true)

atol = 1e-5
@test Korg.get_alpha_H(format_A_X(0.1)) 0.1 atol=atol
@test Korg.get_alpha_H(format_A_X(0.0, 0.1)) 0.1 atol=atol
@test Korg.get_alpha_H(format_A_X(-0.2)) -0.2 atol=atol
@test Korg.get_alpha_H(format_A_X(-2, -0.2)) -0.2 atol=atol
@test Korg.get_metals_H(format_A_X(0.1)) 0.1 atol=atol
@test Korg.get_metals_H(format_A_X(-0.2)) -0.2 atol=atol
@test Korg.get_metals_H(format_A_X(0.1, 0.5)) 0.1 atol=atol
@test Korg.get_metals_H(format_A_X(-0.2, 0.5)) -0.2 atol=atol
@test Korg.get_metals_H(Korg.grevesse_2007_solar_abundances;
solar_abundances=Korg.grevesse_2007_solar_abundances) 0 atol=atol
@test Korg.get_alpha_H(Korg.grevesse_2007_solar_abundances;
solar_abundances=Korg.grevesse_2007_solar_abundances) 0 atol=atol

@test format_A_X(1.1) != format_A_X(1.1, 0)
@test format_A_X(1.1)[50] == format_A_X(1.1, 0)[50] == format_A_X(-1, -2, Dict(50=>1.1))[50]

@testset for metallicity in [0.0, 0.5], abundances in [Dict(), Dict("C"=>1.1)], solar_relative in [true, false]
A_X = format_A_X(metallicity, abundances;
solar_abundances=Korg.asplund_2020_solar_abundances,
solar_relative=solar_relative)

#correct absolute abundances?
if "C" in keys(abundances)
if solar_relative
@test A_X[6] Korg.asplund_2020_solar_abundances[6] + 1.1
else
@test A_X[6] 1.1
end
end
@test A_X[7:end] Korg.asplund_2020_solar_abundances[7:end] .+ metallicity
@test A_X[1:2] == Korg.asplund_2020_solar_abundances[1:2]
end
end
end

@testset "LSF" begin
wls = 5900:0.35:6100
R = 1800.0
Expand Down Expand Up @@ -305,45 +250,6 @@ end
@test Korg.air_to_vacuum.(Korg.vacuum_to_air.(wls)*1e8)*1e-8 wls rtol=1e-3
end

@testset "synthesize wavelength handling" begin
atm = read_model_atmosphere("data/sun.mod")
wls = 15000:0.01:15500
A_X = format_A_X()
@test synthesize(atm, [], A_X, 15000, 15500).wavelengths wls
@test synthesize(atm, [], A_X, 15000, 15500; air_wavelengths=true).wavelengths Korg.air_to_vacuum.(wls)
@test_throws ArgumentError synthesize(atm, [], A_X, 15000, 15500; air_wavelengths=true,
wavelength_conversion_warn_threshold=1e-20)
@test_throws ArgumentError synthesize(atm, [], A_X, 2000, 8000, air_wavelengths=true)

# test multiple line windows
r1 = 5000:0.01:5001
r2 = 6000:0.01:6001
sol1 = synthesize(atm, [], A_X, [r1]; hydrogen_lines=true)
sol2 = synthesize(atm, [], A_X, [r2]; hydrogen_lines=true)
sol3 = synthesize(atm, [], A_X, [r1, r2]; hydrogen_lines=true)

@test sol1.wavelengths == sol3.wavelengths[sol3.subspectra[1]]
@test sol2.wavelengths == sol3.wavelengths[sol3.subspectra[2]]
@test sol1.flux == sol3.flux[sol3.subspectra[1]]
@test sol2.flux == sol3.flux[sol3.subspectra[2]]
end

@testset "line buffer" begin
#strong line at 4999 Å
line1 = Korg.Line(4999e-8, 1.0, Korg.species"Na I", 0.0)
#strong line at 4997 Å
line2 = Korg.Line(4997e-8, 1.0, Korg.species"Na I", 0.0)
atm = read_model_atmosphere("data/sun.mod")

#use a 2 Å line buffer so only line1 in included
sol_no_lines = synthesize(atm, [], format_A_X(), 5000, 5000; line_buffer=2.0) #synthesize at 5000 Å only
sol_one_lines = synthesize(atm, [line1], format_A_X(), 5000, 5000; line_buffer=2.0)
sol_two_lines = synthesize(atm, [line1, line2], format_A_X(), 5000, 5000; line_buffer=2.0)

@test sol_no_lines.flux != sol_one_lines.flux
@test sol_two_lines.flux == sol_one_lines.flux
end

@testset "linelists" begin
atm = read_model_atmosphere("data/sun.mod")

Expand Down
100 changes: 100 additions & 0 deletions test/synthesize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
@testset "synthesize" begin
@testset "line buffer" begin
#strong line at 4999 Å
line1 = Korg.Line(4999e-8, 1.0, Korg.species"Na I", 0.0)
#strong line at 4997 Å
line2 = Korg.Line(4997e-8, 1.0, Korg.species"Na I", 0.0)
atm = read_model_atmosphere("data/sun.mod")

#use a 2 Å line buffer so only line1 in included
sol_no_lines = synthesize(atm, [], format_A_X(), 5000, 5000; line_buffer=2.0) #synthesize at 5000 Å only
sol_one_lines = synthesize(atm, [line1], format_A_X(), 5000, 5000; line_buffer=2.0)
sol_two_lines = synthesize(atm, [line1, line2], format_A_X(), 5000, 5000; line_buffer=2.0)

@test sol_no_lines.flux != sol_one_lines.flux
@test sol_two_lines.flux == sol_one_lines.flux
end

@testset "synthesize wavelength handling" begin
atm = read_model_atmosphere("data/sun.mod")
wls = 15000:0.01:15500
A_X = format_A_X()
@test synthesize(atm, [], A_X, 15000, 15500).wavelengths wls
@test synthesize(atm, [], A_X, 15000, 15500; air_wavelengths=true).wavelengths Korg.air_to_vacuum.(wls)
@test_throws ArgumentError synthesize(atm, [], A_X, 15000, 15500; air_wavelengths=true,
wavelength_conversion_warn_threshold=1e-20)
@test_throws ArgumentError synthesize(atm, [], A_X, 2000, 8000, air_wavelengths=true)

# test multiple line windows
r1 = 5000:0.01:5001
r2 = 6000:0.01:6001
sol1 = synthesize(atm, [], A_X, [r1]; hydrogen_lines=true)
sol2 = synthesize(atm, [], A_X, [r2]; hydrogen_lines=true)
sol3 = synthesize(atm, [], A_X, [r1, r2]; hydrogen_lines=true)

@test sol1.wavelengths == sol3.wavelengths[sol3.subspectra[1]]
@test sol2.wavelengths == sol3.wavelengths[sol3.subspectra[2]]
@test sol1.flux == sol3.flux[sol3.subspectra[1]]
@test sol2.flux == sol3.flux[sol3.subspectra[2]]
end

@testset "abundances" begin
@test (format_A_X()
== format_A_X(0)
== format_A_X(0, 0)
== format_A_X(Dict{String, Float64}())
== format_A_X(Dict{Int, Float64}())
== format_A_X(0, Dict(1=>0.0); solar_relative=true)
== format_A_X(0, 0, Dict(1=>0.0); solar_relative=true)
== format_A_X(0, Dict("H"=>0.0); solar_relative=true)
== format_A_X(0, Dict(1=>12.0); solar_relative=false)
== format_A_X(0, Dict("H"=>12.0); solar_relative=false))

# make sure silly H abundances are caught
@test_throws ArgumentError format_A_X(0.0, Dict("H"=>0); solar_relative=false)
@test_throws ArgumentError format_A_X(0.0, Dict(1=>0); solar_relative=false)
@test_throws ArgumentError format_A_X(0.0, Dict("H"=>12); solar_relative=true)
@test_throws ArgumentError format_A_X(0.0, Dict(1=>12); solar_relative=true)

atol = 1e-5
@test Korg.get_alpha_H(format_A_X(0.1)) 0.1 atol=atol
@test Korg.get_alpha_H(format_A_X(0.0, 0.1)) 0.1 atol=atol
@test Korg.get_alpha_H(format_A_X(-0.2)) -0.2 atol=atol
@test Korg.get_alpha_H(format_A_X(-2, -0.2)) -0.2 atol=atol
@test Korg.get_metals_H(format_A_X(0.1)) 0.1 atol=atol
@test Korg.get_metals_H(format_A_X(-0.2)) -0.2 atol=atol
@test Korg.get_metals_H(format_A_X(0.1, 0.5)) 0.1 atol=atol
@test Korg.get_metals_H(format_A_X(-0.2, 0.5)) -0.2 atol=atol
@test Korg.get_metals_H(Korg.grevesse_2007_solar_abundances;
solar_abundances=Korg.grevesse_2007_solar_abundances) 0 atol=atol
@test Korg.get_alpha_H(Korg.grevesse_2007_solar_abundances;
solar_abundances=Korg.grevesse_2007_solar_abundances) 0 atol=atol

@test format_A_X(1.1) != format_A_X(1.1, 0)
@test format_A_X(1.1)[50] == format_A_X(1.1, 0)[50] == format_A_X(-1, -2, Dict(50=>1.1))[50]

@testset for metallicity in [0.0, 0.5], abundances in [Dict(), Dict("C"=>1.1)], solar_relative in [true, false]
A_X = format_A_X(metallicity, abundances;
solar_abundances=Korg.asplund_2020_solar_abundances,
solar_relative=solar_relative)

#correct absolute abundances?
if "C" in keys(abundances)
if solar_relative
@test A_X[6] Korg.asplund_2020_solar_abundances[6] + 1.1
else
@test A_X[6] 1.1
end
end
@test A_X[7:end] Korg.asplund_2020_solar_abundances[7:end] .+ metallicity
@test A_X[1:2] == Korg.asplund_2020_solar_abundances[1:2]
end
end

@testset "linelist checking" begin
msg = "The provided linelist was not empty"
atm = interpolate_marcs(5000.0, 4.4)
linelist = [Korg.Line(5000e-8, 1.0, Korg.species"Na I", 0.0)]
@test_warn msg synthesize(atm, linelist, format_A_X(), 6000, 6000)
end
end
4 changes: 2 additions & 2 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ function assert_allclose(actual, reference; rtol = 1e-7, atol = 0.0, err_msg = n
relative_diff[diff .== 0] .= 0

if print_rachet_info
if all(diff .< 0.5*atol)
if all(diff .< 0.1*atol)
@info "test can be racheted down: atol=$(atol), but the max diff is $(maximum(diff))"
display(stacktrace())
end
if all(relative_diff .< 0.5*rtol)
if all(relative_diff .< 0.1*rtol)
@info ("test can be racheted down: rtol=$(rtol), but the max relative diff is "
* "$(maximum(relative_diff))")
display(stacktrace())
Expand Down

0 comments on commit f6b3bd3

Please sign in to comment.