Source code for jorbit.utils.states

"""A collection of Chex dataclasses for representing the state of a system of particles."""

import jax

jax.config.update("jax_enable_x64", True)
from dataclasses import field

import astropy.units as u
import chex
import jax.numpy as jnp
from astropy.time import Time

from jorbit import Ephemeris
from jorbit.astrometry.transformations import (
    cartesian_to_elements,
    elements_to_cartesian,
    horizons_ecliptic_to_icrs,
    icrs_to_horizons_ecliptic,
)
from jorbit.data.constants import (
    ALL_PLANET_LOG_GMS,
    SPEED_OF_LIGHT,
    TOTAL_SOLAR_SYSTEM_GM,
)

SUN_GM = jnp.exp(ALL_PLANET_LOG_GMS["sun"])


[docs] @chex.dataclass class SystemState: """Contains the state of a system of particles.""" tracer_positions: jnp.ndarray tracer_velocities: jnp.ndarray massive_positions: jnp.ndarray massive_velocities: jnp.ndarray log_gms: jnp.ndarray time: float fixed_perturber_positions: ( jnp.ndarray ) # need a leading axis! (n_substeps, n_perturbers, 3) fixed_perturber_velocities: jnp.ndarray fixed_perturber_log_gms: jnp.ndarray acceleration_func_kwargs: dict # at a minimum, {"c2": SPEED_OF_LIGHT**2}
[docs] @chex.dataclass class KeplerianState: """Contains the *barycentric* state of a particle in Keplerian elements. Angles are in degrees. Elements will not agree with those presented in a heliocentric frame. Time is represented as a ``(relative_time, time_reference)`` pair so the state is self-describing about its absolute epoch: - ``relative_time`` is the value that flows through to ``SystemState.time`` and the integrator / acceleration functions. Acceleration functions built with a non-zero ``t_ref_jd`` add it back to recover an absolute JD; standalone factories built with ``t_ref_jd=0`` use this value directly. - ``time_reference`` is the absolute JD (TDB) anchor that ``relative_time`` is measured against. ``Particle.__init__`` reads ``relative_time + time_reference`` to set the particle's reference epoch. Invariant: ``absolute_jd = relative_time + time_reference``. """ semi: float ecc: float inc: float Omega: float omega: float nu: float acceleration_func_kwargs: dict # absolute JD (TDB) anchor that relative_time is measured against time_reference: float # offset in days from time_reference; this is what `to_system` propagates # to SystemState.time relative_time: float = field(default_factory=lambda: jnp.array(0.0)) # 6x6 covariance matrix in Keplerian coordinates (semi, ecc, inc, Omega, omega, nu). # Shape (0, 0) means "not set". Not propagated through coordinate transforms. cov: jnp.ndarray = field(default_factory=lambda: jnp.empty((0, 0)))
[docs] def to_cartesian(self) -> "CartesianState": """Converts the Keplerian state to Cartesian coordinates.""" x, v = elements_to_cartesian( self.semi, self.ecc, self.nu, self.inc, self.Omega, self.omega, TOTAL_SOLAR_SYSTEM_GM, ) x = horizons_ecliptic_to_icrs(x) v = horizons_ecliptic_to_icrs(v) return CartesianState( x=x, v=v, relative_time=self.relative_time, time_reference=self.time_reference, acceleration_func_kwargs=self.acceleration_func_kwargs, )
[docs] def to_keplerian(self) -> "KeplerianState": """Convert to a Keplerian state. Does nothing- this is already a Keplerian state. Included so that both KeplerianState and CartesianState have the same interface. """ return self
[docs] def to_system(self) -> SystemState: """Converts the Keplerian state to a system state. ``SystemState.time`` is set from ``self.relative_time``; the ``time_reference`` anchor is dropped, since the acceleration function carries its own ``t_ref_jd`` for converting back to absolute JD. """ c = self.to_cartesian() return SystemState( tracer_positions=c.x, tracer_velocities=c.v, massive_positions=jnp.empty((0, 3)), massive_velocities=jnp.empty((0, 3)), log_gms=jnp.empty((0,)), time=self.relative_time, fixed_perturber_positions=jnp.empty((0, 3)), fixed_perturber_velocities=jnp.empty((0, 3)), fixed_perturber_log_gms=jnp.empty((0,)), acceleration_func_kwargs=self.acceleration_func_kwargs, )
[docs] @chex.dataclass class CartesianState: """Contains the *barycentric* state of a particle in Cartesian coordinates. Time is represented as a ``(relative_time, time_reference)`` pair so the state is self-describing about its absolute epoch. See :class:`KeplerianState` for the full convention. Invariant: ``absolute_jd = relative_time + time_reference``. """ x: jnp.ndarray v: jnp.ndarray acceleration_func_kwargs: dict # absolute JD (TDB) anchor that relative_time is measured against time_reference: float # offset in days from time_reference; this is what `to_system` propagates # to SystemState.time relative_time: float = field(default_factory=lambda: jnp.array(0.0)) # 6x6 covariance matrix in Cartesian coordinates (x0, x1, x2, v0, v1, v2). # Shape (0, 0) means "not set". Not propagated through coordinate transforms. cov: jnp.ndarray = field(default_factory=lambda: jnp.empty((0, 0)))
[docs] def to_keplerian(self) -> KeplerianState: """Converts the Cartesian state to Keplerian elements.""" x = icrs_to_horizons_ecliptic(self.x) v = icrs_to_horizons_ecliptic(self.v) a, ecc, nu, inc, Omega, omega = cartesian_to_elements( x, v, TOTAL_SOLAR_SYSTEM_GM ) return KeplerianState( semi=a, ecc=ecc, inc=inc, Omega=Omega, omega=omega, nu=nu, relative_time=self.relative_time, time_reference=self.time_reference, acceleration_func_kwargs=self.acceleration_func_kwargs, )
[docs] def to_cartesian(self) -> "CartesianState": """Convert to a Cartesian state. Does nothing- this is already a Cartesian state. Included so that both KeplerianState and CartesianState have the same interface. """ return self
[docs] def to_system(self) -> SystemState: """Converts the Cartesian state to a system state. ``SystemState.time`` is set from ``self.relative_time``; the ``time_reference`` anchor is dropped, since the acceleration function carries its own ``t_ref_jd`` for converting back to absolute JD. """ return SystemState( tracer_positions=self.x, tracer_velocities=self.v, massive_positions=jnp.empty((0, 3)), massive_velocities=jnp.empty((0, 3)), log_gms=jnp.empty((0,)), time=self.relative_time, fixed_perturber_positions=jnp.empty((0, 3)), fixed_perturber_velocities=jnp.empty((0, 3)), fixed_perturber_log_gms=jnp.empty((0,)), acceleration_func_kwargs=self.acceleration_func_kwargs, )
[docs] @chex.dataclass class IAS15IntegratorState: """Contains the state of the IAS15 integrator.""" g: jnp.ndarray b: jnp.ndarray e: jnp.ndarray csx: jnp.ndarray csv: jnp.ndarray a0: jnp.ndarray dt: float dt_last_done: float
[docs] @chex.dataclass class LeapfrogIntegratorState: """Contains the state of a leapfrog integrator.""" dt: float C: jnp.ndarray D: jnp.ndarray
def _get_sun_state( time: Time, de_ephemeris_version: str = "440" ) -> tuple[jnp.ndarray, jnp.ndarray]: """Helper to get the state vector of the Sun at a given time. Uses the local copy of JPL DE440. Args: time: astropy.time.Time The time at which to get the Sun's state. de_ephemeris_version: str The version of the JPL DE ephemeris to use. Defaults to "440", accepts "430". Returns: tuple: A tuple containing the position and velocity of the Sun in AU and AU/day, respectively. """ eph = Ephemeris( ssos="default planets", earliest_time=time - 30 * u.day, latest_time=time + 30 * u.day, de_ephemeris_version=de_ephemeris_version, ) sun_state = eph.state(time)["sun"] return sun_state
[docs] def barycentric_to_heliocentric( state: CartesianState | KeplerianState, time: Time, de_ephemeris_version: str = "440", ) -> dict: """Helper to compute heliocentric quantities from barycentric states. Use the local copy of JPL DE440 to query the state vector of the Sun at the given time. Args: state: CartesianState or KeplerianState The barycentric state to convert. time: astropy.time.Time The time at which to compute the heliocentric elements. de_ephemeris_version: str The version of the JPL DE ephemeris to use. Defaults to "440", accepts "430". Returns: dict: A dictionary containing heliocentric quantities. If the input state is Cartesian, returns 'x_helio' and 'v_helio'. If Keplerian, returns 'a_helio', 'ecc_helio', 'inc_helio', 'Omega_helio', 'omega_helio', and 'nu_helio'. """ sun_state = _get_sun_state(time=time, de_ephemeris_version=de_ephemeris_version) cart = state.to_cartesian() helio_x = cart.x - sun_state["x"].value helio_v = cart.v - sun_state["v"].value if isinstance(state, CartesianState): return {"x_helio": helio_x, "v_helio": helio_v} elif isinstance(state, KeplerianState): helio_x = icrs_to_horizons_ecliptic(helio_x) helio_v = icrs_to_horizons_ecliptic(helio_v) a_helio, ecc_helio, nu_helio, inc_helio, Omega_helio, omega_helio = ( cartesian_to_elements(helio_x, helio_v, SUN_GM) ) return { "a_helio": a_helio, "ecc_helio": ecc_helio, "inc_helio": inc_helio, "Omega_helio": Omega_helio, "omega_helio": omega_helio, "nu_helio": nu_helio, } else: raise ValueError( "state must be either a barycentric CartesianState or KeplerianState" )
[docs] def heliocentric_to_barycentric( heliocentric_dict: dict, time: Time, de_ephemeris_version: str = "440", acceleration_func_kwargs: dict = {"c2": SPEED_OF_LIGHT**2}, ) -> CartesianState | KeplerianState: """Helper to compute barycentric quantities from heliocentric states. Use the local copy of JPL DE440 to query the state vector of the Sun at the given time. Args: heliocentric_dict: dict A dictionary containing heliocentric quantities. If the input state is Cartesian, must contain 'x_helio' and 'v_helio'. If Keplerian, must contain 'a_helio', 'ecc_helio', 'inc_helio', 'Omega_helio', 'omega_helio', and 'nu_helio'. time: astropy.time.Time The time at which to compute the barycentric elements. de_ephemeris_version: str The version of the JPL DE ephemeris to use. Defaults to "440", accepts "430". acceleration_func_kwargs: dict Additional arguments to associate with the final CartesianState or KeplerianState. Defaults to {"c2": SPEED_OF_LIGHT**2}. Returns: CartesianState or KeplerianState: The barycentric state. If the input dict had Cartesian quantities, returns a CartesianState. If the inputs were Keplerian, returns a KeplerianState. """ sun_state = _get_sun_state(time=time, de_ephemeris_version=de_ephemeris_version) if "x_helio" in heliocentric_dict: cart_x = heliocentric_dict["x_helio"] + sun_state["x"].value cart_v = heliocentric_dict["v_helio"] + sun_state["v"].value return CartesianState( x=cart_x, v=cart_v, time_reference=time.tdb.jd, acceleration_func_kwargs=acceleration_func_kwargs, ) elif "a_helio" in heliocentric_dict: # Use .ravel() to handle both scalar floats (from manual user input) and # 1-D arrays (from barycentric_to_heliocentric). jnp.array([array]) would # produce a (1, N) shape for array inputs; .ravel() always gives (N,). helio_x, helio_v = elements_to_cartesian( jnp.asarray(heliocentric_dict["a_helio"]).ravel(), jnp.asarray(heliocentric_dict["ecc_helio"]).ravel(), jnp.asarray(heliocentric_dict["nu_helio"]).ravel(), jnp.asarray(heliocentric_dict["inc_helio"]).ravel(), jnp.asarray(heliocentric_dict["Omega_helio"]).ravel(), jnp.asarray(heliocentric_dict["omega_helio"]).ravel(), SUN_GM, ) cart_x = horizons_ecliptic_to_icrs(helio_x) + sun_state["x"].value cart_v = horizons_ecliptic_to_icrs(helio_v) + sun_state["v"].value state = CartesianState( x=cart_x, v=cart_v, time_reference=time.tdb.jd, acceleration_func_kwargs=acceleration_func_kwargs, ) return state.to_keplerian() else: raise ValueError( "heliocentric_dict must contain either heliocentric Cartesian or Keplerian quantities" )