Source code for jorbit.system.system

"""The System class.

The jitted/host helper functions that implement each integration branch live in sibling
modules of this subpackage:

- :mod:`jorbit.system.ephem` (generic integrate/ephem, leapfrog)
- :mod:`jorbit.system.ias15_dense` (IAS15 ``interpolate=True``)
- :mod:`jorbit.system.keplerian` (analytic two-body path)
"""

from __future__ import annotations

from collections.abc import Callable

import astropy.units as u
import jax
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_gr_ephemeris_acceleration_func,
    create_newtonian_ephemeris_acceleration_func,
    newtonian_gravity,
    ppn_gravity,
)
from jorbit.data.constants import Y4_C, Y4_D, Y6_C, Y6_D, Y8_C, Y8_D
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
    create_leapfrog_times,
    ias15_evolve,
    initialize_ias15_integrator_state,
    leapfrog_evolve,
    next_proposed_dt_global,
    next_proposed_dt_PRS23,
    stitched_interpolate,
)
from jorbit.system.ephem import _ephem, _integrate
from jorbit.system.ias15_dense import _ephem_ias15_stitched
from jorbit.system.keplerian import _keplerian_system_ephem, _keplerian_system_integrate
from jorbit.utils.horizons import get_observer_positions
from jorbit.utils.states import (
    IAS15IntegratorState,
    LeapfrogIntegratorState,
    SystemState,
)


