General N-body simulations

General N-body simulations#

While jorbit was designed primarily for solar system orbits and usually assumes that simulations should account for perturbations from the Sun and planets, it can also be used for more general N-body simulations.

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from astropy.time import Time
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from jorbit import Particle, System
from jorbit.utils.states import CartesianState, SystemState

We usually create System objects using a collection of massless Particle objects, but we can also create one directly from a SystemState object that doesn’t need to reference anything having to do with the solar system. Here we’ll create a 3-body system whose funny looking initial condition will become explained shortly.

r = 1 / (2 * jnp.sin(jnp.pi / 3))
v = 1.0
initial_state = SystemState(
    tracer_positions=jnp.empty((0, 3)),
    tracer_velocities=jnp.empty((0, 3)),
    massive_positions=jnp.array(
        [
            [r * jnp.cos(0), r * jnp.sin(0), 0],
            [r * jnp.cos(2 * jnp.pi / 3), r * jnp.sin(2 * jnp.pi / 3), 0],
            [r * jnp.cos(4 * jnp.pi / 3), r * jnp.sin(4 * jnp.pi / 3), 0],
        ]
    ),
    massive_velocities=jnp.array(
        [
            [-v * jnp.sin(0), v * jnp.cos(0), 0],
            [-v * jnp.sin(2 * jnp.pi / 3), v * jnp.cos(2 * jnp.pi / 3), 0],
            [-v * jnp.sin(4 * jnp.pi / 3), v * jnp.cos(4 * jnp.pi / 3), 0],
        ]
    ),
    log_gms=jnp.array([0.0, 0.0, 0.0]),
    time_reference=0.0,
    relative_time=0.0,
    acceleration_func_kwargs={},
    fixed_perturber_positions=jnp.empty((0, 3)),
    fixed_perturber_velocities=jnp.empty((0, 3)),
    fixed_perturber_log_gms=jnp.empty((0,)),
)

From this we can create a System object, but now we set the gravity argument to either “generic newtonian” for Newtonian gravity or “generic gr” for PPN corrections.

s = System(state=initial_state, gravity="generic newtonian")
times = jnp.array(jnp.linspace(0, 10.0, 500))
positions, velocities = s.integrate(times=times)
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect("equal")

scat = ax.scatter(positions[0, :, 0], positions[0, :, 1], s=50, c=["C0", "C1", "C2"])
time_text = ax.text(0.02, 0.95, "", transform=ax.transAxes)


def update(frame):
    scat.set_offsets(positions[frame])
    time_text.set_text(f"t = {times[frame]:.2f}")
    return scat, time_text


anim = FuncAnimation(fig, update, frames=len(times), interval=50, blit=True)
plt.close()

HTML(anim.to_html5_video())