diff --git a/spotter/light_curves.py b/spotter/light_curves.py index a4653a4..5a5391c 100644 --- a/spotter/light_curves.py +++ b/spotter/light_curves.py @@ -17,9 +17,16 @@ def design_matrix(star: Star, time: ArrayLike) -> ArrayLike: lambda u: core.design_matrix(star.y[0], star.phase(time), star.inc, u) )(star.u) else: - return jax.vmap( - lambda y, u: core.design_matrix(y, star.phase(time), star.inc, u) - )(star.y, star.u) + if len(star.u) == 1: + return jax.vmap( + lambda y: core.design_matrix( + y, star.phase(time), star.inc, star.u[0] + ) + )(star.y) + else: + return jax.vmap( + lambda y, u: core.design_matrix(y, star.phase(time), star.inc, u) + )(star.y, star.u) else: return jax.vmap( lambda y: core.design_matrix(y, star.phase(time), star.inc, star.u) diff --git a/tests/test_light_curves.py b/tests/test_light_curves.py index a57beff..00589e6 100644 --- a/tests/test_light_curves.py +++ b/tests/test_light_curves.py @@ -21,6 +21,7 @@ def test_dark_hemisphere(): ( (np.ones(N), None, (1, 3)), (np.ones((2, N)), None, (2, 3)), + (np.ones((2, N)), ((0.1,)), (2, 3)), (np.ones(N), ((0.1,), (0.2,)), (2, 3)), (np.ones((2, N)), ((0.1, 0.4), (0.2, 0.3)), (2, 3)), (np.ones(N), ((0.1, 0.4), (0.2, 0.3)), (2, 3)),