IAS15#

A JAX implementation of the IAS15 integrator.

This is a pythonized/jaxified version of the IAS15 integrator from Rein & Spiegel (2015) (DOI: 10.1093/mnras/stu2164), currently implemented in REBOUND. It used to follow the implementation found in the REBOUND source as closely as possible; see < v1.2 for that.

The original code is available on github. Accessed Summer 2023, re-visited Fall 2024. Refactored early 2026.

Many thanks to the REBOUND developers for their work on this integrator, and for making it open source!

The implementation is split across sibling modules: helpers (low-level primitives), step_control (adaptive step-size controllers), interpolation (dense-output / light-travel-time utilities), step (the single-step predictor-corrector), and evolve (the driving loops). The names below are re-exported so that jorbit.integrators.ias15.<name> continues to resolve as before the split.

jorbit.integrators.ias15.add_cs(p, csp, inp)[source]#

Compensated summation.

Parameters:
  • p (Array) – The current sum.

  • csp (Array) – The current compensation.

  • inp (Array) – The input to add.

Returns:

The new sum and compensation.

Return type:

tuple

jorbit.integrators.ias15.ias15_evolve(initial_system_state, acceleration_func, times, initial_integrator_state, step_scheduler)[source]#

Evolve a system and recover positions/velocities at times via interpolation.

Takes natural adaptive IAS15 steps from the initial time past jnp.max(times), stores the per-step dense output (converged 7th-order b coefficients plus start- of-step acceleration/position/velocity) in a pre-allocated buffer, then evaluates the polynomial at each entry of times. This matches the approach used by ASSIST/REBOUND and avoids the small final jumps that forced-landing integration is prone to.

Supports forward-mode AD only (jax.lax.while_loop has no reverse-mode rule).

Parameters:
  • initial_system_state (SystemState) – The initial state of the system.

  • acceleration_func (Callable[[SystemState], Array]) – The acceleration function to use.

  • times (Array) – Times at which to return interpolated positions and velocities. Must be within [initial_system_state.relative_time, t_end_of_last_natural_step].

  • initial_integrator_state (IAS15IntegratorState) – The initial state of the integrator.

  • step_scheduler (Callable[[Array, Array, Array, float, Array, Array], float]) – The step scheduler function to use for determining the next proposed step size.

Returns:

Interpolated positions and velocities at times, the final system state, the final integrator state, and the iteration count.

Return type:

tuple[Array, Array, SystemState, IAS15IntegratorState, Array]

jorbit.integrators.ias15.ias15_evolve_forced_landing(initial_system_state, acceleration_func, times, initial_integrator_state, step_scheduler)[source]#

Forced-landing IAS15 evolve (internal testing reference only).

Clamps the adaptive step size so that a step always lands exactly on the next entry of times. Kept private because the public ias15_evolve (below) uses dense-output polynomial interpolation instead, which avoids the small final jumps that the forced-landing scheme is prone to. This function is retained as an independent reference path for tests and benchmarks.

Warning

Caps the number of steps between requested times at 10,000.

Parameters:
  • initial_system_state (SystemState) – The initial state of the system.

  • acceleration_func (Callable[[SystemState], Array]) – The acceleration function to use.

  • times (Array) – The times to evolve the system to.

  • initial_integrator_state (IAS15IntegratorState) – The initial state of the integrator.

  • step_scheduler (Callable[[Array, Array, Array, float, Array, Array], float]) – The step scheduler function to use for determining the next proposed step size.

Returns:

The positions and velocities of the system at each timestep, the final state of the system, and the final state of the integrator.

Return type:

tuple[Array, Array, SystemState, IAS15IntegratorState]

jorbit.integrators.ias15.ias15_evolve_with_dense_output(initial_system_state, acceleration_func, times, initial_integrator_state, step_scheduler)[source]#

Evolve a system, returning interpolated states plus the underlying dense-output buffers.

