From 2017a0e9f505b67d95870c0873bf23f59ef72201 Mon Sep 17 00:00:00 2001 From: lgrcia Date: Mon, 4 Nov 2024 14:05:10 -0500 Subject: [PATCH] fix: light curve of multiple y single u --- spotter/light_curves.py | 13 ++++++++++--- tests/test_light_curves.py | 1 + 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/spotter/light_curves.py b/spotter/light_curves.py index a4653a4..5a5391c 100644 --- a/spotter/light_curves.py +++ b/spotter/light_curves.py @@ -17,9 +17,16 @@ def design_matrix(star: Star, time: ArrayLike) -> ArrayLike: lambda u: core.design_matrix(star.y[0], star.phase(time), star.inc, u) )(star.u) else: - return jax.vmap( - lambda y, u: core.design_matrix(y, star.phase(time), star.inc, u) - )(star.y, star.u) + if len(star.u) == 1: + return jax.vmap( + lambda y: core.design_matrix( + y, star.phase(time), star.inc, star.u[0] + ) + )(star.y) + else: + return jax.vmap( + lambda y, u: core.design_matrix(y, star.phase(time), star.inc, u) + )(star.y, star.u) else: return jax.vmap( lambda y: core.design_matrix(y, star.phase(time), star.inc, star.u) diff --git a/tests/test_light_curves.py b/tests/test_light_curves.py index a57beff..00589e6 100644 --- a/tests/test_light_curves.py +++ b/tests/test_light_curves.py @@ -21,6 +21,7 @@ def test_dark_hemisphere(): ( (np.ones(N), None, (1, 3)), (np.ones((2, N)), None, (2, 3)), + (np.ones((2, N)), ((0.1,)), (2, 3)), (np.ones(N), ((0.1,), (0.2,)), (2, 3)), (np.ones((2, N)), ((0.1, 0.4), (0.2, 0.3)), (2, 3)), (np.ones(N), ((0.1, 0.4), (0.2, 0.3)), (2, 3)),