From e5ae7d99c653af9b4a297a5640af3d089e4d57b0 Mon Sep 17 00:00:00 2001 From: Matteo Bachetti Date: Wed, 19 Jun 2024 12:29:21 +0200 Subject: [PATCH] Separate parameter guessing and test properly --- hendrics/ml_timing.py | 99 +++++++++++++++++++++++++++++++------------ 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/hendrics/ml_timing.py b/hendrics/ml_timing.py index b1657a73..332ffcc2 100644 --- a/hendrics/ml_timing.py +++ b/hendrics/ml_timing.py @@ -252,6 +252,55 @@ def normalized_template(template, tomax=False, subtract_min=True): # return np.std(bases), np.std(amps), np.std(phases) +def _func_for_toa_fitting(phases, pars, template_fun): + amp, shift = pars[:2] + base = 0 + if len(pars) > 2: + base = pars[2] + return base + amp * template_fun(phases - shift) + + +def _guess_start_pars(profile, template, fit_base=True, mean_phase=None): + """Guess the starting parameters for the fit. + + Examples + -------- + >>> phases = np.linspace(0, 1, 10001) + >>> template_fun = lambda x : 3.11 * np.exp(-(x - 0.5)**2 / (2 * 0.05**2)) + 355 + >>> template = template_fun(phases) + >>> profile1 = 34 * template_fun(phases - 0.2) + 11. + >>> profile2 = 53 * template_fun(phases - 0.2) + >>> pars1 = _guess_start_pars(profile1, template, fit_base=True) + >>> newprof1 = _func_for_toa_fitting(phases, pars1, template_fun) + >>> assert np.allclose(profile1, newprof1, atol=1e-10) + >>> pars2 = _guess_start_pars(profile2, template, fit_base=False) + >>> newprof2 = _func_for_toa_fitting(phases, pars2, template_fun) + >>> assert np.allclose(profile2, newprof2, atol=1e-10) + """ + minp = np.min(profile) + maxp = np.max(profile) + mint = np.min(template) + maxt = np.max(template) + + dph = 1 / profile.size + if mean_phase is None: + mean_phase = ((np.argmax(profile) - np.argmax(template))) * dph + + if fit_base: + amp_tr = (maxp - minp) / (maxt - mint) + x0 = ( + amp_tr, + phases_from_zero_to_one(mean_phase), + minp - mint * amp_tr, + ) + else: + x0 = ( + maxp / maxt, + phases_from_zero_to_one(mean_phase), + ) + return x0 + + def ml_pulsefit( profile, template, @@ -321,28 +370,11 @@ def func(pars): return np.inf return ll - minp = np.min(profile) - maxp = np.max(profile) - mint = np.min(template) - maxt = np.max(template) - - dph = 1 / profile.size - if mean_phase is None: - mean_phase = ((np.argmax(profile) - np.argmax(template)) + 0.5) * dph + x0 = _guess_start_pars(profile, template, fit_base=fit_base, mean_phase=mean_phase) if fit_base: - x0 = ( - (maxp - minp) / (maxt - mint), - phases_from_zero_to_one(mean_phase), - minp - mint, - ) bounds = [(0, np.inf), (0, 1), (0, np.inf)] - else: - x0 = ( - maxp / maxt, - phases_from_zero_to_one(mean_phase), - ) bounds = [(0, np.inf), (0, 1)] res = minimize(func, x0, bounds=bounds) @@ -368,25 +400,38 @@ def func(pars): errs = np.concatenate((errs, [0])) # import matplotlib.pyplot as plt + + # amp_tr, shift_tr = x0[:2] + # base_tr = x0[2] if fit_base else 0 + # plt.figure() # phases_fine = np.linspace(0, 1, 300) # amp, shift, base = final_pars - # amp_tr, shift_tr = x0[:2] - # base_tr = x0[2] if fit_base else 0 + # shift = phases_from_zero_to_one(shift) # plt.title(f"{template.size} {shift}") - # plt.plot(phases_fine, base + amp * template_fun(phases_fine - shift), label="Best fit") - # plt.plot(phases_fine, - # base_tr + amp_tr * template_fun(phases_fine - shift_tr), - # color="grey", label="Start guess") - # plt.plot(phases_fine, base + amp * template_fun(phases_fine), color="grey", alpha=0.5, - # label="Template") + # plt.plot( + # phases_fine, base + amp * template_fun(phases_fine - shift), label="Best fit" + # ) + # plt.plot( + # phases_fine, + # base_tr + amp_tr * template_fun(phases_fine - shift_tr), + # color="grey", + # label="Start guess", + # ) + # plt.plot( + # phases_fine, + # base + amp * template_fun(phases_fine), + # color="grey", + # alpha=0.5, + # label="Template", + # ) # plt.axvline(shift - errs[1]) # plt.axvline(shift + errs[1]) # plt.axvline(phases_from_zero_to_one(mean_phase), color="k") # plt.plot(phases, profile, label="Data") - # plt.show() # plt.legend() + # plt.show() # plt.savefig(f"{np.random.random()}.png") return final_pars, errs