"""The single-step IAS15 predictor-corrector and its helpers."""
import jax
jax.config.update("jax_enable_x64", True)
from collections.abc import Callable
import jax.numpy as jnp
from jorbit.data.constants import (
EPSILON,
IAS15_BEZIER_COEFFS,
IAS15_D_MATRIX,
IAS15_H,
IAS15_SAFETY_FACTOR,
IAS15_sub_cs,
IAS15_sub_rs,
)
from jorbit.integrators.ias15.helpers import _estimate_x_v_from_b, add_cs
from jorbit.utils.states import IAS15IntegratorState, SystemState
@jax.jit
def _refine_sub_g(
at: jnp.ndarray, a0: jnp.ndarray, previous_gs: jnp.ndarray, r: jnp.ndarray
) -> jnp.ndarray:
def scan_body(carry: tuple, scan_over: tuple) -> tuple:
result = carry
g, r_sub = scan_over
result = (result - g) * r_sub
return result, None
initial_result = (at - a0) * r[0]
new_g, _ = jax.lax.scan(scan_body, initial_result, (previous_gs, r[1:]))
return new_g
@jax.jit
def _update_bs(
current_bs: jnp.ndarray,
current_csbs: jnp.ndarray,
g_diff: jnp.ndarray,
c: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
return add_cs(current_bs, current_csbs, (g_diff[None, :] * c[:, None, None]))
@jax.jit
def _predict_next_step(ratio: float, e: jnp.ndarray, b: jnp.ndarray) -> tuple:
def large_ratio(ratio: float, e: jnp.ndarray, b: jnp.ndarray) -> tuple:
# On saturated growth, zero only `e` and keep `b` as the starting point
# for the next step's PC iteration.
e_new = jnp.zeros_like(e)
return e_new, b
def reasonable_ratio(ratio: float, e: jnp.ndarray, b: jnp.ndarray) -> tuple:
qs = ratio ** jnp.arange(1, 8)
diff = b - e
e = jnp.einsum("i,ij,j...->i...", qs, IAS15_BEZIER_COEFFS, b)
b = e + diff
return e, b
e, b = jax.lax.cond(
ratio >= 1 / IAS15_SAFETY_FACTOR, large_ratio, reasonable_ratio, ratio, e, b
)
return e, b
[docs]
@jax.jit
def ias15_step(
initial_system_state: SystemState,
acceleration_func: Callable[[SystemState], jnp.ndarray],
initial_integrator_state: IAS15IntegratorState,
step_scheduler: Callable[
[jnp.ndarray, jnp.ndarray, jnp.ndarray, float, jnp.ndarray, jnp.ndarray], float
],
) -> SystemState:
"""Take a single step using the IAS15 integrator.
Contains all of the predictor/corrector logic and step validity checks. Does not
accept any pre-computed perturber information, since we don't know the times this
will be needed until runtime. For a static version that accepts pre-computed
perturber data, see ias15_static_step.
Args:
initial_system_state (SystemState):
The initial system state.
acceleration_func (Callable[[SystemState], jnp.ndarray]):
The acceleration function.
initial_integrator_state (IAS15IntegratorState):
The initial integrator state.
step_scheduler (Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, float, jnp.ndarray, jnp.ndarray], float]):
The step scheduler function, which is either going to be
next_proposed_dt_PRS23 or next_proposed_dt_global
Returns:
tuple[SystemState, IAS15IntegratorState, jnp.ndarray]:
The new system state, the new integrator state (with the *predicted*
next-step b coefficients), and the *converged* b coefficients for the
step just completed, shape (7, n_particles, 3). The converged b is
what should be stored when building dense output for interpolation.
"""
t_beginning = initial_system_state.relative_time
# The absolute-JD anchor is constant across the step; the integrator marches
# relative_time and carries time_reference through unchanged.
t_ref = initial_system_state.time_reference
M = initial_system_state.massive_positions.shape[0]
x0 = jnp.concatenate(
(initial_system_state.massive_positions, initial_system_state.tracer_positions)
)
v0 = jnp.concatenate(
(
initial_system_state.massive_velocities,
initial_system_state.tracer_velocities,
)
)
dt = initial_integrator_state.dt
a0 = initial_integrator_state.a0
csx = initial_integrator_state.csx
csv = initial_integrator_state.csv
e = initial_integrator_state.e
b = initial_integrator_state.b
csb = jnp.zeros_like(b)
g = jnp.einsum("ij,jnk->ink", IAS15_D_MATRIX, b)
def _do_nothing(
b: jnp.ndarray,
csb: jnp.ndarray,
g: jnp.ndarray,
predictor_corrector_error: jnp.ndarray,
at_last: jnp.ndarray,
x_end: jnp.ndarray,
v_end: jnp.ndarray,
) -> tuple:
# jax.debug.print("just chillin")
return (
b,
csb,
g,
predictor_corrector_error,
predictor_corrector_error,
at_last,
x_end,
v_end,
)
def _predictor_corrector_iteration(
b: jnp.ndarray,
csb: jnp.ndarray,
g: jnp.ndarray,
predictor_corrector_error: float,
at_last: jnp.ndarray,
x_end: jnp.ndarray,
v_end: jnp.ndarray,
) -> tuple:
# jax.debug.print("PC iteration starting")
del at_last, x_end, v_end
predictor_corrector_error_last = predictor_corrector_error
predictor_corrector_error = 0.0
for n, h, c, r in zip(
range(1, 8), IAS15_H[1:], IAS15_sub_cs, IAS15_sub_rs, strict=True
):
# jax.debug.print(" pc iter {n}: g={g}", n=n, g=g)
step_time = t_beginning + dt * h
x, v = _estimate_x_v_from_b(
a0=a0,
v0=v0,
x0=x0,
h=h,
dt=dt,
bp=b[::-1],
)
# note that the fixed perturber bits likely can/will be overwritten by the
# acceleration function- see ias15_static_step + create_static_default_acceleration_func
acc_state = SystemState(
massive_positions=x[:M],
massive_velocities=v[:M],
tracer_positions=x[M:],
tracer_velocities=v[M:],
log_gms=initial_system_state.log_gms,
time_reference=t_ref,
relative_time=step_time,
fixed_perturber_positions=jnp.empty(
(0, 3),
),
fixed_perturber_velocities=jnp.empty(
(0, 3),
),
fixed_perturber_log_gms=jnp.empty((0,)),
acceleration_func_kwargs=initial_system_state.acceleration_func_kwargs,
)
at = acceleration_func(acc_state)
g_old = g[n - 1]
g_new = _refine_sub_g(at, a0, g[: n - 1], r)
g_diff = g_new - g_old
# jax.debug.print(" min/max g_diff: {x}, {y}", x=jnp.max(g_diff), y=jnp.min(g_diff))
new_bs, new_csbs = _update_bs(b[:n], csb[:n], g_diff, c)
g = g.at[n - 1].set(g_new)
b = b.at[:n].set(new_bs)
csb = csb.at[:n].set(new_csbs)
maxa = jnp.max(jnp.abs(at))
maxb6tmp = jnp.max(jnp.abs(g_diff))
# jax.debug.print("maxa: {maxa}, maxb6tmp: {maxb6tmp}", maxa=maxa, maxb6tmp=maxb6tmp)
predictor_corrector_error = jnp.abs(maxb6tmp / maxa)
# jax.debug.print("PC iteration error: {error}\n\n", error=predictor_corrector_error)
# `at`, `x`, `v` here are from the last sub-step (n=7, h=IAS15_H[7]=0.977),
# i.e. the freshly-evaluated end-of-step acceleration and predictor state.
# REBOUND's GLOBAL controller uses these (integrator_ias15.c:382-385, 547).
return (
b,
csb,
g,
predictor_corrector_error,
predictor_corrector_error_last,
at,
x,
v,
)
def scan_func(carry: tuple, scan_over: int) -> tuple:
(
b,
csb,
g,
predictor_corrector_error,
predictor_corrector_error_last,
at_last,
x_end,
v_end,
) = carry
condition = (predictor_corrector_error < EPSILON) | (
(scan_over > 2)
& (predictor_corrector_error > predictor_corrector_error_last)
)
carry = jax.lax.cond(
condition,
_do_nothing,
_predictor_corrector_iteration,
b,
csb,
g,
predictor_corrector_error,
at_last,
x_end,
v_end,
)
return carry, None
initial_carry = (b, csb, g, 1e300, 2.0, a0, x0, v0)
(b, csb, g, _pc_error, _pc_error_last, at_final, x_end, v_end), _ = jax.lax.scan(
scan_func, initial_carry, jnp.arange(12)
)
dt_done = dt
next_dt = step_scheduler(a0, at_final, b, dt, x_end, v_end)
def step_too_ambitious(
x0: jnp.ndarray,
v0: jnp.ndarray,
csx: jnp.ndarray,
csv: jnp.ndarray,
dt_done: float,
next_dt: float,
) -> tuple:
dt_done = 0.0
return x0, v0, dt_done, next_dt
def step_was_good(
x0: jnp.ndarray,
v0: jnp.ndarray,
csx: jnp.ndarray,
csv: jnp.ndarray,
dt_done: float,
next_dt: float,
) -> tuple:
safe_next_dt = jnp.where(
next_dt / dt_done > 1 / IAS15_SAFETY_FACTOR,
dt_done / IAS15_SAFETY_FACTOR,
next_dt,
)
x0, csx = add_cs(x0, csx, b[6] / 72.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[5] / 56.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[4] / 42.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[3] / 30.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[2] / 20.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[1] / 12.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, b[0] / 6.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, a0 / 2.0 * dt_done * dt_done)
x0, csx = add_cs(x0, csx, v0 * dt_done)
v0, csv = add_cs(v0, csv, b[6] / 8.0 * dt_done)
v0, csv = add_cs(v0, csv, b[5] / 7.0 * dt_done)
v0, csv = add_cs(v0, csv, b[4] / 6.0 * dt_done)
v0, csv = add_cs(v0, csv, b[3] / 5.0 * dt_done)
v0, csv = add_cs(v0, csv, b[2] / 4.0 * dt_done)
v0, csv = add_cs(v0, csv, b[1] / 3.0 * dt_done)
v0, csv = add_cs(v0, csv, b[0] / 2.0 * dt_done)
v0, csv = add_cs(v0, csv, a0 * dt_done)
return x0, v0, dt_done, safe_next_dt
x0, v0, dt_done, next_dt = jax.lax.cond(
jnp.abs(next_dt / dt_done) < IAS15_SAFETY_FACTOR,
step_too_ambitious,
step_was_good,
x0,
v0,
csx,
csv,
dt_done,
next_dt,
)
new_system_state = SystemState(
massive_positions=x0[:M],
massive_velocities=v0[:M],
tracer_positions=x0[M:],
tracer_velocities=v0[M:],
log_gms=initial_system_state.log_gms,
time_reference=t_ref,
relative_time=t_beginning + dt_done,
fixed_perturber_positions=initial_system_state.fixed_perturber_positions * 0,
fixed_perturber_velocities=initial_system_state.fixed_perturber_velocities * 0,
fixed_perturber_log_gms=initial_system_state.fixed_perturber_log_gms * 0,
acceleration_func_kwargs=initial_system_state.acceleration_func_kwargs,
)
# On rejection (dt_done == 0), force ratio into the large_ratio no-op
# branch of _predict_next_step (zeros e, keeps b).
ratio = jnp.where(dt_done == 0.0, 100.0, next_dt / dt_done)
predicted_next_e, predicted_next_b = _predict_next_step(ratio, e, b)
new_integrator_state = IAS15IntegratorState(
g=g,
b=predicted_next_b,
e=predicted_next_e,
csx=csx,
csv=csv,
a0=acceleration_func(new_system_state),
dt=next_dt,
dt_last_done=dt_done,
)
return new_system_state, new_integrator_state, b