Skip to content

Commit

Permalink
feat: better jax
Browse files Browse the repository at this point in the history
feat: better jax
  • Loading branch information
lgrcia authored Mar 13, 2024
2 parents d96e185 + 4e24a46 commit 3a41fa6
Show file tree
Hide file tree
Showing 20 changed files with 1,426 additions and 846 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
id-token: write
needs: [tests, build]
runs-on: ubuntu-latest
# if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/download-artifact@v4
with:
Expand Down
Binary file added docs/source/_static/spotter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 2 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
project = "spotter"
copyright = "2023, Lionel Garcia, Benjamin Rackham"
copyright = "2023 - 2024, Lionel Garcia, Benjamin Rackham"
author = "Lionel Garcia, Benjamin Rackham"
release = "0.0.2"

extensions = [
"myst_nb",
Expand Down Expand Up @@ -39,9 +38,7 @@
]

nb_execution_mode = "off"
html_short_title = "spotter"
html_title = f"{html_short_title}"

html_logo = "_static/spotter.png"
html_css_files = ["style.css"]
myst_url_schemes = ("http", "https")

Expand Down
33 changes: 12 additions & 21 deletions docs/source/index.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,34 @@
# spotter

```{image} _static/spotter.jpg
:width: 400px
:align: center
```
*Approximate forward models of fluxes and spectra time-series of non-uniform stars.*

---

