Deep Dive#
The Particle and System classes are convenient for simple simulations, but they are built on top of much more flexible individual functions. Here’s we’ll demonstrate how some of them come together to move our particles around.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from astropy.time import Time
from jorbit.accelerations import create_newtonian_ephemeris_acceleration_func
from jorbit.accelerations.newtonian import newtonian_gravity
from jorbit.accelerations.gr import ppn_gravity
from jorbit.ephemeris import Ephemeris
from jorbit.integrators.ias15 import (
ias15_evolve,
initialize_ias15_integrator_state,
next_proposed_dt_PRS23,
)
from jorbit.utils.states import SystemState, IAS15IntegratorState
Here’s a simple situation involving a handful of small particles and a few massive ones. I guarantee the system will drift since we’re not in the center of mass frame, and it’s likely that some of these particles are unbounded. But for our purposes, we’re just going to let them go and see what happens.
n_tracer_particles = 10
n_massive_particles = 3
# the underlying state representation behind Particle and System
s = SystemState(
tracer_positions=jax.random.uniform(jax.random.PRNGKey(0), (n_tracer_particles, 3))
* 10,
tracer_velocities=jax.random.uniform(
jax.random.PRNGKey(1), (n_tracer_particles, 3)
),
massive_positions=jax.random.uniform(
jax.random.PRNGKey(2), (n_massive_particles, 3)
)
* 10,
massive_velocities=jax.random.uniform(
jax.random.PRNGKey(3), (n_massive_particles, 3)
),
log_gms=jnp.log(
jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles,)) * 1e-3
),
acceleration_func_kwargs={},
time_reference=0.0,
relative_time=0.0,
fixed_perturber_positions=jnp.empty((0, 3)),
fixed_perturber_velocities=jnp.empty((0, 3)),
fixed_perturber_log_gms=jnp.empty((0,)),
)
# ias15_evolve requires any jax partialized function that takes in a SystemState and
# returns a vector of accelerations of the same shape as the positions, ordered by
# massive particles first. This function can be time-dependent and/or include
# time-dependent parameters: that's why we include acceleration_func_kwargs and the
# (relative_time, time_reference) pair in the SystemState. In the usual solar system
# integration case, we use relative_time + time_reference to compute the positions of
# the perturbing planets/asteriods at the timestep
# in question
acceleration_func = jax.tree_util.Partial(newtonian_gravity)
# we need to initialize the integrator with the starting acceleration values
a0 = acceleration_func(s)
init_integrator = initialize_ias15_integrator_state(a0=a0)
# the step scheduler controls how IAS15 picks the next proposed step size from the
# converged predictor-corrector iterates. PRS23 is the controller from Pham+ 2023.
step_scheduler = jax.tree_util.Partial(next_proposed_dt_PRS23)
# now we run it
positions, velocities, final_system_state, final_integrator_state, n_steps = (
ias15_evolve(
initial_system_state=s,
times=jnp.linspace(
0, 10, 10
), # actual end times--not timesteps. it picks the timesteps
acceleration_func=acceleration_func,
initial_integrator_state=init_integrator,
step_scheduler=step_scheduler,
)
)
Right now, the newtonian_acceleration function is the only built-in that’s optimized for large systems. It splits the massless tracer particles from the massive ones to avoid unnecessary pairwise calculations, which lets us evaluate the accelerations of much larger systems without a problem. However actually integrating those accelerations is still slower than ideal, so keep your systems small for now
# same as before, but now way more tracer particles
n_tracer_particles = int(1e6)
n_massive_particles = 3
s = SystemState(
tracer_positions=jax.random.uniform(jax.random.PRNGKey(0), (n_tracer_particles, 3))
* 10,
tracer_velocities=jax.random.uniform(
jax.random.PRNGKey(1), (n_tracer_particles, 3)
),
massive_positions=jax.random.uniform(
jax.random.PRNGKey(2), (n_massive_particles, 3)
)
* 10,
massive_velocities=jax.random.uniform(
jax.random.PRNGKey(3), (n_massive_particles, 3)
),
log_gms=jnp.log(
jax.random.uniform(jax.random.PRNGKey(3), (n_massive_particles,)) * 1e-3
),
acceleration_func_kwargs={},
time_reference=0.0,
relative_time=0.0,
fixed_perturber_positions=jnp.empty((0, 3)),
fixed_perturber_velocities=jnp.empty((0, 3)),
fixed_perturber_log_gms=jnp.empty((0,)),
)
acceleration_func = jax.tree_util.Partial(newtonian_gravity)
a0 = acceleration_func(s) # run it once to compile
%timeit acceleration_func(s).block_until_ready()
6.88 ms ± 405 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
When we’re within the solar system, instead of using these vanilla gravitational acceleration functions, we instead use ones that take into account the perturbations of all the planets as set by the DE440 ephemeris. This happens automatically in the Particle and System classes, but here we’ll do it manually.
First, let’s create an Ephemeris object that can extract data from our local copy of the DE440 ephemeris:
eph = Ephemeris(
ssos="default planets",
earliest_time=Time("1980-01-01"),
latest_time=Time("2050-01-01"),
)
This creates the nicely-facing public object that can serve up the positions and velocities of the planets at any time:
eph.state(Time("2000-01-01"))
{'sun': {'x': <Quantity [-0.00713986, -0.00264396, -0.00092139] AU>,
'v': <Quantity [ 5.37426823e-06, -6.76193952e-06, -3.03437408e-06] AU / d>,
'log_gm': Array(-8.12544774, dtype=float64, weak_type=True)},
'mercury': {'x': <Quantity [-0.14785222, -0.40063289, -0.198918 ] AU>,
'v': <Quantity [ 0.02117455, -0.00551464, -0.00514067] AU / d>,
'log_gm': Array(-23.73665301, dtype=float64, weak_type=True)},
'venus': {'x': <Quantity [-0.7257697 , -0.03968176, 0.02789532] AU>,
'v': <Quantity [ 0.00051933, -0.01851507, -0.0083622 ] AU / d>,
'log_gm': Array(-21.045753, dtype=float64, weak_type=True)},
'earth': {'x': <Quantity [-0.17567731, 0.88619693, 0.3844338 ] AU>,
'v': <Quantity [-0.01722853, -0.00276646, -0.00119947] AU / d>,
'log_gm': Array(-20.84118348, dtype=float64, weak_type=True)},
'moon': {'x': <Quantity [-0.17780043, 0.88461595, 0.3840147 ] AU>,
'v': <Quantity [-0.01690458, -0.0031899 , -0.0013841 ] AU / d>,
'log_gm': Array(-25.23933649, dtype=float64, weak_type=True)},
'mars': {'x': <Quantity [ 1.38322176, -0.00813949, -0.0410353 ] AU>,
'v': <Quantity [0.00075319, 0.01380716, 0.00631275] AU / d>,
'log_gm': Array(-23.07194211, dtype=float64, weak_type=True)},
'jupiter': {'x': <Quantity [3.99631685, 2.73099757, 1.07327637] AU>,
'v': <Quantity [-0.00455811, 0.005878 , 0.00263057] AU / d>,
'log_gm': Array(-15.07946488, dtype=float64, weak_type=True)},
'saturn': {'x': <Quantity [6.40141168, 6.17025198, 2.27302953] AU>,
'v': <Quantity [-0.00428575, 0.00352277, 0.00163933] AU / d>,
'log_gm': Array(-16.28536632, dtype=float64, weak_type=True)},
'uranus': {'x': <Quantity [ 14.42337962, -12.51013934, -5.68313086] AU>,
'v': <Quantity [0.00268375, 0.00245501, 0.00103727] AU / d>,
'log_gm': Array(-18.16446878, dtype=float64, weak_type=True)},
'neptune': {'x': <Quantity [ 16.80361936, -22.98357741, -9.82565798] AU>,
'v': <Quantity [0.00258474, 0.00166154, 0.00061573] AU / d>,
'log_gm': Array(-17.99910783, dtype=float64, weak_type=True)},
'pluto': {'x': <Quantity [ -9.88400421, -27.98094909, -5.75398118] AU>,
'v': <Quantity [ 0.00303408, -0.0011345 , -0.00126819] AU / d>,
'log_gm': Array(-26.8539481, dtype=float64, weak_type=True)}}
But, it also contains a pytree-compatible JAX class with the same functionality called an EphemerisProcessor:
eph.processor.state(Time("2000-01-01").tdb.jd)
(Array([[-7.13986335e-03, -2.64396337e-03, -9.21394198e-04],
[-1.47852217e-01, -4.00632892e-01, -1.98918003e-01],
[-7.25769699e-01, -3.96817640e-02, 2.78953240e-02],
[-1.75677314e-01, 8.86196930e-01, 3.84433804e-01],
[-1.77800434e-01, 8.84615947e-01, 3.84014702e-01],
[ 1.38322176e+00, -8.13948942e-03, -4.10352972e-02],
[ 3.99631685e+00, 2.73099757e+00, 1.07327637e+00],
[ 6.40141168e+00, 6.17025198e+00, 2.27302953e+00],
[ 1.44233796e+01, -1.25101393e+01, -5.68313086e+00],
[ 1.68036194e+01, -2.29835774e+01, -9.82565798e+00],
[-9.88400421e+00, -2.79809491e+01, -5.75398118e+00]], dtype=float64),
Array([[ 5.37426823e-06, -6.76193952e-06, -3.03437408e-06],
[ 2.11745508e-02, -5.51463941e-03, -5.14066968e-03],
[ 5.19329969e-04, -1.85150738e-02, -8.36219771e-03],
[-1.72285335e-02, -2.76645660e-03, -1.19946950e-03],
[-1.69045775e-02, -3.18990180e-03, -1.38409671e-03],
[ 7.53187821e-04, 1.38071602e-02, 6.31274981e-03],
[-4.55810624e-03, 5.87800299e-03, 2.63056670e-03],
[-4.28574727e-03, 3.52276973e-03, 1.63933448e-03],
[ 2.68375383e-03, 2.45501219e-03, 1.03727032e-03],
[ 2.58474369e-03, 1.66154265e-03, 6.15729144e-04],
[ 3.03407638e-03, -1.13450133e-03, -1.26819304e-03]], dtype=float64))
We can use this EphemerisProcessor to build an acceleration function:
def func(inputs: SystemState) -> jnp.ndarray:
perturber_xs, perturber_vs = ephem_processor.state(
inputs.relative_time + inputs.time_reference
)
perturber_log_gms = ephem_processor.log_gms
new_state = SystemState(
massive_positions=inputs.massive_positions,
massive_velocities=inputs.massive_velocities,
tracer_positions=inputs.tracer_positions,
tracer_velocities=inputs.tracer_velocities,
log_gms=inputs.log_gms,
time_reference=inputs.time_reference,
relative_time=inputs.relative_time,
acceleration_func_kwargs=inputs.acceleration_func_kwargs,
fixed_perturber_positions=perturber_xs,
fixed_perturber_velocities=perturber_vs,
fixed_perturber_log_gms=perturber_log_gms,
)
accs = newtonian_gravity(new_state)
num_perturbers = perturber_xs.shape[0]
return accs[num_perturbers:]
acceleration_func = jax.tree_util.Partial(func)
This can now be used in with ias15_evolve just like the simpler newtonian_gravity function: now whenever we ask for the acceleration of a SystemState of particles, it’ll compute the positions and velocities of the perturbing planets at that time, tack them onto the SystemState, compute self-consistent accelerations for everything, then cleave off the perturbers again at the end.