Source code for jorbit.integrators.ias15.helpers
"""Low-level IAS15 primitives shared across the step and interpolation modules."""
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jorbit.data.constants import (
IAS15_BV_DENOMS,
IAS15_BX_DENOMS,
)
from jorbit.utils.states import IAS15IntegratorState
[docs]
def initialize_ias15_integrator_state(a0: jnp.ndarray) -> IAS15IntegratorState:
"""Initializes the IAS15IntegratorState dataclass with zeros.
Args:
a0 (jnp.ndarray):
The initial acceleration.
Returns:
IAS15IntegratorState:
An instance of the IAS15IntegratorState dataclass with zeros.
"""
n_particles = a0.shape[0]
zeros_b = jnp.zeros((7, n_particles, 3), dtype=jnp.float64)
return IAS15IntegratorState(
g=zeros_b,
b=zeros_b,
e=zeros_b,
csx=jnp.zeros((n_particles, 3), dtype=jnp.float64),
csv=jnp.zeros((n_particles, 3), dtype=jnp.float64),
a0=a0,
dt=0.001,
dt_last_done=0.0,
)
[docs]
@jax.jit
def add_cs(p: jnp.ndarray, csp: jnp.ndarray, inp: jnp.ndarray) -> tuple:
"""Compensated summation.
Args:
p (jnp.ndarray):
The current sum.
csp (jnp.ndarray):
The current compensation.
inp (jnp.ndarray):
The input to add.
Returns:
tuple:
The new sum and compensation.
"""
y = inp - csp
t = p + y
csp = (t - p) - y
p = t
return p, csp
@jax.jit
def _estimate_x_v_from_b(
a0: jnp.ndarray,
v0: jnp.ndarray,
x0: jnp.ndarray,
h: jnp.ndarray,
dt: jnp.ndarray,
bp: jnp.ndarray, # remember to flip it!
) -> tuple[jnp.ndarray, jnp.ndarray]:
xcoeffs = bp * dt * dt / IAS15_BX_DENOMS
x, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), xcoeffs)
x *= h * h * h
x += (v0 * dt) * h + (a0 * dt * dt / 2.0) * h * h + x0
vcoeffs = bp * dt / IAS15_BV_DENOMS
v, _ = jax.lax.scan(lambda y, _p: (y * h + _p, None), jnp.zeros_like(x0), vcoeffs)
v *= h * h
v += v0 + (a0 * dt) * h
return x, v