Source code for jorbit.utils.generate_coefficients

"""Generate the high-precision coefficients for IAS15-style integrators."""

import jax

jax.config.update("jax_enable_x64", True)

import mpmath
from tqdm import tqdm


[docs] def create_gauss_radau_spacings(internal_points: int) -> list[mpmath.mpf]: """Compute the spacings for the Gauss-Radau quadrature rule for an arbitrary number of internal points. IAS15 uses 7 internal points, which results in a 15th order integrator. If we wanted a higher or lower order integrator, we could change the number of internal points. But, then we need to know where to actually evaluate those points: this function uses the mpmath library to compute those locations, which are the roots of the Legendre polynomial of order n-1 and n, divided by x+1. Confirmed that the resulting values match those tabulated in Table 12 of `Stroud and Secrest 1966 <https://archive.org/details/gaussianquadratu00stro/page/340/mode/2up>`_. Args: internal_points (int): The number of internal points to use. Returns: list[mpmath.mpf]: The spacings for the Gauss-Radau quadrature rule. The 'H' array in IAS15. """ # for ias15 H spacings, internal_points = 7 # matches https://archive.org/details/gaussianquadratu00stro/page/340/mode/2up mpmath.mp.dps = 75 n = internal_points + 1 # include the endpoint def f(x: mpmath.mpf) -> mpmath.mpf: return (mpmath.legendre(n - 1, x) + mpmath.legendre(n, x)) / (x + 1) slices = mpmath.linspace(-1, 1, 1000) sols = [] for i in tqdm(range(1, len(slices) - 1)): try: s = mpmath.findroot( f, (slices[i], slices[i + 1]), solver="secant", tol=mpmath.mpf("1e-50"), ) possibly_new = s not in sols if possibly_new: for sol in sols: if mpmath.fabs(sol - s) < mpmath.mpf("1e-50"): possibly_new = False break if possibly_new: assert mpmath.fabs(f(s)) < mpmath.mpf("1e-50") sols.append(s) del s except Exception: pass assert len(sols) == n - 1 sols.sort() sols = [((s + 1) / 2) for s in sols] # rescale to 0, 1 from the range -1, 1 sols = [mpmath.mpf(0), *sols] # add the endpoint return sols
[docs] def create_iasnn_r_array(h: list[mpmath.mpf]) -> list[mpmath.mpf]: """Create the equivalent of the 'R' array in IAS15 for an arbitrary-order integrator. Args: h (list[mpmath.mpf]): The spacings for the Gauss-Radau quadrature rule. Returns: list[mpmath.mpf]: The 'R' array for the integrator. """ mpmath.mp.dps = 75 n = len(h) r = [] # Calculate all pairwise differences where j > k for j in range(1, n): for k in range(j): r.append(h[j] - h[k]) return r
[docs] def create_iasnn_c_d_arrays( h: list[mpmath.mpf], ) -> tuple[list[mpmath.mpf], list[mpmath.mpf]]: """Create the equivalent of the 'C' and 'D' arrays in IAS15 for an arbitrary-order integrator. Args: h (list[mpmath.mpf]): The spacings for the Gauss-Radau quadrature rule. Returns: tuple[list[mpmath.mpf], list[mpmath.mpf]]: The 'C' and 'D' arrays for the integrator. """ mpmath.mp.dps = 75 n = len(h) size = 1 # Initial element for j in range(2, n - 1): size += j # j elements per iteration (1 + (j-2) + 1) c = [mpmath.mpf(0)] * size d = [mpmath.mpf(0)] * size # Initial values c[0] = -h[1] d[0] = h[1] idx = 0 # Current position in arrays # Main recurrence relations for j in range(2, n - 1): # First element for this j idx += 1 c[idx] = -h[j] * c[idx - j + 1] d[idx] = h[1] * d[idx - j + 1] # Middle elements for k in range(2, j): idx += 1 c[idx] = c[idx - j] - h[j] * c[idx - j + 1] d[idx] = d[idx - j] + h[k] * d[idx - j + 1] # Last element for this j idx += 1 c[idx] = c[idx - j] - h[j] d[idx] = d[idx - j] + h[j] return c, d
[docs] def create_iasnn_constants( n_internal_points: int, ) -> tuple[list[mpmath.mpf], list[mpmath.mpf], list[mpmath.mpf], list[mpmath.mpf]]: """Create the equivalent of the 'H', 'R', 'C', and 'D' arrays in IAS15 for an arbitrary-order integrator. Just a wrapper around create_gauss_radau_spacings, create_iasnn_r_array, and create_iasnn_c_d_arrays. Args: n_internal_points (int): The number of internal points to use. Returns: tuple[list[mpmath.mpf], list[mpmath.mpf], list[mpmath.mpf], list[mpmath.mpf]]: The 'H', 'R', 'C', and 'D' arrays for the integrator. """ mpmath.mp.dps = 75 h = create_gauss_radau_spacings(n_internal_points) r = create_iasnn_r_array(h) c, d = create_iasnn_c_d_arrays(h) return h, r, c, d