Source code for jorbit.utils.kepler

"""Functions for solving Kepler's equation.

This is a fork of squishyplanet/engine/kepler.py, which itself is a fork of
jaxoplanet/src/jaxoplanet/core/kepler.py, many thanks to the original authors
"""

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.interpreters import ad

from jorbit.astrometry.transformations import (
    cartesian_to_elements,
    elements_to_cartesian,
)


[docs] @jax.jit def kepler(M: float, ecc: float) -> float: """Solve Kepler's equation to compute the true anomaly. This implementation is based on that within `jaxoplanet <https://github.com/exoplanet-dev/jaxoplanet/>`_, many thanks to the authors. Args: M (Array [Radian]): Mean anomaly ecc (Array [Dimensionless]): Eccentricity Returns: Array: True anomaly in radians """ sinf, cosf = _kepler(M, ecc) # this is the only bit that's different from jaxoplanet- # puts true anomalies into the range [0, 2*pi) f = jnp.arctan2(sinf, cosf) return jnp.where(f < 0, f + 2 * jnp.pi, f)
@jax.custom_jvp def _kepler(M: float, ecc: float) -> tuple[float, float]: # Wrap into the right range M = M % (2 * jnp.pi) # We can restrict to the range [0, pi) high = jnp.pi < M M = jnp.where(high, 2 * jnp.pi - M, M) # Solve ome = 1 - ecc E = _starter(M, ecc, ome) E = _refine(M, ecc, ome, E) # Re-wrap back into the full range E = jnp.where(high, 2 * jnp.pi - E, E) # Convert to true anomaly; tan(0.5 * f) tan_half_f = jnp.sqrt((1 + ecc) / (1 - ecc)) * jnp.tan(0.5 * E) tan2_half_f = jnp.square(tan_half_f) # Then we compute sin(f) and cos(f) using: # sin(f) = 2*tan(0.5*f)/(1 + tan(0.5*f)^2), and # cos(f) = (1 - tan(0.5*f)^2)/(1 + tan(0.5*f)^2) denom = 1 / (1 + tan2_half_f) sinf = 2 * tan_half_f * denom cosf = (1 - tan2_half_f) * denom return sinf, cosf @_kepler.defjvp def _(primals: tuple, tangents: tuple) -> tuple[tuple, tuple]: M, e = primals M_dot, e_dot = tangents sinf, cosf = _kepler(M, e) # Pre-compute some things ecosf = e * cosf ome2 = 1 - e**2 def make_zero(tan: float) -> float: if type(tan) is ad.Zero: return ad.zeros_like_aval(tan.aval) else: return tan # Propagate the derivatives f_dot = make_zero(M_dot) * (1 + ecosf) ** 2 / ome2**1.5 f_dot += make_zero(e_dot) * (2 + ecosf) * sinf / ome2 return (sinf, cosf), (cosf * f_dot, -sinf * f_dot) def _starter(M: float, ecc: float, ome: float) -> float: M2 = jnp.square(M) alpha = 3 * jnp.pi / (jnp.pi - 6 / jnp.pi) alpha += 1.6 / (jnp.pi - 6 / jnp.pi) * (jnp.pi - M) / (1 + ecc) d = 3 * ome + alpha * ecc alphad = alpha * d r = (3 * alphad * (d - ome) + M2) * M q = 2 * alphad * ome - M2 q2 = jnp.square(q) w = jnp.square(jnp.cbrt(jnp.abs(r) + jnp.sqrt(q2 * q + r * r))) return (2 * r * w / (jnp.square(w) + w * q + q2) + M) / d def _refine(M: float, ecc: float, ome: float, E: float) -> float: sE = E - jnp.sin(E) cE = 1 - jnp.cos(E) f_0 = ecc * sE + E * ome - M f_1 = ecc * cE + ome f_2 = ecc * (E - sE) f_3 = 1 - f_1 d_3 = -f_0 / (f_1 - 0.5 * f_0 * f_2 / f_1) d_4 = -f_0 / (f_1 + 0.5 * d_3 * f_2 + (d_3 * d_3) * f_3 / 6) d_42 = d_4 * d_4 dE = -f_0 / (f_1 + 0.5 * d_4 * f_2 + d_4 * d_4 * f_3 / 6 - d_42 * d_4 * f_2 / 24) return E + dE
[docs] @jax.jit def M_from_f(f: float, ecc: float) -> float: """Compute the mean anomaly from the true anomaly and eccentricity. Args: f (float): True anomaly in radians. ecc (float): Eccentricity. Returns: float: Mean anomaly in radians. """ E = jnp.arctan2(jnp.sqrt(1 - ecc**2) * jnp.sin(f), ecc + jnp.cos(f)) M = E - ecc * jnp.sin(E) return jnp.where(M < 0, M + 2 * jnp.pi, M)
[docs] @jax.jit def keplerian_propagate( x0: jnp.ndarray, v0: jnp.ndarray, t0: float, times: jnp.ndarray, gm: float, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Propagate a Keplerian orbit to new times. Args: x0: Initial position (1, 3), ecliptic frame. v0: Initial velocity (1, 3), ecliptic frame. t0: Initial time (scalar, JD TDB). times: Target times (N,), JD TDB. gm: Gravitational parameter (GM) in AU^3/day^2. Returns: tuple: Positions (N, 3) and velocities (N, 3) in ecliptic frame. """ a, ecc, nu, inc, Omega, omega = cartesian_to_elements(x0, v0, gm) nu_rad = nu[0] * jnp.pi / 180 M0 = M_from_f(nu_rad, ecc[0]) n = jnp.sqrt(gm / a[0] ** 3) M_new = M0 + n * (times - t0) nu_new_rad = jax.vmap(kepler, in_axes=(0, None))(M_new, ecc[0]) nu_new_deg = nu_new_rad * 180 / jnp.pi a_arr = jnp.full_like(times, a[0]) ecc_arr = jnp.full_like(times, ecc[0]) inc_arr = jnp.full_like(times, inc[0]) Omega_arr = jnp.full_like(times, Omega[0]) omega_arr = jnp.full_like(times, omega[0]) positions, velocities = elements_to_cartesian( a_arr, ecc_arr, nu_new_deg, inc_arr, Omega_arr, omega_arr, gm ) return positions, velocities