[docs] class System: """A system of particles in the solar system. Very similar in spirit to the `Particle` class, but now for multiple massless particles. """ def __init__( self, particles: list | None = None, state: SystemState | None = None, gravity: str | Callable = "default solar system", de_ephemeris_version: str | None = "440", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), max_step_size: u.Quantity | None = None, ) -> None: """Initialize a System. Args: particles (list, optional): A list of Particle objects. None if state is provided. Defaults to None. state (SystemState, optional): A SystemState object. None if particles is provided. Defaults to None. gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions when using one of the built-in gravity models. Accepts either "440" or "430", default is "440". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. """ self._earliest_time = earliest_time self._latest_time = latest_time self._de_ephemeris_version = de_ephemeris_version self._is_keplerian = gravity == "keplerian" # Mirrors the Particle rebase: the JAX-visible state always carries # relative_time=0.0, the absolute epoch lives on self._t_ref_astropy / # self._t_ref_jd (and on the state's time_reference), and any # user-supplied query times are converted to offsets via # _times_to_offsets before they hit the integrator. if state is None: assert particles is not None t_ref_jds = jnp.array([p._t_ref_jd for p in particles]) assert jnp.allclose( t_ref_jds, t_ref_jds[0], atol=1e-6, rtol=0 ), "All particles must have the same reference time (tolerance: 1e-6 days ~ 0.1 s)" self._t_ref_astropy = particles[0]._t_ref_astropy self._t_ref_jd = particles[0]._t_ref_jd self._state = SystemState( tracer_positions=jnp.array([p._x for p in particles]), tracer_velocities=jnp.array([p._v for p in particles]), massive_positions=jnp.empty((0, 3)), massive_velocities=jnp.empty((0, 3)), log_gms=jnp.empty((0,)), time_reference=self._t_ref_jd, relative_time=jnp.array(0.0), fixed_perturber_positions=jnp.empty((0, 3)), fixed_perturber_velocities=jnp.empty((0, 3)), fixed_perturber_log_gms=jnp.empty((0,)), acceleration_func_kwargs={}, ) else: # The state self-describes its absolute epoch as # relative_time + time_reference (the same convention as # CartesianState/KeplerianState and Particle.__init__). Use that sum # as the System's reference epoch, then rebase the stored state to # relative_time=0.0 against the new anchor. abs_jd = float(state.relative_time) + float(state.time_reference) self._t_ref_astropy = Time(abs_jd, format="jd", scale="tdb") self._t_ref_jd = jnp.array(abs_jd) self._state = state.replace( relative_time=jnp.array(0.0), time_reference=self._t_ref_jd, ) if self._is_keplerian: self.gravity = "keplerian" self._integrator_state = None self._integrator = None self._integrator_method = "keplerian" self._max_step_size = None else: self.gravity = self._setup_acceleration_func(gravity) self._integrator_state, self._integrator = self._setup_integrator( integrator, max_step_size ) self._integrator_method = integrator self._max_step_size = max_step_size def __repr__(self) -> str: """Return a string representation of the System.""" return f"*************\njorbit System\n time: {Time(self._t_ref_jd, format='jd', scale='tdb').utc.iso}\n*************" @property def t_ref(self) -> Time: """Reference time (astropy Time, TDB) — the System's epoch. All JAX-visible times inside the System are offsets in days from this reference, which keeps the integrator's internal arithmetic well-conditioned at decadal timescales. """ return self._t_ref_astropy @property def t_ref_jd(self) -> float: """Reference time as a float JD (TDB), matching ``t_ref``.""" return self._t_ref_jd def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray: """Convert user-provided query times to offsets from ``self._t_ref_astropy``. Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair through the subtraction, so the returned offsets retain sub-ns precision regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs are assumed to be absolute JD (TDB) and are subtracted in float64 — this is Sterbenz-exact when the magnitudes match, but the offset precision is capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities). """ if isinstance(times, Time): return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day)) return jnp.asarray(times) - self._t_ref_jd @staticmethod def _resolve_step_scheduler(name: str) -> Callable: """Translate a string name to a JIT-friendly Partial step scheduler.""" if name.lower() == "prs23": return jax.tree_util.Partial(next_proposed_dt_PRS23) elif name.lower() == "global": return jax.tree_util.Partial(next_proposed_dt_global) raise ValueError( f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'." ) def _setup_acceleration_func(self, gravity: str | Callable) -> Callable: if isinstance(gravity, jax.tree_util.Partial): # See Particle._setup_acceleration_func for the full rationale. # CONTRACT: the custom function must recover the absolute JD from the # state itself as inputs.relative_time + inputs.time_reference (what # jorbit's own factory functions now do). The integrator hands it a # state carrying time_reference == self._t_ref_jd, so no shim is needed. return gravity if gravity == "newtonian planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", ) acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor) elif gravity == "newtonian solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor) elif gravity == "gr planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", ) acc_func = create_gr_ephemeris_acceleration_func(eph.processor) elif gravity == "gr solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_gr_ephemeris_acceleration_func(eph.processor) elif gravity == "default solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_default_ephemeris_acceleration_func(eph.processor) elif gravity == "generic newtonian": # newtonian_gravity / ppn_gravity don't read the state's time at all, # so they're used directly with no ephemeris lookup. acc_func = jax.tree_util.Partial(newtonian_gravity) elif gravity == "generic gr": acc_func = jax.tree_util.Partial(ppn_gravity) else: raise ValueError( f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', " "'newtonian solar system', 'gr planets', 'gr solar system', " "'default solar system', 'generic newtonian', 'generic gr', 'keplerian'. " "For a custom acceleration function, pass a jax.tree_util.Partial." ) return acc_func def _setup_integrator( self, integrator: str, max_step_size: u.Quantity | None ) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]: if integrator == "ias15": assert ( max_step_size is None ), "max_step_size should not be provided for IAS15 integrator." a0 = self.gravity(self._state) integrator_state = initialize_ias15_integrator_state(a0) integrator = jax.tree_util.Partial(ias15_evolve) elif integrator in ["Y4", "Y6", "Y8"]: assert ( max_step_size is not None ), "Must provide max_step_size for leapfrog integrators." dt = max_step_size.to(u.day).value if integrator == "Y4": c = Y4_C d = Y4_D elif integrator == "Y6": c = Y6_C d = Y6_D elif integrator == "Y8": c = Y8_C d = Y8_D integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d) integrator = jax.tree_util.Partial(leapfrog_evolve) return integrator_state, integrator ################ # PUBLIC METHODS ################
[docs] def integrate( self, times: Time, step_scheduler: str = "prs23", return_steps: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the System to a given time. Note: This method does not change the state of the system. It returns the positions and velocities at the given times, but the system itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. step_scheduler (str): The scheduler to use for determining step sizes. Choices are "prs23", which uses the PRS23 controller from Pham+ 2023, or "global", which uses the controller from the original IAS15 paper. Default is "prs23". Ignored for leapfrog integrators and keplerian systems. return_steps (bool): If True, also return the total number of integration steps taken (summed across any stitched interpolation chunks), appended as the final element of the returned tuple. ``None`` for the analytic Keplerian and fixed-step leapfrog paths, which cannot truncate. Defaults to False. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. If return_steps is True, also the total integration step count. """ times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._is_keplerian: positions, velocities = _keplerian_system_integrate( self._state.tracer_positions, self._state.tracer_velocities, self._state.relative_time, times, ) steps = None elif self._integrator_method in ["Y4", "Y6", "Y8"]: # Leapfrog pre-expands its (uncapped) time array, so it never truncates. times, inds = create_leapfrog_times( t0=self._state.relative_time, times=times, biggest_allowed_dt=self._integrator_state.dt, ) scheduler = self._resolve_step_scheduler(step_scheduler) positions, velocities, _fss, _fis = _integrate( times, self._state, self.gravity, self._integrator, self._integrator_state, inds, scheduler, ) steps = None else: # IAS15: stitch dense-output chunks so the 15k buffer can't silently # truncate. See jorbit.integrators.budgeted. scheduler = self._resolve_step_scheduler(step_scheduler) positions, velocities, steps = stitched_interpolate( self._state, self.gravity, times, self._integrator_state, scheduler, ) if return_steps: return positions, velocities, steps return positions, velocities
[docs] def ephemeris( self, times: Time | jnp.ndarray, observer: str | jnp.ndarray, step_scheduler: str = "prs23", return_steps: bool = False, ) -> SkyCoord: """Compute an ephemeris for the system. Args: times (Time | jnp.ndarray): The times to compute the ephemeris for. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. observer (str | jnp.ndarray): The observer to compute the ephemeris for. Can be a string representing an observatory name, or a 3D position vector in AU. For more info on acceptable strings, see the get_observer_positions function. step_scheduler (str): The scheduler to use for determining step sizes. Choices are "prs23", which uses the PRS23 controller from Pham+ 2023, or "global", which uses the controller from the original IAS15 paper. Default is "prs23". Ignored for leapfrog integrators and keplerian systems. return_steps (bool): If True, return a tuple of the SkyCoord and the total number of integration steps taken (summed across any stitched interpolation chunks). ``None`` for the analytic Keplerian and fixed-step leapfrog paths, which cannot truncate. Defaults to False. Returns: coords (SkyCoord): The ephemeris of each particle in the system at the given times, in ICRS coordinates and as seen from that specific observer. Each particle has its own light travel time correction applied individually. If return_steps is True, returns ``(coords, steps)``. """ if isinstance(observer, str): observer_positions = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_positions = observer times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._is_keplerian: ras, decs = _keplerian_system_ephem( self._state.tracer_positions, self._state.tracer_velocities, self._state.relative_time, times, observer_positions, ) coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if return_steps: return coords, None return coords if self._integrator_method in ["Y4", "Y6", "Y8"]: # For leapfrog, need to create intermediate times. Leapfrog has no # b-coefficients; fall back to the original constant-acceleration Taylor. times, inds = create_leapfrog_times( t0=self._state.relative_time, times=times, biggest_allowed_dt=self._integrator_state.dt, ) scheduler = self._resolve_step_scheduler(step_scheduler) ras, decs = _ephem( times, self._state, self.gravity, self._integrator, self._integrator_state, observer_positions, inds, scheduler, ) steps = None else: # IAS15: stitch dense-output chunks (truncation-proof) and use the # b-coefficients for each particle's light-travel-time correction. inds = jnp.arange(times.shape[0]) scheduler = self._resolve_step_scheduler(step_scheduler) ras, decs, steps = _ephem_ias15_stitched( times, self._state, self.gravity, self._integrator_state, observer_positions, inds, scheduler, ) coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if return_steps: return coords, steps return coords