Skip to content

Commit

Permalink
feat: add obliquity + sigmoid instead of steps to allow grads
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed Nov 22, 2024
1 parent b94a5ec commit cf9a0ea
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 39 deletions.
39 changes: 27 additions & 12 deletions spotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import jax.numpy as jnp
import numpy as np

from spotter import utils


def distance(X, x):
return jnp.arctan2(
Expand All @@ -14,7 +16,7 @@ def npix(N):
return hp.nside2npix(N)


def equator_coords(phi=None, inc=None):
def equator_coords(phi=None, inc=None, obl=None):
# inc = None -> inc = pi/2
if inc is None:
si = 1.0
Expand All @@ -30,15 +32,28 @@ def equator_coords(phi=None, inc=None):
cp = jnp.cos(phi)
sp = jnp.sin(phi)

if obl is None:
co = 1.0
so = 0.0
else:
co = jnp.cos(obl)
so = jnp.sin(obl)

x = si * cp
y = si * sp
z = ci

x, y, z = (
x,
co * y - so * z,
so * y + co * z,
)

return jnp.array([x, y, z])


def mask_projected_limb(X, phase=None, inc=None, u=None):
d = distance(X, equator_coords(phase, inc))
def mask_projected_limb(X, phase=None, inc=None, u=None, obl=None):
d = distance(X, equator_coords(phase, inc, obl))
mask = d < jnp.pi / 2
z = jnp.cos(d)
projected_area = z
Expand Down Expand Up @@ -67,15 +82,15 @@ def vec(N_or_y):
return np.array(hp.pix2vec(N, range(n))).T


def design_matrix(N_or_y, phase=None, inc=None, u=None):
def design_matrix(N_or_y, phase=None, inc=None, u=None, obl=None):
X = vec(N_or_y)
mask, projected_area, limb_darkening = mask_projected_limb(X, phase, inc, u)
mask, projected_area, limb_darkening = mask_projected_limb(X, phase, inc, u, obl)
geometry = mask * projected_area
return jnp.pi * limb_darkening * geometry / (geometry * limb_darkening).sum()


def flux(y, inc=None, u=None, phase=None):
return design_matrix(y, inc=inc, u=u, phase=phase) @ y
def flux(y, inc=None, u=None, phase=None, obl=None):
return design_matrix(y, inc=inc, u=u, phase=phase, obl=obl) @ y


def spherical_to_cartesian(theta, phi):
Expand All @@ -86,10 +101,10 @@ def spherical_to_cartesian(theta, phi):
return jnp.array([x, y, z])


def spot(N, latitude, longitude, radius):
def spot(N, latitude, longitude, radius, sharpness=1000):
X = vec(N)
d = distance(X, spherical_to_cartesian(jnp.pi / 2 - latitude, longitude))
return d < radius
d = distance(X, spherical_to_cartesian(jnp.pi / 2 - latitude, -longitude))
return 1 - utils.sigmoid(d - radius, sharpness)


def soft_spot(N, latitude, longitude, radius):
Expand All @@ -101,14 +116,14 @@ def soft_spot(N, latitude, longitude, radius):
return profile / jnp.max(profile)


def render(y, inc=None, u=None, phase=0.0):
def render(y, inc=None, u=None, phase=0.0, obl=0.0):
import matplotlib.pyplot as plt

X = vec(y)

limb_darkening = mask_projected_limb(X, phase, inc, u)[2]
rotated = hp.Rotator(
rot=[phase, np.pi / 2 - inc or 0.0], deg=False
rot=[phase, np.pi / 2 - inc or 0.0, obl or 0.0], deg=False
).rotate_map_pixel(y * limb_darkening)

projected_map = hp.orthview(rotated, half_sky=True, return_projected_map=True)
Expand Down
9 changes: 4 additions & 5 deletions spotter/light_curves.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax

from functools import partial

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

Expand Down Expand Up @@ -74,8 +73,8 @@ def impl(star, time):
def transit_light_curve(
star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0, time: float = 0.0
):
"""Light curve of a transited Star.
"""Light curve of a transited Star. The x-axis cross the star in the horizontal direction (→),
and the y-axis cross the star in the vertical up direction (↑).
Parameters
----------
star : Star
Expand All @@ -94,4 +93,4 @@ def transit_light_curve(
ArrayLike
Light curve array.
"""
return light_curve(transited_star(star, x, y, r), star.phase(time))
return light_curve(transited_star(star, y, x, r), star.phase(time))
50 changes: 32 additions & 18 deletions spotter/star.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import equinox as eqx
import healpy as hp
import jax
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike

from spotter import core, viz
from spotter import core, utils, viz


class Star(eqx.Module):
Expand Down Expand Up @@ -51,6 +52,9 @@ class Star(eqx.Module):
inc: float | None = None
"""Inclination of the star, in radians. 0 is pole-on, pi/2 is equator-on."""

obl: float | None = None
"""Obliquity of the star, in radians. 0 is no obliquity, pi/2 is maximum obliquity."""

radius: float | None = None
"""Radius of the star, in solar radii."""

Expand All @@ -65,13 +69,15 @@ def __init__(
y: ArrayLike | None = None,
u: ArrayLike | None = None,
inc: float | None = None,
obl: float | None = None,
period: float | None = None,
radius: float | None = None,
wv: float | None = None,
):
self.y = jnp.atleast_2d(y)
self.u = jnp.atleast_2d(u) if u is not None else None
self.inc = inc
self.obl = obl
self.period = period
self.sides = core._N_or_Y_to_N_n(self.y[0])[0]
self.radius = radius if radius is not None else 1.0
Expand Down Expand Up @@ -124,7 +130,7 @@ def __mul__(self, other):
y = self.y * other.y
else:
y = self.y * other
return self.__class__(y, self.u, self.inc, self.period)
return self.set(y=y)

def __rmul__(self, other):
return self.__mul__(other)
Expand All @@ -134,7 +140,7 @@ def __add__(self, other):
y = self.y + other.y
else:
y = self.y + other
return self.__class__(y, self.u, self.inc, self.period)
return self.set(y=y)

def __radd__(self, other):
return self.__add__(other)
Expand All @@ -144,7 +150,7 @@ def __sub__(self, other):
y = self.y - other.y
else:
y = self.y - other
return self.__class__(y, self.u, self.inc, self.period)
return self.set(y=y)

def __rsub__(self, other):
return self.__sub__(other)
Expand All @@ -161,6 +167,7 @@ def set(self, **kwargs):
"y": self.y,
"u": self.u,
"inc": self.inc,
"obl": self.obl,
"period": self.period,
"radius": self.radius,
"wv": self.wv,
Expand All @@ -184,6 +191,7 @@ def show(star: Star, phase: ArrayLike = 0.0, ax=None, **kwargs):
viz.show(
star.y[0],
star.inc if star.inc is not None else np.pi / 2,
star.obl if star.obl is not None else 0.0,
star.u[0] if star.u is not None else None,
phase,
ax=ax,
Expand Down Expand Up @@ -232,18 +240,24 @@ def transited_star(star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0):
Star
Star object transited by the disk.
"""
if star.inc is None:
c = 0.0
s = 1.0
else:
c = jnp.sin(star.inc)
s = jnp.cos(star.inc)
from jax.scipy.spatial.transform import Rotation

_z, _y, _x = core.vec(star.sides).T
_x = _x * c - _z * s
return (
1
- (
jnp.linalg.norm(jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0)
< r
)
) * star
v = jnp.stack((_x, _y, _z), axis=-1)

inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0
_inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle)
_rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(v)
rv = jnp.where(inc_angle == 0.0, v, _rv)

if star.obl is not None:
obl_angle = jnp.where(star.obl == 0.0, 1.0, star.obl)
_rv = Rotation.from_rotvec([0.0, 0.0, obl_angle]).apply(rv)
rv = jnp.where(obl_angle == 0.0, v, _rv)

_x, _y, _ = rv.T

distance = jnp.linalg.norm(
jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0
)
return utils.sigmoid(distance - r, 1000.0) * star
6 changes: 6 additions & 0 deletions spotter/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import defaultdict

import healpy as hp
import jax
import jax.numpy as jnp
import numpy as np


Expand Down Expand Up @@ -29,3 +31,7 @@ def ylm2healpix(y):
hy[i] = _hy[i]

return hy


def sigmoid(x, scale=1000):
return (jnp.tanh(x * scale / 2) + 1) / 2
8 changes: 4 additions & 4 deletions spotter/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def rotation(inc, obl, theta):
u /= np.linalg.norm(u)
u *= inc

R = Rotation.from_rotvec(u)
R = Rotation.from_rotvec(np.array(u))
R *= Rotation.from_rotvec([0, 0, obl])
R *= Rotation.from_rotvec([np.pi / 2, 0, 0])
R *= Rotation.from_rotvec([0, 0, -theta])
Expand Down Expand Up @@ -121,7 +121,7 @@ def graticule(
ax.plot(sqrt_radius * np.cos(theta), sqrt_radius * np.sin(theta), c="w", lw=3)


def show(y, inc=np.pi / 2, u=None, phase=0.0, ax=None, **kwargs):
def show(y, inc=np.pi / 2, obl=0.0, u=None, phase=0.0, ax=None, **kwargs):
import matplotlib.pyplot as plt

kwargs.setdefault("cmap", _DEFAULT_CMAP)
Expand All @@ -130,13 +130,13 @@ def show(y, inc=np.pi / 2, u=None, phase=0.0, ax=None, **kwargs):
# kwargs.setdefault("vmax", 1.0)
ax = ax or plt.gca()

img = core.render(y, inc, u, phase)
img = core.render(y, inc, u, phase, obl)
plt.setp(ax.spines.values(), visible=False)
ax.tick_params(left=False, labelleft=False)
ax.tick_params(bottom=False, labelbottom=False)
ax.patch.set_visible(False)
ax.imshow(img, extent=(-1, 1, -1, 1), **kwargs)
graticule(inc, 0.0, phase, ax=ax)
graticule(inc, obl, phase, ax=ax)


def video(y, inc=None, u=None, duration=4, fps=10, **kwargs):
Expand Down

0 comments on commit cf9a0ea

Please sign in to comment.