Same integration logic as ias15_evolve(), but in addition to the interpolated positions and velocities at times it returns the converged 7th-order b coefficients plus the start-of-step state for every step. Callers that want to do their own polynomial evaluation (e.g. richer light-travel-time correction in on_sky() via make_ltt_propagator()) should use this function instead of ias15_evolve().

Returns:

(positions, velocities, final_system_state, final_integrator_state, iter_num, b_buf, a0_buf, x0_buf, v0_buf, dts_buf, t_step_starts, step_indices, h_values). b_buf has shape (IAS15_MAX_DYNAMIC_STEPS, 7, n_particles, 3); a0_buf, x0_buf, v0_buf have shape (IAS15_MAX_DYNAMIC_STEPS, n_particles, 3); dts_buf and t_step_starts have shape (IAS15_MAX_DYNAMIC_STEPS,); step_indices and h_values have shape (len(times),).

Return type:

tuple

Parameters:
  • initial_system_state (SystemState)

  • acceleration_func (collections.abc.Callable[[SystemState], jax.Array])

  • times (jax.Array)

  • initial_integrator_state (IAS15IntegratorState)

  • step_scheduler (collections.abc.Callable[[jax.Array, jax.Array, jax.Array, float, jax.Array, jax.Array], float])

jorbit.integrators.ias15.ias15_step(initial_system_state, acceleration_func, initial_integrator_state, step_scheduler)[source]#

Take a single step using the IAS15 integrator.

Contains all of the predictor/corrector logic and step validity checks. Does not accept any pre-computed perturber information, since we don’t know the times this will be needed until runtime. For a static version that accepts pre-computed perturber data, see ias15_static_step.

Parameters:
  • initial_system_state (SystemState) – The initial system state.

  • acceleration_func (Callable[[SystemState], Array]) – The acceleration function.

  • initial_integrator_state (IAS15IntegratorState) – The initial integrator state.

  • step_scheduler (Callable[[Array, Array, Array, float, Array, Array], float]) – The step scheduler function, which is either going to be next_proposed_dt_PRS23 or next_proposed_dt_global

Returns:

The new system state, the new integrator state (with the predicted next-step b coefficients), and the converged b coefficients for the step just completed, shape (7, n_particles, 3). The converged b is what should be stored when building dense output for interpolation.

Return type:

SystemState

jorbit.integrators.ias15.initialize_ias15_integrator_state(a0)[source]#

Initializes the IAS15IntegratorState dataclass with zeros.

Parameters:

a0 (Array) – The initial acceleration.

Returns:

An instance of the IAS15IntegratorState dataclass with zeros.

Return type:

IAS15IntegratorState

jorbit.integrators.ias15.interpolate_from_dense_output(b_all, a0_all, x0_all, v0_all, dts, step_indices, h_values)[source]#

Interpolate positions and velocities at arbitrary times from stored IAS15 polynomial data.

Uses the b coefficients from completed IAS15 steps to evaluate the 7th-order polynomial at fractional times within each step, without re-integrating.

The step_indices and h_values should be precomputed via precompute_interpolation_indices. Since they depend only on the fixed step structure and observation times (not the particle state), precomputing them keeps searchsorted out of the JIT graph and avoids redundant work on every forward and backward pass.

Parameters:
  • b_all (Array) – Per-step b coefficients, shape (n_steps, 7, n_particles, 3).

  • a0_all (Array) – Per-step initial accelerations, shape (n_steps, n_particles, 3).

  • x0_all (Array) – Per-step initial positions, shape (n_steps, n_particles, 3).

  • v0_all (Array) – Per-step initial velocities, shape (n_steps, n_particles, 3).

  • dts (Array) – Per-step time step sizes, shape (n_steps,).

  • step_indices (Array) – Index of the containing step for each query time, shape (n_queries,). From precompute_interpolation_indices.

  • h_values (Array) – Fractional time within each step (0 to 1), shape (n_queries,). From precompute_interpolation_indices.

Returns:

Interpolated positions and velocities, each shape (n_queries, n_particles, 3).

Return type:

tuple[Array, Array]

jorbit.integrators.ias15.make_ltt_propagator(b_step, a0_step, x0_step, v0_step, dt_step, h_obs)[source]#

