Source code for jorbit.accelerations.newtonian

"""Simple Newtonian gravity acceleration function."""

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit.utils.states import SystemState


[docs] @jax.jit def newtonian_gravity(inputs: SystemState) -> jnp.ndarray: """Compute the acceleration felt by each particle due to Newtonian gravity. Set up to separate massive and tracer particles, so systems with many tracers will not compute useless pairwise interactions. Note: We use "stop_gradient" on perturbers that are passed as fixed inputs, so any gradients with respect to these perturber quantities will not be correct. To track gradients with respect to perturbers, they must be included as "massive" particles, not "fixed perturbers". Args: inputs (SystemState): The instantaneous state of the system. Returns: jnp.ndarray: The 3D acceleration felt by each particle, ordered by massive particles first followed by tracer particles. """ M = inputs.massive_positions.shape[0] # number of massive particles gms = jnp.exp(inputs.log_gms) # (M,) # massive particles due to other massive particles dx_massive = ( inputs.massive_positions[:, None, :] - inputs.massive_positions[None, :, :] ) # (M,M,3) r2_massive = jnp.sum(dx_massive * dx_massive, axis=-1) # (M,M) r3_massive = r2_massive * jnp.sqrt(r2_massive) # (M,M) mask_massive = ~jnp.eye(M, dtype=bool) # (M,M) prefac_massive = jnp.where(mask_massive, 1.0 / r3_massive, 0.0) a_massive = -jnp.sum( prefac_massive[:, :, None] * dx_massive * gms[None, :, None], axis=1, ) # (M,3) # tracer particles due to massive particles dx_tracers = ( inputs.tracer_positions[:, None, :] - inputs.massive_positions[None, :, :] ) # (T,M,3) r2_tracers = jnp.sum(dx_tracers * dx_tracers, axis=-1) # (T,M) r3_tracers = r2_tracers * jnp.sqrt(r2_tracers) # (T,M) a_tracers = -jnp.sum( (1.0 / r3_tracers)[:, :, None] * dx_tracers * gms[None, :, None], axis=1, ) # (T,3) all_as = jnp.concatenate([a_massive, a_tracers], axis=0) # Fixed perturbers acting on all (massive + tracer) particles. # stop_gradient since we never need gradients through fixed perturber quantities. p_pos = jax.lax.stop_gradient(inputs.fixed_perturber_positions) p_gms = jax.lax.stop_gradient(jnp.exp(inputs.fixed_perturber_log_gms)) dx_perturbers = ( jnp.concatenate([inputs.massive_positions, inputs.tracer_positions], axis=0)[ :, None, : ] - p_pos[None, :, :] ) # (N,P,3) r2_perturbers = jnp.sum(dx_perturbers * dx_perturbers, axis=-1) # (N,P) r3_perturbers = r2_perturbers * jnp.sqrt(r2_perturbers) # (N,P) a_perturbers = -jnp.sum( (1.0 / r3_perturbers)[:, :, None] * dx_perturbers * p_gms[None, :, None], axis=1, ) # (N,3) return all_as + a_perturbers