*spotter* is a Python package to produce forward models of non-uniform stars spectra. It uses the [HEALPix](https://healpix.sourceforge.io/) subdivision scheme and is powered by the high-performance numerical package [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), enabling its use on GPUs.
```{warning}
Use at your own risk as the code is completely untested and its API subject to change.
```

**Note**
*spotter* uses the [HEALPix](https://healpix.sourceforge.io/) subdivision scheme and is powered by the high-performance numerical package [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), enabling its use on GPUs.

In its beta version, *spotter* is mainly developed to estimate transmission spectra stellar contamination from stellar rotational light curves. Use at your own risk as the code is completely untested and its API subject to change.

## Features

- Adjustable surface resolution <span style="color:grey">- *in beta*</span>
- Small-scale surface features modeling (e.g. beyond limitations of [starry]()) <span style="color:grey">- *in beta*</span>
- Small-scale surface features (e.g. beyond limitations of [starry]()) <span style="color:grey">
- Modeling of any active regions with their limb laws (e.g. limb-brightened faculae)
- GPU compatible <span style="color:grey">- *in beta*</span>
- GPU compatible <span style="color:grey">
- Possibility to input any stellar spectra model

```{toctree}
:maxdepth: 1
:caption: Get started
api
```

```{toctree}
:maxdepth: 1
:caption: Examples
notebooks/simple_example
notebooks/experiments
notebooks/amplitude_constraints.ipynb
notebooks/introduction
```

```{toctree}
:maxdepth: 1
:caption: Notes
:caption: Reference
notebooks/rotation.ipynb
api
```
336 changes: 272 additions & 64 deletions docs/source/notebooks/amplitude_constraints.ipynb

Large diffs are not rendered by default.

179 changes: 115 additions & 64 deletions docs/source/notebooks/experiments.ipynb

Large diffs are not rendered by default.

403 changes: 403 additions & 0 deletions docs/source/notebooks/introduction.ipynb

Large diffs are not rendered by default.

91 changes: 53 additions & 38 deletions docs/source/notebooks/jax_features.ipynb

Large diffs are not rendered by default.

89 changes: 29 additions & 60 deletions docs/source/notebooks/rotation.ipynb

Large diffs are not rendered by default.

26 changes: 14 additions & 12 deletions docs/source/notebooks/simple_example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[project]
name = "starspotter"
version = "0.0.6-beta"
version = "0.0.7-beta"
description = "Stellar contamination estimates from rotational light curves"
authors = [{name="Lionel Garcia"}, {name="Benjamin Rackham"}]
license = "MIT"
readme = "readme.md"
requires-python = ">=3.9"
packages = [{ include = "spotter" },]
dependencies = ["numpy", "healpy", "matplotlib", "jax", "jaxlib"]
dependencies = ["numpy", "healpy", "jax", "jaxlib"]

[project.optional-dependencies]
dev = ["black", "pytest", "nox"]
Expand Down
21 changes: 8 additions & 13 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,32 @@
# spotter

<p align="center" style="margin-bottom:-50px">
<img src="docs/_static/spotter.jpg" width="380">
<p align="center">
<img src="docs/source/_static/spotter.png" width="270">
</p>

<p align="center">
Forward models of non-uniform stellar photospheres and their spectra
Approximate forward models of fluxes and spectra time-series of non-uniform stars
<br>
<p align="center">
<a href="https://github.com/lgrcia/spotter">
<img src="https://img.shields.io/badge/github-lgrcia/spotter-e3a8a1.svg?style=flat" alt="github"/></a>
<img src="https://img.shields.io/badge/github-lgrcia/spotter-white.svg?style=flat" alt="github"/></a>
<a href="LICENCE">
<img src="https://img.shields.io/badge/license-MIT-lightgray.svg?style=flat" alt="license"/>
</a>
</p>
</p>

*spotter* is a Python package to produce forward models of non-uniform stellar photospheres and their spectra. It uses the [HEALPix](https://healpix.sourceforge.io/) subdivision scheme and is powered by the high-performance numerical package [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), enabling its use on GPUs.

**Note**
*spotter* uses the [HEALPix](https://healpix.sourceforge.io/) subdivision scheme and is powered by the high-performance numerical package [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), enabling its use on GPUs.

In its beta version, *spotter* is mainly developed to estimate transmission spectra stellar contamination from stellar rotational light curves. Use at your own risk as the code is completely untested and its API subject to change.

## Features

- Adjustable surface resolution <span style="color:grey">- *in beta*</span>
- Small-scale surface feature modeling (e.g., beyond limitations of [starry]()) <span style="color:grey">- *in beta*</span>
- Modeling of active regions with unique angular dependence on brightness (e.g., limb-brightened faculae)
- GPU compatible <span style="color:grey">- *in beta*</span>
- Small-scale surface features (e.g. beyond limitations of [starry]()) <span style="color:grey">
- Modeling of any active regions with their limb laws (e.g. limb-brightened faculae)
- GPU compatible <span style="color:grey">
- Possibility to input any stellar spectra model


## Installation

For now only locally with
Expand Down
94 changes: 73 additions & 21 deletions spotter/core.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,86 @@
import healpy as hp
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


def hemisphere_mask(thetas):
def mask(phase):
a = (phase + jnp.pi / 2) % (2 * jnp.pi)
b = (phase - jnp.pi / 2) % (2 * jnp.pi)
mask_1 = jnp.logical_and((thetas < a), (thetas > b))
mask_2 = jnp.logical_or((thetas > b), (thetas < a))
cond1 = a > phase % (2 * jnp.pi)
cond2 = b < phase % (2 * jnp.pi)
cond = cond1 * cond2
return jnp.where(cond, mask_1, mask_2)
def hemisphere_mask(theta, phase):
theta = jnp.atleast_1d(theta)
a = (phase + jnp.pi / 2) % (2 * jnp.pi)
b = (phase - jnp.pi / 2) % (2 * jnp.pi)
mask_1 = jnp.logical_and((theta < a), (theta > b))
mask_2 = jnp.logical_or((theta > b), (theta < a))
cond1 = a > phase % (2 * jnp.pi)
cond2 = b < phase % (2 * jnp.pi)
cond = cond1 * cond2
return jnp.where(cond, mask_1, mask_2)

return mask

def polynomial_limb_darkening(theta, phi, u=None, phase=0.0):
if u is None:
return 1.0
else:
theta = jnp.atleast_1d(theta)
phi = jnp.atleast_1d(phi)
u = jnp.atleast_1d(u)
z = jnp.sin(phi) * jnp.cos(theta - phase)
terms = jnp.array([un * (1 - z) ** (n + 1) for n, un in enumerate(u)])
return 1 - jnp.sum(terms, axis=theta.ndim - 1)

def polynomial_limb_darkening(thetas, phis):
def ld(u, phase):
z = jnp.sin(phis) * jnp.cos(thetas - phase)
terms = jnp.array([u * (1 - z) ** (n + 1) for n, u in enumerate(u)])
return 1 - jnp.sum(terms, 0)

return ld
def projected_area(theta, phi, phase):
return jnp.cos(theta - phase) * jnp.sin(phi)


def projected_area(thetas, phis):
def area(phase):
return jnp.cos(thetas - phase) * jnp.sin(phis)
def covering_fraction(x):
return jnp.mean(x > 0)

return area

def distance(thetas, phis):

p1 = phis - jnp.pi / 2
t1 = thetas
sp1 = jnp.sin(p1)
cp1 = jnp.cos(p1)

def fun(theta0, phi0):
# https://en.wikipedia.org/wiki/Great-circle_distance
# Vincenty formula
p2 = theta0 - jnp.pi / 2
t2 = phi0
dl = jnp.abs((t1 - t2))

sp2 = jnp.sin(p2)
cp2 = jnp.cos(p2)
cdl = jnp.cos(dl)
sdl = jnp.sin(dl)

a = (cp2 * sdl) ** 2 + (cp1 * sp2 - sp1 * cp2 * cdl) ** 2
b = sp1 * sp2 + cp1 * cp2 * cdl
return jnp.arctan2(jnp.sqrt(a), b)

return fun


def query_disk(thetas, phis):

distance_fn = distance(thetas, phis)

def fun(theta, phi, radius):
d = distance_fn(theta, phi)
return jnp.array(d <= radius, dtype=jnp.int8)

return fun


def smooth_spot(thetas, phis):

distance_fn = distance(thetas, phis)

def fun(theta, phi, r, c):
A = c * distance_fn(theta, phi) / (2 * r)
C = c / 2
return 0.5 * jnp.tanh(C - A) + 0.5 * jnp.tanh(C + A)

return fun
27 changes: 27 additions & 0 deletions spotter/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
import jax.numpy as jnp
import numpy as np
from jax import random


def jax_butterfly(key, latitudes=0.0, latitude_sigma=0.0, n=1):
new_key, subkey = random.split(key)
theta = jnp.pi / 2 - (
latitudes
+ random.normal(key, shape=(n,))
* latitude_sigma
* random.choice(subkey, jnp.array([-1.0, 1.0]), shape=(n,))
)
new_key, subkey = random.split(new_key)
phi = random.uniform(subkey, minval=0.0, maxval=2.0 * jnp.pi, shape=(n,))
return theta, phi


def jax_uniform(key, n=1):
# Latitude
theta = jnp.pi / 2 - jnp.arcsin(
random.uniform(key, minval=-1.0, maxval=1.0, shape=(n,))
)
_, subkey = random.split(key)
# Longitude
phi = random.uniform(subkey, minval=0.0, maxval=2.0 * jnp.pi, shape=(n,))

return theta, phi


def butterfly(latitudes=0, latitude_sigma=0, n=1):
Expand Down
Loading

0 comments on commit 3a41fa6

Please sign in to comment.