Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kk/close approach #14

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
192 changes: 93 additions & 99 deletions src/adam_assist/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
from adam_core.constants import Constants as c
from adam_core.coordinates import CartesianCoordinates, Origin, transform_coordinates
from adam_core.coordinates.origin import OriginCodes
from adam_core.dynamics.impacts import EarthImpacts, ImpactMixin
from adam_core.dynamics.impacts import CollisionConditions, CollisionEvent, ImpactMixin
from adam_core.orbits import Orbits
from adam_core.orbits.variants import VariantOrbits
from adam_core.propagator.propagator import OrbitType, Propagator, TimestampType
from adam_core.time import Timestamp
from jpl_small_bodies_de441_n16 import de441_n16
from naif_de440 import de440
from quivr.concat import concatenate

from adam_core.propagator.propagator import OrbitType, Propagator, TimestampType

C = c.C

try:
Expand Down Expand Up @@ -253,9 +254,12 @@ def _propagate_orbits_inner(

return results

def _detect_impacts(
self, orbits: OrbitType, num_days: int
) -> Tuple[VariantOrbits, EarthImpacts]:
def _detect_collisions(
self,
orbits: OrbitType,
num_days: int,
collision_conditions: CollisionConditions,
) -> Tuple[VariantOrbits, CollisionEvent]:
# Assert that the time for each orbit definition is the same for the simulator to work
assert len(pc.unique(orbits.coordinates.time.mjd())) == 1

Expand All @@ -264,7 +268,9 @@ def _detect_impacts(
# For units we use solar masses, astronomical units, and days.
# The time coordinate is Barycentric Dynamical Time (TDB) in Julian days.

# Convert coordinates to ICRF using TDB time
# KK Note: do we want to specify the version of spice kernels that were used- if we're doing
# addtional work down stream, to ensure that the same kernels are used? de440, 441 for asteroid position

coords = transform_coordinates(
orbits.coordinates,
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
Expand All @@ -281,13 +287,12 @@ def _detect_impacts(
sim = None
sim = rebound.Simulation()

backward_propagation = num_days < 0

# Set the simulation time, relative to the jd_ref
start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0]
start_tdb_time = start_tdb_time - ephem.jd_ref
sim.t = start_tdb_time

backward_propagation = num_days < 0
if backward_propagation:
sim.dt = sim.dt * -1

Expand All @@ -308,7 +313,6 @@ def _detect_impacts(
# Add the orbits as particles to the simulation
coords_df = orbits.coordinates.to_dataframe()

# ASSIST _must_ be initialized before adding particles
assist.Extras(sim, ephem)

for i in range(len(coords_df)):
Expand Down Expand Up @@ -337,7 +341,7 @@ def _detect_impacts(
# Results stores the final positions of the objects
# If an object is an impactor, this represents its position at impact time
results = None
earth_impacts = None
collision_events = CollisionEvent.empty()
past_integrator_time = False
time_step_results: Union[None, OrbitType] = None

Expand Down Expand Up @@ -429,75 +433,84 @@ def _detect_impacts(
frame_out="ecliptic",
),
)

# Get the Earth's position at the current time
# earth_geo = get_perturber_state(OriginCodes.EARTH, results.coordinates.time[0], origin=OriginCodes.SUN)
# diff = time_step_results.coordinates.values - earth_geo.coordinates.values
earth_geo = ephem.get_particle("Earth", sim.t)
earth_geo = CartesianCoordinates.from_kwargs(
x=[earth_geo.x],
y=[earth_geo.y],
z=[earth_geo.z],
vx=[earth_geo.vx],
vy=[earth_geo.vy],
vz=[earth_geo.vz],
time=Timestamp.from_jd([sim.t + ephem.jd_ref], scale="tdb"),
origin=Origin.from_kwargs(
code=["SOLAR_SYSTEM_BARYCENTER"],
),
frame="equatorial",
)
earth_geo = transform_coordinates(
earth_geo,
origin_out=OriginCodes.SUN,
frame_out="ecliptic",
)
diff = time_step_results.coordinates.values - earth_geo.values

# Calculate the distance in KM
# We use the IAU definition of the astronomical unit (149_597_870.7 km)
normalized_distance = np.linalg.norm(diff[:, :3], axis=1) * KM_P_AU

# Calculate which particles are within an Earth radius
within_radius = normalized_distance < EARTH_RADIUS_KM

# If any are within our earth radius, we record the impact
# and do bookkeeping to remove the particle from the simulation
if np.any(within_radius):
distances = normalized_distance[within_radius]
impacting_orbits = time_step_results.apply_mask(within_radius)

if isinstance(orbits, VariantOrbits):
new_impacts = EarthImpacts.from_kwargs(
orbit_id=impacting_orbits.orbit_id,
distance=distances,
coordinates=impacting_orbits.coordinates,
variant_id=impacting_orbits.variant_id,
)
elif isinstance(orbits, Orbits):
new_impacts = EarthImpacts.from_kwargs(
orbit_id=impacting_orbits.orbit_id,
distance=distances,
coordinates=impacting_orbits.coordinates,
)
if earth_impacts is None:
earth_impacts = new_impacts
else:
earth_impacts = qv.concatenate([earth_impacts, new_impacts])

# Remove the particle from the simulation, orbits, and store in results
for hash_id in orbit_id_hashes[within_radius]:
sim.remove(hash=c_uint32(hash_id))
# For some reason, it fails if we let rebound convert the hash to c_uint32

# Remove the particle from the input / running orbits
# This allows us to carry through object_id, weights, and weights_cov
orbits = orbits.apply_mask(~within_radius)
# Put the orbits / variants of the impactors into the results set
if results is None:
results = impacting_orbits
else:
results = qv.concatenate([results, impacting_orbits])
for particle in collision_conditions:
particle_location = ephem.get_particle(
particle.collision_object_name.to_numpy(
zero_copy_only=False
).astype(str)[0],
sim.t,
)
particle_location = CartesianCoordinates.from_kwargs(
x=[particle_location.x],
y=[particle_location.y],
z=[particle_location.z],
vx=[particle_location.vx],
vy=[particle_location.vy],
vz=[particle_location.vz],
time=Timestamp.from_jd([sim.t + ephem.jd_ref], scale="tdb"),
origin=Origin.from_kwargs(
code=["SOLAR_SYSTEM_BARYCENTER"],
),
frame="equatorial",
)
particle_location = transform_coordinates(
particle_location,
origin_out=OriginCodes.SUN,
frame_out="ecliptic",
)
diff = time_step_results.coordinates.values - particle_location.values

# Calculate the distance in KM
# We use the IAU definition of the astronomical unit (149_597_870.7 km)
normalized_distance = np.linalg.norm(diff[:, :3], axis=1) * KM_P_AU

# Calculate which particles are within the collision distance
within_radius = normalized_distance < particle.collision_distance

# If any are within our collision distance, we record the impact
# and do bookkeeping to remove the particle from the simulation
if np.any(within_radius):
distances = normalized_distance[within_radius]
colliding_orbits = time_step_results.apply_mask(within_radius)

if isinstance(orbits, VariantOrbits):
new_impacts = CollisionEvent.from_kwargs(
orbit_id=colliding_orbits.orbit_id,
distance=distances,
coordinates=colliding_orbits.coordinates,
variant_id=colliding_orbits.variant_id,
collision_object_name=particle.collision_object_name,
collision_distance=particle.collision_distance,
stopping_condition=particle.stopping_condition,
)
elif isinstance(orbits, Orbits):
new_impacts = CollisionEvent.from_kwargs(
orbit_id=colliding_orbits.orbit_id,
distance=distances,
coordinates=colliding_orbits.coordinates,
collision_object_name=particle.collision_object_name,
collision_distance=particle.collision_distance,
stopping_condition=particle.stopping_condition,
)
collision_events = qv.concatenate([collision_events, new_impacts])

stopping_condition = particle.stopping_condition.to_numpy(
zero_copy_only=False
)[0]

if stopping_condition:
for hash_id in orbit_id_hashes[within_radius]:
sim.remove(hash=c_uint32(hash_id))
# For some reason, it fails if we let rebound convert the hash to c_uint32

# Remove the particle from the input / running orbits
# This allows us to carry through object_id, weights, and weights_cov
orbits = orbits.apply_mask(~within_radius)
# Put the orbits / variants of the impactors into the results set
if results is None:
results = colliding_orbits
else:
results = qv.concatenate([results, colliding_orbits])

# Add the final positions of the particles that are not already in the results
if time_step_results is not None:
Expand All @@ -516,23 +529,4 @@ def _detect_impacts(
[results, time_step_results.apply_mask(still_in_simulation)]
)

if earth_impacts is None:
earth_impacts = EarthImpacts.from_kwargs(
orbit_id=[],
distance=[],
coordinates=CartesianCoordinates.from_kwargs(
x=[],
y=[],
z=[],
vx=[],
vy=[],
vz=[],
time=Timestamp.from_jd([], scale="tdb"),
origin=Origin.from_kwargs(
code=[],
),
frame="ecliptic",
),
variant_id=[],
)
return results, earth_impacts
return results, collision_events
36 changes: 36 additions & 0 deletions tests/test_collisions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from adam_core.orbits import Orbits
from src.adam_core.propagator.adam_assist import ASSISTPropagator, CollisionConditions

IMPACTOR_FILE_PATH_60 = "tests/data/I00007_orbit.parquet"
# Contains a likely impactor with 100% chance of impact in 30 days
IMPACTOR_FILE_PATH_100 = "tests/data/I00008_orbit.parquet"
# Contains a likely impactor with 0% chance of impact in 30 days
IMPACTOR_FILE_PATH_0 = "tests/data/I00009_orbit.parquet"


def test_detect_collisions():
orbits = Orbits.from_parquet(IMPACTOR_FILE_PATH_100)[0]
propagator = ASSISTPropagator()

collision_conditions = CollisionConditions.from_kwargs(
collision_object_name=["Earth"],
collision_distance=[7000],
stopping_condition=[True],
)
results, collisions = propagator._detect_collisions(
orbits, 60, collision_conditions
)

assert len(collisions) == 1
assert collisions.distance.to_numpy()[0] <= 7000

collision_conditions = CollisionConditions.from_kwargs(
collision_object_name=["Earth", "Earth"],
collision_distance=[10000, 7000],
stopping_condition=[False, True],
)
results, collisions = propagator._detect_collisions(
orbits, 60, collision_conditions
)

assert len(collisions) > 1
26 changes: 16 additions & 10 deletions tests/test_impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_calculate_impacts_benchmark(benchmark, processes):
propagator,
num_samples=200,
processes=processes,
seed=42 # This allows us to predict exact number of impactors empirically
seed=42, # This allows us to predict exact number of impactors empirically
)
assert len(variants) == 200, "Should have 200 variants"
assert len(impacts) == 138, "Should have exactly 138 impactors"
Expand All @@ -35,7 +35,7 @@ def test_calculate_impacts_benchmark(benchmark, processes):
@pytest.mark.benchmark
@pytest.mark.parametrize("processes", [1, 2])
def test_calculate_impacts_benchmark(benchmark, processes):

impactor = Orbits.from_parquet(IMPACTOR_FILE_PATH_100)[0]
propagator = ASSISTPropagator()
variants, impacts = benchmark(
Expand All @@ -45,7 +45,7 @@ def test_calculate_impacts_benchmark(benchmark, processes):
propagator,
num_samples=200,
processes=processes,
seed=42 # This allows us to predict exact number of impactors empirically
seed=42, # This allows us to predict exact number of impactors empirically
)
assert len(variants) == 200, "Should have 200 variants"
assert len(impacts) == 200, "Should have exactly 200 impactors"
Expand All @@ -63,20 +63,26 @@ def test_calculate_impacts_benchmark(benchmark, processes):
propagator,
num_samples=200,
processes=processes,
seed=42 # This allows us to predict exact number of impactors empirically
seed=42, # This allows us to predict exact number of impactors empirically
)
assert len(variants) == 200, "Should have 200 variants"
assert len(impacts) == 0, "Should have exactly 0 impactors"


def test_detect_impacts_time_direction():
def test_detect_collisions_time_direction():
start_time = Timestamp.from_mjd([60000], scale="utc")
orbit = query_horizons(["1980 PA"], start_time)

propagator = ASSISTPropagator()

results, impacts = propagator._detect_impacts(orbit, 60)
assert results.coordinates.time.mjd().to_numpy()[0] >= orbit.coordinates.time.add_days(60).mjd().to_numpy()[0]
results, impacts = propagator._detect_collisions(orbit, 60)
assert (
results.coordinates.time.mjd().to_numpy()[0]
>= orbit.coordinates.time.add_days(60).mjd().to_numpy()[0]
)

results, impacts = propagator._detect_impacts(orbit, -60)
assert results.coordinates.time.mjd().to_numpy()[0] <= orbit.coordinates.time.add_days(-60).mjd().to_numpy()[0]
results, impacts = propagator._detect_collisions(orbit, -60)
assert (
results.coordinates.time.mjd().to_numpy()[0]
<= orbit.coordinates.time.add_days(-60).mjd().to_numpy()[0]
)
Loading