Build a closure that evaluates the IAS15 polynomial at a light-travel-delayed time.

Used inside on_sky to propagate a particle backward by the light travel time using the converged 7th-order Hermite polynomial for the step containing the observation time, instead of a constant-acceleration Taylor expansion.

The returned closure maps a (negative) time offset dt to the particle’s position at fractional position h_obs + dt / dt_step within the step. It accepts h slightly outside [0, 1] (i.e. it will extrapolate within the same step’s polynomial) — typically only by a small amount, since the LTT is much shorter than dt_step for normal solar-system geometries. For close flybys with very small steps where LTT may exceed dt_step, this still gives a much higher-order correction than the constant-acceleration Taylor.

Parameters:
  • b_step (Array) – Converged b coefficients for this step (single particle slice), shape (7, 3).

  • a0_step (Array) – Acceleration at the start of this step, shape (3,).

  • x0_step (Array) – Position at the start of this step, shape (3,).

  • v0_step (Array) – Velocity at the start of this step, shape (3,).

  • dt_step (Array) – Length of this step (scalar).

  • h_obs (Array) – Fractional position of the observation time within this step, in [0, 1] (scalar).

Returns:

A pytree-friendly callable f(dt) -> x_at_delayed_time of shape (3,).

Return type:

Partial

jorbit.integrators.ias15.next_proposed_dt_PRS23(a0, at_fresh, b, dt_done, x_end, v_end)[source]#

The PRS23 step controller.

Return type:

Array

Parameters:
  • a0 (jax.Array)

  • at_fresh (jax.Array)

  • b (jax.Array)

  • dt_done (float)

  • x_end (jax.Array)

  • v_end (jax.Array)

jorbit.integrators.ias15.next_proposed_dt_global(a0, at_fresh, b, dt_done, x_end, v_end)[source]#

REBOUND’s GLOBAL step controller (legacy, used by ASSIST).

Compares the magnitude of the highest-order polynomial coefficient (b[6]) to the freshly-evaluated end-of-step acceleration (at_fresh, taken from the last predictor-corrector sub-step at h = IAS15_H[7] = 0.977). Includes REBOUND’s “slow-acceleration” filter that skips particles with v²·dt²/x² < 1e-16, evaluated on the END-of-step predictor state (x_end, v_end) to match REBOUND’s particles[mi] semantics (integrator_ias15.c:543-558). Falls back to dt/safety_factor growth when no particle contributes. Finally clamps the proposed step to IAS15_MIN_DT. See REBOUND integrator_ias15.c:534-619. ASSIST forces this mode at assist.c:446.

Return type:

Array

Parameters:
  • a0 (jax.Array)

  • at_fresh (jax.Array)

  • b (jax.Array)

  • dt_done (float)

  • x_end (jax.Array)

  • v_end (jax.Array)

jorbit.integrators.ias15.precompute_interpolation_indices(t_step_starts, dts, query_times)[source]#

Precompute the step indices and fractional times for interpolation.

Call this once during setup, then pass the results into interpolate_from_dense_output to avoid redundant searchsorted calls inside the JIT’d residuals function.

Parameters:
  • t_step_starts (Array) – Start time of each step, shape (n_steps,).

  • dts (Array) – Per-step time step sizes, shape (n_steps,).

  • query_times (Array) – Times at which to interpolate, shape (n_queries,).

Return type:

tuple[jax.Array, jax.Array]

Handles both integration directions. jnp.searchsorted requires an ascending sequence, but a backward integration (negative dts) produces a descending t_step_starts, so the lookup is done in direction-normalized coordinates. Unfilled buffer slots carry a large positive dts sentinel; their key is forced past every real step so valid queries always route into the filled prefix.

Returns:

step_indices: Integer index of the containing step for each query time,

shape (n_queries,).

h_values: Fractional time within each step (0 to 1),

shape (n_queries,).

Return type:

tuple[Array, Array]

Parameters:
  • t_step_starts (jax.Array)

  • dts (jax.Array)

  • query_times (jax.Array)