General N-body simulations#
While jorbit was designed primarily for solar system orbits and usually assumes that simulations should account for perturbations from the Sun and planets, it can also be used for more general N-body simulations.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from astropy.time import Time
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from jorbit import Particle, System
from jorbit.utils.states import CartesianState, SystemState
We usually create System objects using a collection of massless Particle objects, but we can also create one directly from a SystemState object that doesn’t need to reference anything having to do with the solar system. Here we’ll create a 3-body system whose funny looking initial condition will become explained shortly.
r = 1 / (2 * jnp.sin(jnp.pi / 3))
v = 1.0
initial_state = SystemState(
tracer_positions=jnp.empty((0, 3)),
tracer_velocities=jnp.empty((0, 3)),
massive_positions=jnp.array(
[
[r * jnp.cos(0), r * jnp.sin(0), 0],
[r * jnp.cos(2 * jnp.pi / 3), r * jnp.sin(2 * jnp.pi / 3), 0],
[r * jnp.cos(4 * jnp.pi / 3), r * jnp.sin(4 * jnp.pi / 3), 0],
]
),
massive_velocities=jnp.array(
[
[-v * jnp.sin(0), v * jnp.cos(0), 0],
[-v * jnp.sin(2 * jnp.pi / 3), v * jnp.cos(2 * jnp.pi / 3), 0],
[-v * jnp.sin(4 * jnp.pi / 3), v * jnp.cos(4 * jnp.pi / 3), 0],
]
),
log_gms=jnp.array([0.0, 0.0, 0.0]),
time_reference=0.0,
relative_time=0.0,
acceleration_func_kwargs={},
fixed_perturber_positions=jnp.empty((0, 3)),
fixed_perturber_velocities=jnp.empty((0, 3)),
fixed_perturber_log_gms=jnp.empty((0,)),
)
From this we can create a System object, but now we set the gravity argument to either “generic newtonian” for Newtonian gravity or “generic gr” for PPN corrections.
s = System(state=initial_state, gravity="generic newtonian")
times = jnp.array(jnp.linspace(0, 10.0, 500))
positions, velocities = s.integrate(times=times)
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect("equal")
scat = ax.scatter(positions[0, :, 0], positions[0, :, 1], s=50, c=["C0", "C1", "C2"])
time_text = ax.text(0.02, 0.95, "", transform=ax.transAxes)
def update(frame):
scat.set_offsets(positions[frame])
time_text.set_text(f"t = {times[frame]:.2f}")
return scat, time_text
anim = FuncAnimation(fig, update, frames=len(times), interval=50, blit=True)
plt.close()
HTML(anim.to_html5_video())