Skip to content

Commit

Permalink
fix: shift_spectra jaxified
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed Nov 1, 2024
1 parent 33eb505 commit 0757a1f
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions spotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,12 @@ def doppler_shift(theta, phi, period, radius, phase):


def shifted_spectra(spectra, shift):
n_spectra, n_wavelength = spectra.shape

# Fourier transform along the wavelength axis (axis=1) for each spectrum independently
spectra_ft = np.fft.fft(spectra, axis=1)

# Generate the frequency indices for each element along the wavelength axis
k = np.fft.fftfreq(n_wavelength).reshape(
1, -1
) # Shape (1, n_wavelength) to broadcast along rows

# Compute the phase shift matrix for each element in shift
phase_shift = np.exp(-2j * np.pi * k * shift) # Shape (n_spectra, n_wavelength)

# Apply the phase shift and inverse Fourier transform
shifted = np.fft.ifft(spectra_ft * phase_shift, axis=1)

# Return the real part, assuming the input was real
return np.real(shifted)
_, n_wavelength = spectra.shape
spectra_ft = jnp.fft.fft(spectra, axis=1)
k = np.fft.fftfreq(n_wavelength).reshape(1, -1)
phase_shift = jnp.exp(-2j * np.pi * k * shift)
shifted = jnp.fft.ifft(spectra_ft * phase_shift, axis=1)
return jnp.real(shifted)


def integrated_spectrum(N, theta, phi, period, radius, wv, spectra, phase, inc):
Expand Down

0 comments on commit 0757a1f

Please sign in to comment.