diff --git a/pyproject.toml b/pyproject.toml index de30d97..7ac518c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "starspotter" -version = "0.0.1" +version = "0.0.2" description = "Stellar contamination estimates from rotational light curves" authors = [{ name = "Lionel Garcia" }, { name = "Benjamin Rackham" }] license = "MIT" diff --git a/spotter/core.py b/spotter/core.py index beb2946..3dcf58f 100644 --- a/spotter/core.py +++ b/spotter/core.py @@ -3,6 +3,8 @@ import jax.numpy as jnp import numpy as np +from spotter import utils + def distance(X, x): return jnp.arctan2( @@ -14,7 +16,7 @@ def npix(N): return hp.nside2npix(N) -def equator_coords(phi=None, inc=None): +def equator_coords(phi=None, inc=None, obl=None): # inc = None -> inc = pi/2 if inc is None: si = 1.0 @@ -30,15 +32,28 @@ def equator_coords(phi=None, inc=None): cp = jnp.cos(phi) sp = jnp.sin(phi) + if obl is None: + co = 1.0 + so = 0.0 + else: + co = jnp.cos(obl) + so = jnp.sin(obl) + x = si * cp y = si * sp z = ci + x, y, z = ( + x, + co * y - so * z, + so * y + co * z, + ) + return jnp.array([x, y, z]) -def mask_projected_limb(X, phase=None, inc=None, u=None): - d = distance(X, equator_coords(phase, inc)) +def mask_projected_limb(X, phase=None, inc=None, u=None, obl=None): + d = distance(X, equator_coords(phase, inc, obl)) mask = d < jnp.pi / 2 z = jnp.cos(d) projected_area = z @@ -67,15 +82,15 @@ def vec(N_or_y): return np.array(hp.pix2vec(N, range(n))).T -def design_matrix(N_or_y, phase=None, inc=None, u=None): +def design_matrix(N_or_y, phase=None, inc=None, u=None, obl=None): X = vec(N_or_y) - mask, projected_area, limb_darkening = mask_projected_limb(X, phase, inc, u) + mask, projected_area, limb_darkening = mask_projected_limb(X, phase, inc, u, obl) geometry = mask * projected_area return jnp.pi * limb_darkening * geometry / (geometry * limb_darkening).sum() -def flux(y, inc=None, u=None, phase=None): - return design_matrix(y, inc=inc, u=u, phase=phase) @ y +def flux(y, inc=None, u=None, phase=None, obl=None): + return design_matrix(y, inc=inc, u=u, phase=phase, obl=obl) @ y def spherical_to_cartesian(theta, phi): @@ -86,10 +101,10 @@ def spherical_to_cartesian(theta, phi): return jnp.array([x, y, z]) -def spot(N, latitude, longitude, radius): +def spot(N, latitude, longitude, radius, sharpness=1000): X = vec(N) - d = distance(X, spherical_to_cartesian(jnp.pi / 2 - latitude, longitude)) - return d < radius + d = distance(X, spherical_to_cartesian(jnp.pi / 2 - latitude, -longitude)) + return 1 - utils.sigmoid(d - radius, sharpness) def soft_spot(N, latitude, longitude, radius): @@ -101,14 +116,14 @@ def soft_spot(N, latitude, longitude, radius): return profile / jnp.max(profile) -def render(y, inc=None, u=None, phase=0.0): +def render(y, inc=None, u=None, phase=0.0, obl=0.0): import matplotlib.pyplot as plt X = vec(y) limb_darkening = mask_projected_limb(X, phase, inc, u)[2] rotated = hp.Rotator( - rot=[phase, np.pi / 2 - inc or 0.0], deg=False + rot=[phase, np.pi / 2 - inc or 0.0, obl or 0.0], deg=False ).rotate_map_pixel(y * limb_darkening) projected_map = hp.orthview(rotated, half_sky=True, return_projected_map=True) diff --git a/spotter/light_curves.py b/spotter/light_curves.py index d75eb93..a122fc6 100644 --- a/spotter/light_curves.py +++ b/spotter/light_curves.py @@ -1,7 +1,6 @@ -import jax - from functools import partial +import jax import jax.numpy as jnp from jax.typing import ArrayLike @@ -74,8 +73,8 @@ def impl(star, time): def transit_light_curve( star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0, time: float = 0.0 ): - """Light curve of a transited Star. - + """Light curve of a transited Star. The x-axis cross the star in the horizontal direction (→), + and the y-axis cross the star in the vertical up direction (↑). Parameters ---------- star : Star @@ -94,4 +93,4 @@ def transit_light_curve( ArrayLike Light curve array. """ - return light_curve(transited_star(star, x, y, r), star.phase(time)) + return light_curve(transited_star(star, y, x, r), star.phase(time)) diff --git a/spotter/star.py b/spotter/star.py index 630e575..459f945 100644 --- a/spotter/star.py +++ b/spotter/star.py @@ -1,10 +1,11 @@ import equinox as eqx import healpy as hp +import jax import jax.numpy as jnp import numpy as np from jax.typing import ArrayLike -from spotter import core, viz +from spotter import core, utils, viz class Star(eqx.Module): @@ -51,6 +52,9 @@ class Star(eqx.Module): inc: float | None = None """Inclination of the star, in radians. 0 is pole-on, pi/2 is equator-on.""" + obl: float | None = None + """Obliquity of the star, in radians. 0 is no obliquity, pi/2 is maximum obliquity.""" + radius: float | None = None """Radius of the star, in solar radii.""" @@ -65,6 +69,7 @@ def __init__( y: ArrayLike | None = None, u: ArrayLike | None = None, inc: float | None = None, + obl: float | None = None, period: float | None = None, radius: float | None = None, wv: float | None = None, @@ -72,6 +77,7 @@ def __init__( self.y = jnp.atleast_2d(y) self.u = jnp.atleast_2d(u) if u is not None else None self.inc = inc + self.obl = obl self.period = period self.sides = core._N_or_Y_to_N_n(self.y[0])[0] self.radius = radius if radius is not None else 1.0 @@ -124,7 +130,7 @@ def __mul__(self, other): y = self.y * other.y else: y = self.y * other - return self.__class__(y, self.u, self.inc, self.period) + return self.set(y=y) def __rmul__(self, other): return self.__mul__(other) @@ -134,7 +140,7 @@ def __add__(self, other): y = self.y + other.y else: y = self.y + other - return self.__class__(y, self.u, self.inc, self.period) + return self.set(y=y) def __radd__(self, other): return self.__add__(other) @@ -144,7 +150,7 @@ def __sub__(self, other): y = self.y - other.y else: y = self.y - other - return self.__class__(y, self.u, self.inc, self.period) + return self.set(y=y) def __rsub__(self, other): return self.__sub__(other) @@ -161,6 +167,7 @@ def set(self, **kwargs): "y": self.y, "u": self.u, "inc": self.inc, + "obl": self.obl, "period": self.period, "radius": self.radius, "wv": self.wv, @@ -184,6 +191,7 @@ def show(star: Star, phase: ArrayLike = 0.0, ax=None, **kwargs): viz.show( star.y[0], star.inc if star.inc is not None else np.pi / 2, + star.obl if star.obl is not None else 0.0, star.u[0] if star.u is not None else None, phase, ax=ax, @@ -232,18 +240,24 @@ def transited_star(star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0): Star Star object transited by the disk. """ - if star.inc is None: - c = 0.0 - s = 1.0 - else: - c = jnp.sin(star.inc) - s = jnp.cos(star.inc) + from jax.scipy.spatial.transform import Rotation + _z, _y, _x = core.vec(star.sides).T - _x = _x * c - _z * s - return ( - 1 - - ( - jnp.linalg.norm(jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0) - < r - ) - ) * star + v = jnp.stack((_x, _y, _z), axis=-1) + + inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0 + _inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle) + _rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(v) + rv = jnp.where(inc_angle == 0.0, v, _rv) + + if star.obl is not None: + obl_angle = jnp.where(star.obl == 0.0, 1.0, star.obl) + _rv = Rotation.from_rotvec([0.0, 0.0, obl_angle]).apply(rv) + rv = jnp.where(obl_angle == 0.0, v, _rv) + + _x, _y, _ = rv.T + + distance = jnp.linalg.norm( + jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0 + ) + return utils.sigmoid(distance - r, 1000.0) * star diff --git a/spotter/utils.py b/spotter/utils.py index 3435adc..d54e9fc 100644 --- a/spotter/utils.py +++ b/spotter/utils.py @@ -1,6 +1,8 @@ from collections import defaultdict import healpy as hp +import jax +import jax.numpy as jnp import numpy as np @@ -29,3 +31,7 @@ def ylm2healpix(y): hy[i] = _hy[i] return hy + + +def sigmoid(x, scale=1000): + return (jnp.tanh(x * scale / 2) + 1) / 2 diff --git a/spotter/viz.py b/spotter/viz.py index f8030c5..efa9655 100644 --- a/spotter/viz.py +++ b/spotter/viz.py @@ -47,7 +47,7 @@ def rotation(inc, obl, theta): u /= np.linalg.norm(u) u *= inc - R = Rotation.from_rotvec(u) + R = Rotation.from_rotvec(np.array(u)) R *= Rotation.from_rotvec([0, 0, obl]) R *= Rotation.from_rotvec([np.pi / 2, 0, 0]) R *= Rotation.from_rotvec([0, 0, -theta]) @@ -121,7 +121,7 @@ def graticule( ax.plot(sqrt_radius * np.cos(theta), sqrt_radius * np.sin(theta), c="w", lw=3) -def show(y, inc=np.pi / 2, u=None, phase=0.0, ax=None, **kwargs): +def show(y, inc=np.pi / 2, obl=0.0, u=None, phase=0.0, ax=None, **kwargs): import matplotlib.pyplot as plt kwargs.setdefault("cmap", _DEFAULT_CMAP) @@ -130,13 +130,13 @@ def show(y, inc=np.pi / 2, u=None, phase=0.0, ax=None, **kwargs): # kwargs.setdefault("vmax", 1.0) ax = ax or plt.gca() - img = core.render(y, inc, u, phase) + img = core.render(y, inc, u, phase, obl) plt.setp(ax.spines.values(), visible=False) ax.tick_params(left=False, labelleft=False) ax.tick_params(bottom=False, labelbottom=False) ax.patch.set_visible(False) ax.imshow(img, extent=(-1, 1, -1, 1), **kwargs) - graticule(inc, 0.0, phase, ax=ax) + graticule(inc, obl, phase, ax=ax) def video(y, inc=None, u=None, duration=4, fps=10, **kwargs):