"""The JAX-compatible functions for manipulating JPL DE ephemeris data."""
from collections.abc import Callable
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
[docs]
@jax.tree_util.register_pytree_node_class
class EphemerisProcessor:
"""A pytree-compatible class for processing JPL DE ephemeris data.
This class provides functionality to evaluate Chebyshev polynomials for computing
planetary positions and velocities from JPL Development Ephemerides (DE) data.
It is compatible with JAX's pytree system for automatic differentiation and
parallelization.
"""
def __init__(
self,
init: jnp.ndarray,
intlen: jnp.ndarray,
coeffs: jnp.ndarray,
log_gms: jnp.ndarray,
) -> None:
"""Initializes the EphemerisProcessor with Chebyshev polynomial data.
Args:
init (jnp.ndarray):
Initial epoch times for each body in Julian days TDB.
intlen (jnp.ndarray):
Length of each Chebyshev polynomial interval in seconds.
coeffs (jnp.ndarray):
Chebyshev polynomial coefficients array of shape
(3, degree+1, n_intervals) where 3 represents x,y,z coordinates.
log_gms (jnp.ndarray):
Natural log of gravitational parameters (GM) for each body in
AU^3/day^2.
"""
self.init = init
self.intlen = intlen
self.coeffs = coeffs
self.log_gms = log_gms
[docs]
def tree_flatten(self: "EphemerisProcessor") -> tuple:
"""Flattens the class instance for JAX pytree compatibility.
Returns:
tuple: A tuple of (children, auxiliary_data) where children contains
the arrays to be transformed and auxiliary_data is None.
"""
children = (self.init, self.intlen, self.coeffs, self.log_gms)
aux_data = None
return children, aux_data
[docs]
@classmethod
def tree_unflatten(cls, aux_data: None, children: tuple) -> "EphemerisProcessor":
"""Reconstructs a class instance from flattened data.
Args:
aux_data: Auxiliary data (unused).
children (tuple): Tuple of arrays to reconstruct the instance.
Returns:
EphemerisProcessor: A new instance of the class.
"""
return cls(*children)
[docs]
@jax.jit
def eval_cheby(self, coefficients: jnp.ndarray, x: float) -> tuple:
"""Evaluates a Chebyshev polynomial using Clenshaw's algorithm.
Implements Clenshaw's recurrence formula for evaluating Chebyshev polynomials
in a numerically stable way.
Args:
coefficients (jnp.ndarray): Chebyshev polynomial coefficients.
x (float): Input value in the domain [-1, 1].
Returns:
tuple:
A tuple containing (jnp.ndarray, The evaluated polynomial value;
jnp.ndarray, Intermediate values used for velocity computation).
"""
b_ii = jnp.zeros(3)
b_i = jnp.zeros(3)
def scan_func(X: tuple, a: jnp.ndarray) -> tuple:
b_i, b_ii = X
tmp = b_i
b_i = a + 2 * x * b_i - b_ii
b_ii = tmp
return (b_i, b_ii), b_i
(b_i, b_ii), s = jax.lax.scan(scan_func, (b_i, b_ii), coefficients[:-1])
return coefficients[-1] + x * b_i - b_ii, s
@jax.jit
def _individual_state(
self, init: jnp.ndarray, intlen: jnp.ndarray, coeffs: jnp.ndarray, tdb: float
) -> tuple:
"""Computes position and velocity for a single body in the ephemeris at a given time.
Args:
init (float): Initial epoch time in Julian days TDB.
intlen (float): Length of Chebyshev polynomial interval in seconds.
coeffs (jnp.ndarray): Chebyshev coefficients for the body.
tdb (float): Requested time in Julian days TDB.
Returns:
tuple:
A tuple of jnp.ndarrays, first positions in AU, then velocities in
AU/day.
"""
tdb2 = 0.0 # leaving in case we ever decide to increase the time precision and use 2 floats
_, _, n = coeffs.shape
# 2451545.0 is the J2000 epoch in TDB
index1, offset1 = jnp.divmod((tdb - 2451545.0) * 86400.0 - init, intlen)
index2, offset2 = jnp.divmod(tdb2 * 86400.0, intlen)
index3, offset = jnp.divmod(offset1 + offset2, intlen)
index = (index1 + index2 + index3).astype(int)
omegas = index == n
index = jnp.where(omegas, index - 1, index)
offset = jnp.where(omegas, offset + intlen, offset)
coefficients = coeffs[:, :, index]
s = 2.0 * offset / intlen - 1.0
# Position
x, As = self.eval_cheby(coefficients, s) # in km here
# Velocity
Q = self.eval_cheby(2 * As, s)
v = Q[0] - As[-1]
v /= intlen
v *= 2.0 # in km/s here
# # Acceleration
# a = self.eval_cheby(4 * Q[1], s)[0] - 2 * Q[1][-1]
# a /= intlen**2
# a *= 4.0 # in km/s^2 here
# Convert to AU, AU/day, AU/day^2
return (
x.T * 6.684587122268446e-09,
v.T * 0.0005775483273639937,
# a.T * 49.900175484249054,
)
[docs]
@jax.jit
def state(self, tdb: float) -> tuple:
"""Computes positions and velocities for all bodies in the ephemeris at a given time.
Args:
tdb (float): Requested time in Julian days TDB (Barycentric Dynamical Time).
Returns:
tuple:
A tuple of jnp.ndarrays, first positions in AU, then velocities in AU/day.
"""
x, v = jax.vmap(self._individual_state, in_axes=(0, 0, 0, None))(
self.init, self.intlen, self.coeffs, tdb
)
return x, v
[docs]
@jax.tree_util.register_pytree_node_class
class EphemerisPostProcessor:
"""A pytree-compatible class for post-processing multiple ephemeris calculations.
Useful when one ephemeris (e.g. the asteroids) provides coordinates based on the
positions of an object in another (e.g. the sun).
Attributes:
log_gms (jnp.ndarray):
Concatenated array of natural log of gravitational parameters (GM) from all
ephemeris processors.
"""
def __init__(self, ephs: list, postprocessing_func: Callable) -> None:
"""Initializes the EphemerisPostProcessor with multiple ephemeris processors.
Args:
ephs (list[EphemerisProcessor]):
List of ephemeris processor instances.
postprocessing_func (callable):
Function to process the combined state vectors.
"""
self.ephs = ephs
self.postprocessing_func = postprocessing_func
log_gms = jnp.empty(0)
for eph in ephs:
log_gms = jnp.concatenate([log_gms, eph.log_gms])
self.log_gms = log_gms
[docs]
def tree_flatten(self: "EphemerisPostProcessor") -> tuple:
"""Flattens the class instance for JAX pytree compatibility.
Returns:
tuple: A tuple of (children, auxiliary_data) where children contains
the ephemeris processors and post-processing function, and
auxiliary_data is None.
"""
children = (self.ephs, self.postprocessing_func)
aux_data = None
return children, aux_data
[docs]
@classmethod
def tree_unflatten(
cls, aux_data: None, children: tuple
) -> "EphemerisPostProcessor":
"""Reconstructs a class instance from flattened data.
Args:
aux_data: Auxiliary data (unused).
children (tuple): Tuple containing ephemeris processors and
post-processing function.
Returns:
EphemerisPostProcessor: A new instance of the class.
"""
return cls(*children)
[docs]
@jax.jit
def state(self, tdb: float) -> tuple:
"""Computes combined and post-processed state vectors for all bodies.
This method:
1. Initializes empty arrays for position and velocity
2. Computes states for each ephemeris processor
3. Concatenates the results vertically
4. Applies the post-processing function to the combined states
Args:
tdb (float): Requested time in Julian days TDB (Barycentric Dynamical Time).
Returns:
tuple:
The post-processed state vectors. The exact structure depends on
the post-processing function, but typically includes
- jnp.ndarray: Processed position vectors, shape (n_total_bodies, 3)
- jnp.ndarray: Processed velocity vectors, shape (n_total_bodies, 3)
"""
x = jnp.empty((0, 3))
v = jnp.empty((0, 3))
for eph in self.ephs:
_x, _v = eph.state(tdb)
x = jnp.vstack([x, _x])
v = jnp.vstack([v, _v])
return self.postprocessing_func(x, v)