Source code for jorbit.integrators.ias15.step_control

"""Adaptive step-size controllers for the IAS15 integrator.

These are passed to :func:`ias15_step` as the ``step_scheduler`` callable; they are not
called internally by the step machinery.
"""

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

from jorbit.data.constants import (
    IAS15_EPSILON,
    IAS15_MIN_DT,
    IAS15_SAFETY_FACTOR,
    IAS15_EPS_Modified,
)


[docs] @jax.jit def next_proposed_dt_PRS23( a0: jnp.ndarray, at_fresh: jnp.ndarray, b: jnp.ndarray, dt_done: float, x_end: jnp.ndarray, v_end: jnp.ndarray, ) -> jnp.ndarray: """The PRS23 step controller.""" tmp = a0 + jnp.sum(b, axis=0) y2 = jnp.sum(tmp * tmp, axis=1) coeffs_1 = jnp.arange(1, 8) tmp = jnp.sum(coeffs_1[:, None, None] * b, axis=0) y3 = jnp.sum(tmp * tmp, axis=1) coeffs_2 = jnp.arange(2, 8) * jnp.arange(1, 7) tmp = jnp.sum(coeffs_2[:, None, None] * b[1:], axis=0) y4 = jnp.sum(tmp * tmp, axis=1) timescale2 = 2.0 * y2 / (y3 + jnp.sqrt(y4 * y2)) # PRS23 min_timescale2 = jnp.nanmin(timescale2) dt_new = jnp.sqrt(min_timescale2) * dt_done * IAS15_EPS_Modified return dt_new
[docs] @jax.jit def next_proposed_dt_global( a0: jnp.ndarray, at_fresh: jnp.ndarray, b: jnp.ndarray, dt_done: float, x_end: jnp.ndarray, v_end: jnp.ndarray, ) -> jnp.ndarray: """REBOUND's GLOBAL step controller (legacy, used by ASSIST). Compares the magnitude of the highest-order polynomial coefficient (`b[6]`) to the freshly-evaluated end-of-step acceleration (`at_fresh`, taken from the last predictor-corrector sub-step at h = IAS15_H[7] = 0.977). Includes REBOUND's "slow-acceleration" filter that skips particles with `v²·dt²/x² < 1e-16`, evaluated on the END-of-step predictor state (`x_end, v_end`) to match REBOUND's `particles[mi]` semantics (`integrator_ias15.c:543-558`). Falls back to `dt/safety_factor` growth when no particle contributes. Finally clamps the proposed step to `IAS15_MIN_DT`. See REBOUND `integrator_ias15.c:534-619`. ASSIST forces this mode at `assist.c:446`. """ del a0 v2 = jnp.sum(v_end * v_end, axis=1) x2 = jnp.sum(x_end * x_end, axis=1) keep = (v2 * dt_done * dt_done / x2) >= 1e-16 at_masked = jnp.where(keep[:, None], at_fresh, 0.0) b6_masked = jnp.where(keep[:, None], b[6], 0.0) maxa = jnp.max(jnp.abs(at_masked)) maxj = jnp.max(jnp.abs(b6_masked)) integrator_error = maxj / maxa dt_new = jnp.where( jnp.isfinite(integrator_error) & (integrator_error > 0), dt_done * jnp.power(IAS15_EPSILON / integrator_error, 1.0 / 7.0), dt_done / IAS15_SAFETY_FACTOR, ) dt_new = jnp.where( jnp.abs(dt_new) < IAS15_MIN_DT, jnp.sign(dt_new) * IAS15_MIN_DT, dt_new, ) return dt_new