Source code for jorbit.system

"""The System class and its supporting functions."""

import jax

jax.config.update("jax_enable_x64", True)
from collections.abc import Callable

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_gr_ephemeris_acceleration_func,
    create_newtonian_ephemeris_acceleration_func,
    newtonian_gravity,
    ppn_gravity,
)
from jorbit.astrometry.sky_projection import on_sky
from jorbit.data.constants import Y4_C, Y4_D, Y6_C, Y6_D, Y8_C, Y8_D
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
    create_leapfrog_times,
    ias15_evolve,
    ias15_evolve_with_dense_output,
    initialize_ias15_integrator_state,
    leapfrog_evolve,
    make_ltt_propagator,
    next_proposed_dt_global,
    next_proposed_dt_PRS23,
)
from jorbit.particle import _keplerian_integrate, _keplerian_on_sky
from jorbit.utils.horizons import get_observer_positions
from jorbit.utils.states import (
    IAS15IntegratorState,
    LeapfrogIntegratorState,
    SystemState,
)


[docs] class System: """A system of particles in the solar system. Very similar in spirit to the `Particle` class, but now for multiple massless particles. """ def __init__( self, particles: list | None = None, state: SystemState | None = None, gravity: str | Callable = "default solar system", de_ephemeris_version: str | None = "440", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), max_step_size: u.Quantity | None = None, ) -> None: """Initialize a System. Args: particles (list, optional): A list of Particle objects. None if state is provided. Defaults to None. state (SystemState, optional): A SystemState object. None if particles is provided. Defaults to None. gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions when using one of the built-in gravity models. Accepts either "440" or "430", default is "440". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. """ self._earliest_time = earliest_time self._latest_time = latest_time self._de_ephemeris_version = de_ephemeris_version self._is_keplerian = gravity == "keplerian" # Mirrors the Particle rebase: the JAX-visible state always carries # time=0.0, the absolute epoch lives on self._t_ref_astropy / # self._t_ref_jd, and any user-supplied query times are converted to # offsets via _times_to_offsets before they hit the integrator. if state is None: assert particles is not None t_ref_jds = jnp.array([p._t_ref_jd for p in particles]) assert jnp.allclose( t_ref_jds, t_ref_jds[0], atol=1e-6, rtol=0 ), "All particles must have the same reference time (tolerance: 1e-6 days ~ 0.1 s)" self._t_ref_astropy = particles[0]._t_ref_astropy self._t_ref_jd = particles[0]._t_ref_jd self._state = SystemState( tracer_positions=jnp.array([p._x for p in particles]), tracer_velocities=jnp.array([p._v for p in particles]), massive_positions=jnp.empty((0, 3)), massive_velocities=jnp.empty((0, 3)), log_gms=jnp.empty((0,)), time=jnp.array(0.0), fixed_perturber_positions=jnp.empty((0, 3)), fixed_perturber_velocities=jnp.empty((0, 3)), fixed_perturber_log_gms=jnp.empty((0,)), acceleration_func_kwargs={}, ) else: t_in = state.time if isinstance(t_in, Time): t_ref_astropy = t_in.tdb else: raise ValueError( "Cannot determine the absolute epoch from a SystemState whose " "time field is a numeric offset. SystemState.time is an " "integration offset (typically 0.0), not an absolute JD. " "To build a System from an existing particle state, use " "System(particles=[particle]) or set state.time to an " "astropy Time object representing the absolute epoch." ) self._t_ref_astropy = t_ref_astropy self._t_ref_jd = jnp.array(float(t_ref_astropy.tdb.jd)) self._state = state.replace(time=jnp.array(0.0)) if self._is_keplerian: self.gravity = "keplerian" self._integrator_state = None self._integrator = None self._integrator_method = "keplerian" self._max_step_size = None else: self.gravity = self._setup_acceleration_func(gravity) self._integrator_state, self._integrator = self._setup_integrator( integrator, max_step_size ) self._integrator_method = integrator self._max_step_size = max_step_size def __repr__(self) -> str: """Return a string representation of the System.""" return f"*************\njorbit System\n time: {Time(self._t_ref_jd, format='jd', scale='tdb').utc.iso}\n*************" @property def t_ref(self) -> Time: """Reference time (astropy Time, TDB) — the System's epoch. All JAX-visible times inside the System are offsets in days from this reference, which keeps the integrator's internal arithmetic well-conditioned at decadal timescales. """ return self._t_ref_astropy @property def t_ref_jd(self) -> float: """Reference time as a float JD (TDB), matching ``t_ref``.""" return self._t_ref_jd def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray: """Convert user-provided query times to offsets from ``self._t_ref_astropy``. Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair through the subtraction, so the returned offsets retain sub-ns precision regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs are assumed to be absolute JD (TDB) and are subtracted in float64 — this is Sterbenz-exact when the magnitudes match, but the offset precision is capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities). """ if isinstance(times, Time): return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day)) return jnp.asarray(times) - self._t_ref_jd @staticmethod def _resolve_step_scheduler(name: str) -> Callable: """Translate a string name to a JIT-friendly Partial step scheduler.""" if name.lower() == "prs23": return jax.tree_util.Partial(next_proposed_dt_PRS23) elif name.lower() == "global": return jax.tree_util.Partial(next_proposed_dt_global) raise ValueError( f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'." ) def _setup_acceleration_func(self, gravity: str | Callable) -> Callable: if isinstance(gravity, jax.tree_util.Partial): # See Particle._setup_acceleration_func for the full rationale. # CONTRACT: the custom function must be built with t_ref_jd=0 and # must treat state.time as an absolute Julian Date. Do NOT pass a # jorbit factory function built with a non-zero t_ref_jd — that # would double-count the offset and silently query the ephemeris # at the wrong absolute time. user_func = gravity t_ref_jd = self._t_ref_jd def _wrapped_user_acc(state: SystemState) -> jnp.ndarray: shifted = state.replace(time=state.time + t_ref_jd) return user_func(shifted) return jax.tree_util.Partial(_wrapped_user_acc) if gravity == "newtonian planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", ) acc_func = create_newtonian_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "newtonian solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_newtonian_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "gr planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", ) acc_func = create_gr_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "gr solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_gr_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "default solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", ) acc_func = create_default_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "generic newtonian": # newtonian_gravity / ppn_gravity don't read inputs.time, so the # rebase is invisible to them — no shim needed. acc_func = jax.tree_util.Partial(newtonian_gravity) elif gravity == "generic gr": acc_func = jax.tree_util.Partial(ppn_gravity) else: raise ValueError( f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', " "'newtonian solar system', 'gr planets', 'gr solar system', " "'default solar system', 'generic newtonian', 'generic gr', 'keplerian'. " "For a custom acceleration function, pass a jax.tree_util.Partial." ) return acc_func def _setup_integrator( self, integrator: str, max_step_size: u.Quantity | None ) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]: if integrator == "ias15": assert ( max_step_size is None ), "max_step_size should not be provided for IAS15 integrator." a0 = self.gravity(self._state) integrator_state = initialize_ias15_integrator_state(a0) integrator = jax.tree_util.Partial(ias15_evolve) elif integrator in ["Y4", "Y6", "Y8"]: assert ( max_step_size is not None ), "Must provide max_step_size for leapfrog integrators." dt = max_step_size.to(u.day).value if integrator == "Y4": c = Y4_C d = Y4_D elif integrator == "Y6": c = Y6_C d = Y6_D elif integrator == "Y8": c = Y8_C d = Y8_D integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d) integrator = jax.tree_util.Partial(leapfrog_evolve) return integrator_state, integrator ################ # PUBLIC METHODS ################
[docs] def integrate( self, times: Time, step_scheduler: str = "prs23" ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the System to a given time. Note: This method does not change the state of the system. It returns the positions and velocities at the given times, but the system itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. step_scheduler (str): The scheduler to use for determining step sizes. Choices are "prs23", which uses the PRS23 controller from Pham+ 2023, or "global", which uses the controller from the original IAS15 paper. Default is "prs23". Ignored for leapfrog integrators and keplerian systems. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. """ times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._is_keplerian: return _keplerian_system_integrate( self._state.tracer_positions, self._state.tracer_velocities, self._state.time, times, ) if self._integrator_method in ["Y4", "Y6", "Y8"]: # For leapfrog, need to create intermediate times times, inds = create_leapfrog_times( t0=self._state.time, times=times, biggest_allowed_dt=self._integrator_state.dt, ) else: inds = jnp.arange(times.shape[0]) scheduler = self._resolve_step_scheduler(step_scheduler) positions, velocities, _final_system_state, _final_integrator_state = ( _integrate( times, self._state, self.gravity, self._integrator, self._integrator_state, inds, scheduler, ) ) return positions, velocities
[docs] def ephemeris( self, times: Time | jnp.ndarray, observer: str | jnp.ndarray, step_scheduler: str = "prs23", ) -> SkyCoord: """Compute an ephemeris for the system. Args: times (Time | jnp.ndarray): The times to compute the ephemeris for. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. observer (str | jnp.ndarray): The observer to compute the ephemeris for. Can be a string representing an observatory name, or a 3D position vector in AU. For more info on acceptable strings, see the get_observer_positions function. step_scheduler (str): The scheduler to use for determining step sizes. Choices are "prs23", which uses the PRS23 controller from Pham+ 2023, or "global", which uses the controller from the original IAS15 paper. Default is "prs23". Ignored for leapfrog integrators and keplerian systems. Returns: coords (SkyCoord): The ephemeris of each particle in the system at the given times, in ICRS coordinates and as seen from that specific observer. Each particle has its own light travel time correction applied individually. """ if isinstance(observer, str): observer_positions = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_positions = observer times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._is_keplerian: ras, decs = _keplerian_system_ephem( self._state.tracer_positions, self._state.tracer_velocities, self._state.time, times, observer_positions, ) return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if self._integrator_method in ["Y4", "Y6", "Y8"]: # For leapfrog, need to create intermediate times times, inds = create_leapfrog_times( t0=self._state.time, times=times, biggest_allowed_dt=self._integrator_state.dt, ) else: inds = jnp.arange(times.shape[0]) scheduler = self._resolve_step_scheduler(step_scheduler) # IAS15 has dense-output b-coefficients available, so use them in on_sky's # LTT loop instead of a constant-acceleration Taylor expansion. Leapfrog # has no b-coefficients; fall back to the original Taylor. use_dense_ltt = self._integrator_method not in ("Y4", "Y6", "Y8") if use_dense_ltt: ras, decs = _ephem_ias15( times, self._state, self.gravity, self._integrator_state, observer_positions, inds, scheduler, ) else: ras, decs = _ephem( times, self._state, self.gravity, self._integrator, self._integrator_state, observer_positions, inds, scheduler, ) return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
@jax.jit def _integrate( times: jnp.ndarray, state: SystemState, acc_func: Callable, integrator_func: Callable, integrator_state: IAS15IntegratorState, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[jnp.ndarray, jnp.ndarray, SystemState, IAS15IntegratorState]: positions, velocities, final_system_state, final_integrator_state, _steps = ( integrator_func(state, acc_func, times, integrator_state, step_scheduler) ) return ( positions[relevant_inds], velocities[relevant_inds], final_system_state, final_integrator_state, ) @jax.jit def _ephem( times: jnp.ndarray, state: SystemState, acc_func: Callable, integrator_func: Callable, integrator_state: IAS15IntegratorState, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[jnp.ndarray, jnp.ndarray]: positions, velocities, _, _ = _integrate( times, state, acc_func, integrator_func, integrator_state, relevant_inds, step_scheduler, ) def interior(px: jnp.ndarray, pv: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: def scan_func( carry: None, scan_over: tuple[jnp.ndarray, jnp.ndarray] ) -> tuple[None, tuple[jnp.ndarray, jnp.ndarray]]: position, velocity, time, observer_position = scan_over ra, dec = on_sky(position, velocity, time, observer_position, acc_func) return None, (ra, dec) _, (ras, decs) = jax.lax.scan( scan_func, None, (px, pv, times, observer_positions), ) return ras, decs ras, decs = jax.vmap(interior, in_axes=(1, 1))(positions, velocities) return ras, decs @jax.jit def _ephem_ias15( times: jnp.ndarray, state: SystemState, acc_func: Callable, integrator_state: IAS15IntegratorState, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[jnp.ndarray, jnp.ndarray]: """IAS15-only variant of ``_ephem`` that uses dense-output b-coefficients for LTT. Vmaps the per-observation polynomial-LTT closure over both the observation axis and the particle axis; each particle gets its own light-travel-time correction. """ ( _positions, _velocities, _final_system_state, _final_integrator_state, _iter_num, b_buf, a0_buf, x0_buf, v0_buf, dts_buf, _t_step_starts, step_indices, h_values, ) = ias15_evolve_with_dense_output( state, acc_func, times, integrator_state, step_scheduler ) obs_times = times[relevant_inds] obs_step_indices = step_indices[relevant_inds] obs_h_values = h_values[relevant_inds] # Per-obs gather. Shapes: b (n_obs, 7, n_particles, 3), a0/v0/x0 (n_obs, n_particles, 3), # dt (n_obs,). b_per_obs_all = b_buf[obs_step_indices] a0_per_obs_all = a0_buf[obs_step_indices] x0_per_obs_all = x0_buf[obs_step_indices] v0_per_obs_all = v0_buf[obs_step_indices] dt_per_obs = dts_buf[obs_step_indices] def per_particle_per_obs( b_step: jnp.ndarray, a0_step: jnp.ndarray, x0_step: jnp.ndarray, v0_step: jnp.ndarray, dt_step: jnp.ndarray, h_obs: jnp.ndarray, time: jnp.ndarray, observer_pos: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: propagator = make_ltt_propagator( b_step, a0_step, x0_step, v0_step, dt_step, h_obs ) x_obs = propagator(jnp.array(0.0)) return on_sky( x_obs, jnp.zeros(3), time, observer_pos, acc_func, ltt_position_fn=propagator, ) def for_single_particle( b_obs_p: jnp.ndarray, a0_obs_p: jnp.ndarray, x0_obs_p: jnp.ndarray, v0_obs_p: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: # b_obs_p: (n_obs, 7, 3); a0/v0/x0_obs_p: (n_obs, 3) return jax.vmap(per_particle_per_obs, in_axes=(0, 0, 0, 0, 0, 0, 0, 0))( b_obs_p, a0_obs_p, x0_obs_p, v0_obs_p, dt_per_obs, obs_h_values, obs_times, observer_positions, ) # Vmap over particle axis: 2 in b_per_obs_all (axes are obs/coeff/particle/xyz), # 1 in a0/v0/x0_per_obs_all (axes are obs/particle/xyz). ras, decs = jax.vmap(for_single_particle, in_axes=(2, 1, 1, 1))( b_per_obs_all, a0_per_obs_all, x0_per_obs_all, v0_per_obs_all ) return ras, decs @jax.jit def _keplerian_system_integrate( xs: jnp.ndarray, vs: jnp.ndarray, t0: float, times: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: # vmap _keplerian_integrate over particles: (N,T,3) for each positions, velocities = jax.vmap(_keplerian_integrate, in_axes=(0, 0, None, None))( xs, vs, t0, times ) # transpose to (T,N,3) to match existing convention positions = jnp.transpose(positions, (1, 0, 2)) velocities = jnp.transpose(velocities, (1, 0, 2)) return positions, velocities @jax.jit def _keplerian_system_ephem( xs: jnp.ndarray, vs: jnp.ndarray, t0: float, times: jnp.ndarray, observer_positions: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: positions, velocities = _keplerian_system_integrate(xs, vs, t0, times) # _keplerian_on_sky operates on a single (position, velocity, time, observer) # vmap over times (axis 0 of positions[n]), then over particles (axis 1) _on_sky_over_times = jax.vmap(_keplerian_on_sky, in_axes=(0, 0, 0, 0)) _on_sky_over_particles = jax.vmap(_on_sky_over_times, in_axes=(1, 1, None, None)) ras, decs = _on_sky_over_particles(positions, velocities, times, observer_positions) return ras, decs