Source code for jorbit.integrators.ias15.interpolation

"""Dense-output interpolation and light-travel-time utilities for IAS15.

Evaluate the converged 7th-order IAS15 polynomial at arbitrary times within completed
steps, without re-integrating.
"""

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp

from jorbit.integrators.ias15.helpers import _estimate_x_v_from_b


[docs] def precompute_interpolation_indices( t_step_starts: jnp.ndarray, dts: jnp.ndarray, query_times: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Precompute the step indices and fractional times for interpolation. Call this once during setup, then pass the results into interpolate_from_dense_output to avoid redundant searchsorted calls inside the JIT'd residuals function. Args: t_step_starts (jnp.ndarray): Start time of each step, shape (n_steps,). dts (jnp.ndarray): Per-step time step sizes, shape (n_steps,). query_times (jnp.ndarray): Times at which to interpolate, shape (n_queries,). Handles both integration directions. ``jnp.searchsorted`` requires an ascending sequence, but a backward integration (negative ``dts``) produces a *descending* ``t_step_starts``, so the lookup is done in direction-normalized coordinates. Unfilled buffer slots carry a large positive ``dts`` sentinel; their key is forced past every real step so valid queries always route into the filled prefix. Returns: tuple[jnp.ndarray, jnp.ndarray]: step_indices: Integer index of the containing step for each query time, shape (n_queries,). h_values: Fractional time within each step (0 to 1), shape (n_queries,). """ direction = jnp.sign(dts[0]) filled = jnp.abs(dts) < 1e29 key = jnp.where(filled, direction * t_step_starts, jnp.inf) step_indices = jnp.searchsorted(key, direction * query_times, side="right") - 1 # A query at the integration epoch (zero-span: every dts is the unfilled sentinel, or # a query landing exactly on the first step start) keys past every slot and yields -1, # which would index the zero-filled buffer tail (origin). Clamp it into slot 0, whose # state buffers always hold the initial condition. step_indices = jnp.maximum(step_indices, 0) h_values = (query_times - t_step_starts[step_indices]) / dts[step_indices] return step_indices, h_values
[docs] def make_ltt_propagator( b_step: jnp.ndarray, a0_step: jnp.ndarray, x0_step: jnp.ndarray, v0_step: jnp.ndarray, dt_step: jnp.ndarray, h_obs: jnp.ndarray, ) -> jax.tree_util.Partial: """Build a closure that evaluates the IAS15 polynomial at a light-travel-delayed time. Used inside ``on_sky`` to propagate a particle backward by the light travel time using the converged 7th-order Hermite polynomial for the step containing the observation time, instead of a constant-acceleration Taylor expansion. The returned closure maps a (negative) time offset ``dt`` to the particle's position at fractional position ``h_obs + dt / dt_step`` within the step. It accepts ``h`` slightly outside ``[0, 1]`` (i.e. it will extrapolate within the same step's polynomial) — typically only by a small amount, since the LTT is much shorter than ``dt_step`` for normal solar-system geometries. For close flybys with very small steps where LTT may exceed dt_step, this still gives a much higher-order correction than the constant-acceleration Taylor. Args: b_step (jnp.ndarray): Converged b coefficients for this step (single particle slice), shape (7, 3). a0_step (jnp.ndarray): Acceleration at the start of this step, shape (3,). x0_step (jnp.ndarray): Position at the start of this step, shape (3,). v0_step (jnp.ndarray): Velocity at the start of this step, shape (3,). dt_step (jnp.ndarray): Length of this step (scalar). h_obs (jnp.ndarray): Fractional position of the observation time within this step, in ``[0, 1]`` (scalar). Returns: jax.tree_util.Partial: A pytree-friendly callable ``f(dt) -> x_at_delayed_time`` of shape (3,). """ # _estimate_x_v_from_b assumes a per-particle axis (IAS15_BX_DENOMS broadcasts # against shape (7, n_particles, 3)). Add a singleton particle axis here and # strip it in the output so callers can work with plain (3,) / (7, 3) shapes. bp = b_step[::-1][:, None, :] a0 = a0_step[None, :] v0 = v0_step[None, :] x0 = x0_step[None, :] def f(dt: jnp.ndarray) -> jnp.ndarray: h = h_obs + dt / dt_step x_at_delayed_time, _ = _estimate_x_v_from_b(a0, v0, x0, h, dt_step, bp) return x_at_delayed_time[0] return jax.tree_util.Partial(f)
[docs] @jax.jit def interpolate_from_dense_output( b_all: jnp.ndarray, a0_all: jnp.ndarray, x0_all: jnp.ndarray, v0_all: jnp.ndarray, dts: jnp.ndarray, step_indices: jnp.ndarray, h_values: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Interpolate positions and velocities at arbitrary times from stored IAS15 polynomial data. Uses the b coefficients from completed IAS15 steps to evaluate the 7th-order polynomial at fractional times within each step, without re-integrating. The step_indices and h_values should be precomputed via precompute_interpolation_indices. Since they depend only on the fixed step structure and observation times (not the particle state), precomputing them keeps searchsorted out of the JIT graph and avoids redundant work on every forward and backward pass. Args: b_all (jnp.ndarray): Per-step b coefficients, shape (n_steps, 7, n_particles, 3). a0_all (jnp.ndarray): Per-step initial accelerations, shape (n_steps, n_particles, 3). x0_all (jnp.ndarray): Per-step initial positions, shape (n_steps, n_particles, 3). v0_all (jnp.ndarray): Per-step initial velocities, shape (n_steps, n_particles, 3). dts (jnp.ndarray): Per-step time step sizes, shape (n_steps,). step_indices (jnp.ndarray): Index of the containing step for each query time, shape (n_queries,). From precompute_interpolation_indices. h_values (jnp.ndarray): Fractional time within each step (0 to 1), shape (n_queries,). From precompute_interpolation_indices. Returns: tuple[jnp.ndarray, jnp.ndarray]: Interpolated positions and velocities, each shape (n_queries, n_particles, 3). """ b = b_all[step_indices] a0 = a0_all[step_indices] x0 = x0_all[step_indices] v0 = v0_all[step_indices] dt = dts[step_indices] positions, velocities = jax.vmap( lambda a, v, x, _h, _dt, _b: _estimate_x_v_from_b(a, v, x, _h, _dt, _b[::-1]) )(a0, v0, x0, h_values, dt, b) return positions, velocities