diff --git a/spotter/core.py b/spotter/core.py index 9dff0ba..62029fa 100644 --- a/spotter/core.py +++ b/spotter/core.py @@ -167,24 +167,12 @@ def doppler_shift(theta, phi, period, radius, phase): def shifted_spectra(spectra, shift): - n_spectra, n_wavelength = spectra.shape - - # Fourier transform along the wavelength axis (axis=1) for each spectrum independently - spectra_ft = np.fft.fft(spectra, axis=1) - - # Generate the frequency indices for each element along the wavelength axis - k = np.fft.fftfreq(n_wavelength).reshape( - 1, -1 - ) # Shape (1, n_wavelength) to broadcast along rows - - # Compute the phase shift matrix for each element in shift - phase_shift = np.exp(-2j * np.pi * k * shift) # Shape (n_spectra, n_wavelength) - - # Apply the phase shift and inverse Fourier transform - shifted = np.fft.ifft(spectra_ft * phase_shift, axis=1) - - # Return the real part, assuming the input was real - return np.real(shifted) + _, n_wavelength = spectra.shape + spectra_ft = jnp.fft.fft(spectra, axis=1) + k = np.fft.fftfreq(n_wavelength).reshape(1, -1) + phase_shift = jnp.exp(-2j * np.pi * k * shift) + shifted = jnp.fft.ifft(spectra_ft * phase_shift, axis=1) + return jnp.real(shifted) def integrated_spectrum(N, theta, phi, period, radius, wv, spectra, phase, inc):