"""The System class and its supporting functions."""
import jax
jax.config.update("jax_enable_x64", True)
from collections.abc import Callable
import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time
from jorbit.accelerations import (
create_default_ephemeris_acceleration_func,
create_gr_ephemeris_acceleration_func,
create_newtonian_ephemeris_acceleration_func,
newtonian_gravity,
ppn_gravity,
)
from jorbit.astrometry.sky_projection import on_sky
from jorbit.data.constants import Y4_C, Y4_D, Y6_C, Y6_D, Y8_C, Y8_D
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
create_leapfrog_times,
ias15_evolve,
ias15_evolve_with_dense_output,
initialize_ias15_integrator_state,
leapfrog_evolve,
make_ltt_propagator,
next_proposed_dt_global,
next_proposed_dt_PRS23,
)
from jorbit.particle import _keplerian_integrate, _keplerian_on_sky
from jorbit.utils.horizons import get_observer_positions
from jorbit.utils.states import (
IAS15IntegratorState,
LeapfrogIntegratorState,
SystemState,
)
[docs]
class System:
"""A system of particles in the solar system.
Very similar in spirit to the `Particle` class, but now for multiple massless
particles.
"""
def __init__(
self,
particles: list | None = None,
state: SystemState | None = None,
gravity: str | Callable = "default solar system",
de_ephemeris_version: str | None = "440",
integrator: str = "ias15",
earliest_time: Time = Time("1980-01-01"),
latest_time: Time = Time("2050-01-01"),
max_step_size: u.Quantity | None = None,
) -> None:
"""Initialize a System.
Args:
particles (list, optional):
A list of Particle objects. None if state is provided. Defaults to None.
state (SystemState, optional):
A SystemState object. None if particles is provided. Defaults to None.
gravity (str | Callable):
The gravitational acceleration function to use when integrating the
particle's orbit. Defaults to "default solar system", which corresponds
to parameterized post-Newtonian interactions with the 10 bodies in the
JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest
asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be
a jax.tree_util.Partial object that follows the same signature as the
acceleration functions in jorbit.accelerations.
de_ephemeris_version (str | None):
Which version of the JPL DE ephemeris to use for perturber positions
when using one of the built-in gravity models. Accepts either "440" or
"430", default is "440".
integrator (str):
The integrator to use for the particle. Choices are "ias15", which is a
15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which
are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed
step sizes. Defaults to "ias15".
earliest_time (Time):
The earliest time we expect to integrate the particle to. Defaults to
Time("1980-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
latest_time (Time):
The latest time we expect to integrate the particle to. Defaults to
Time("2050-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
max_step_size (u.Quantity, optional):
The fixed step size to use for leapfrog integrators. Required if
integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15".
Note that this is the maximum step size; the actual step size may be
smaller to ensure that the particle lands exactly on the requested
output times, and that the step size may change if the spacing between
output times is not constant. Defaults to None.
"""
self._earliest_time = earliest_time
self._latest_time = latest_time
self._de_ephemeris_version = de_ephemeris_version
self._is_keplerian = gravity == "keplerian"
# Mirrors the Particle rebase: the JAX-visible state always carries
# time=0.0, the absolute epoch lives on self._t_ref_astropy /
# self._t_ref_jd, and any user-supplied query times are converted to
# offsets via _times_to_offsets before they hit the integrator.
if state is None:
assert particles is not None
t_ref_jds = jnp.array([p._t_ref_jd for p in particles])
assert jnp.allclose(
t_ref_jds, t_ref_jds[0], atol=1e-6, rtol=0
), "All particles must have the same reference time (tolerance: 1e-6 days ~ 0.1 s)"
self._t_ref_astropy = particles[0]._t_ref_astropy
self._t_ref_jd = particles[0]._t_ref_jd
self._state = SystemState(
tracer_positions=jnp.array([p._x for p in particles]),
tracer_velocities=jnp.array([p._v for p in particles]),
massive_positions=jnp.empty((0, 3)),
massive_velocities=jnp.empty((0, 3)),
log_gms=jnp.empty((0,)),
time=jnp.array(0.0),
fixed_perturber_positions=jnp.empty((0, 3)),
fixed_perturber_velocities=jnp.empty((0, 3)),
fixed_perturber_log_gms=jnp.empty((0,)),
acceleration_func_kwargs={},
)
else:
t_in = state.time
if isinstance(t_in, Time):
t_ref_astropy = t_in.tdb
else:
raise ValueError(
"Cannot determine the absolute epoch from a SystemState whose "
"time field is a numeric offset. SystemState.time is an "
"integration offset (typically 0.0), not an absolute JD. "
"To build a System from an existing particle state, use "
"System(particles=[particle]) or set state.time to an "
"astropy Time object representing the absolute epoch."
)
self._t_ref_astropy = t_ref_astropy
self._t_ref_jd = jnp.array(float(t_ref_astropy.tdb.jd))
self._state = state.replace(time=jnp.array(0.0))
if self._is_keplerian:
self.gravity = "keplerian"
self._integrator_state = None
self._integrator = None
self._integrator_method = "keplerian"
self._max_step_size = None
else:
self.gravity = self._setup_acceleration_func(gravity)
self._integrator_state, self._integrator = self._setup_integrator(
integrator, max_step_size
)
self._integrator_method = integrator
self._max_step_size = max_step_size
def __repr__(self) -> str:
"""Return a string representation of the System."""
return f"*************\njorbit System\n time: {Time(self._t_ref_jd, format='jd', scale='tdb').utc.iso}\n*************"
@property
def t_ref(self) -> Time:
"""Reference time (astropy Time, TDB) — the System's epoch.
All JAX-visible times inside the System are offsets in days from this
reference, which keeps the integrator's internal arithmetic
well-conditioned at decadal timescales.
"""
return self._t_ref_astropy
@property
def t_ref_jd(self) -> float:
"""Reference time as a float JD (TDB), matching ``t_ref``."""
return self._t_ref_jd
def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray:
"""Convert user-provided query times to offsets from ``self._t_ref_astropy``.
Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair
through the subtraction, so the returned offsets retain sub-ns precision
regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs
are assumed to be absolute JD (TDB) and are subtracted in float64 — this
is Sterbenz-exact when the magnitudes match, but the offset precision is
capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities).
"""
if isinstance(times, Time):
return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day))
return jnp.asarray(times) - self._t_ref_jd
@staticmethod
def _resolve_step_scheduler(name: str) -> Callable:
"""Translate a string name to a JIT-friendly Partial step scheduler."""
if name.lower() == "prs23":
return jax.tree_util.Partial(next_proposed_dt_PRS23)
elif name.lower() == "global":
return jax.tree_util.Partial(next_proposed_dt_global)
raise ValueError(
f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'."
)
def _setup_acceleration_func(self, gravity: str | Callable) -> Callable:
if isinstance(gravity, jax.tree_util.Partial):
# See Particle._setup_acceleration_func for the full rationale.
# CONTRACT: the custom function must be built with t_ref_jd=0 and
# must treat state.time as an absolute Julian Date. Do NOT pass a
# jorbit factory function built with a non-zero t_ref_jd — that
# would double-count the offset and silently query the ephemeris
# at the wrong absolute time.
user_func = gravity
t_ref_jd = self._t_ref_jd
def _wrapped_user_acc(state: SystemState) -> jnp.ndarray:
shifted = state.replace(time=state.time + t_ref_jd)
return user_func(shifted)
return jax.tree_util.Partial(_wrapped_user_acc)
if gravity == "newtonian planets":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default planets",
)
acc_func = create_newtonian_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "newtonian solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
)
acc_func = create_newtonian_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "gr planets":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default planets",
)
acc_func = create_gr_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "gr solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
)
acc_func = create_gr_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "default solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
)
acc_func = create_default_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "generic newtonian":
# newtonian_gravity / ppn_gravity don't read inputs.time, so the
# rebase is invisible to them — no shim needed.
acc_func = jax.tree_util.Partial(newtonian_gravity)
elif gravity == "generic gr":
acc_func = jax.tree_util.Partial(ppn_gravity)
else:
raise ValueError(
f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', "
"'newtonian solar system', 'gr planets', 'gr solar system', "
"'default solar system', 'generic newtonian', 'generic gr', 'keplerian'. "
"For a custom acceleration function, pass a jax.tree_util.Partial."
)
return acc_func
def _setup_integrator(
self, integrator: str, max_step_size: u.Quantity | None
) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]:
if integrator == "ias15":
assert (
max_step_size is None
), "max_step_size should not be provided for IAS15 integrator."
a0 = self.gravity(self._state)
integrator_state = initialize_ias15_integrator_state(a0)
integrator = jax.tree_util.Partial(ias15_evolve)
elif integrator in ["Y4", "Y6", "Y8"]:
assert (
max_step_size is not None
), "Must provide max_step_size for leapfrog integrators."
dt = max_step_size.to(u.day).value
if integrator == "Y4":
c = Y4_C
d = Y4_D
elif integrator == "Y6":
c = Y6_C
d = Y6_D
elif integrator == "Y8":
c = Y8_C
d = Y8_D
integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d)
integrator = jax.tree_util.Partial(leapfrog_evolve)
return integrator_state, integrator
################
# PUBLIC METHODS
################
[docs]
def integrate(
self, times: Time, step_scheduler: str = "prs23"
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Integrate the System to a given time.
Note: This method does not change the state of the system. It returns the
positions and velocities at the given times, but the system itself is not
changed.
Args:
times (Time | jnp.ndarray):
The times to integrate to. Can be a single time or an array of times.
If provided as a jnp.array, the entries are assumed to be in TDB JD.
step_scheduler (str):
The scheduler to use for determining step sizes. Choices are "prs23",
which uses the PRS23 controller from Pham+ 2023, or "global", which
uses the controller from the original IAS15 paper. Default is "prs23".
Ignored for leapfrog integrators and keplerian systems.
Returns:
tuple[jnp.ndarray, jnp.ndarray]:
The positions of the particle at the given times, in AU, and the
The velocities of the particle at the given times, in AU/day.
"""
times = self._times_to_offsets(times)
if times.shape == ():
times = jnp.array([times])
if self._is_keplerian:
return _keplerian_system_integrate(
self._state.tracer_positions,
self._state.tracer_velocities,
self._state.time,
times,
)
if self._integrator_method in ["Y4", "Y6", "Y8"]:
# For leapfrog, need to create intermediate times
times, inds = create_leapfrog_times(
t0=self._state.time,
times=times,
biggest_allowed_dt=self._integrator_state.dt,
)
else:
inds = jnp.arange(times.shape[0])
scheduler = self._resolve_step_scheduler(step_scheduler)
positions, velocities, _final_system_state, _final_integrator_state = (
_integrate(
times,
self._state,
self.gravity,
self._integrator,
self._integrator_state,
inds,
scheduler,
)
)
return positions, velocities
[docs]
def ephemeris(
self,
times: Time | jnp.ndarray,
observer: str | jnp.ndarray,
step_scheduler: str = "prs23",
) -> SkyCoord:
"""Compute an ephemeris for the system.
Args:
times (Time | jnp.ndarray):
The times to compute the ephemeris for. Can be a single time or an array
of times. If provided as a jnp.array, the entries are assumed to be in
TDB JD.
observer (str | jnp.ndarray):
The observer to compute the ephemeris for. Can be a string representing
an observatory name, or a 3D position vector in AU. For more info on
acceptable strings, see the get_observer_positions function.
step_scheduler (str):
The scheduler to use for determining step sizes. Choices are "prs23",
which uses the PRS23 controller from Pham+ 2023, or "global", which
uses the controller from the original IAS15 paper. Default is "prs23".
Ignored for leapfrog integrators and keplerian systems.
Returns:
coords (SkyCoord):
The ephemeris of each particle in the system at the given times, in ICRS
coordinates and as seen from that specific observer. Each particle has
its own light travel time correction applied individually.
"""
if isinstance(observer, str):
observer_positions = get_observer_positions(
times, observer, self._de_ephemeris_version
)
else:
observer_positions = observer
times = self._times_to_offsets(times)
if times.shape == ():
times = jnp.array([times])
if self._is_keplerian:
ras, decs = _keplerian_system_ephem(
self._state.tracer_positions,
self._state.tracer_velocities,
self._state.time,
times,
observer_positions,
)
return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
if self._integrator_method in ["Y4", "Y6", "Y8"]:
# For leapfrog, need to create intermediate times
times, inds = create_leapfrog_times(
t0=self._state.time,
times=times,
biggest_allowed_dt=self._integrator_state.dt,
)
else:
inds = jnp.arange(times.shape[0])
scheduler = self._resolve_step_scheduler(step_scheduler)
# IAS15 has dense-output b-coefficients available, so use them in on_sky's
# LTT loop instead of a constant-acceleration Taylor expansion. Leapfrog
# has no b-coefficients; fall back to the original Taylor.
use_dense_ltt = self._integrator_method not in ("Y4", "Y6", "Y8")
if use_dense_ltt:
ras, decs = _ephem_ias15(
times,
self._state,
self.gravity,
self._integrator_state,
observer_positions,
inds,
scheduler,
)
else:
ras, decs = _ephem(
times,
self._state,
self.gravity,
self._integrator,
self._integrator_state,
observer_positions,
inds,
scheduler,
)
return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
@jax.jit
def _integrate(
times: jnp.ndarray,
state: SystemState,
acc_func: Callable,
integrator_func: Callable,
integrator_state: IAS15IntegratorState,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[jnp.ndarray, jnp.ndarray, SystemState, IAS15IntegratorState]:
positions, velocities, final_system_state, final_integrator_state, _steps = (
integrator_func(state, acc_func, times, integrator_state, step_scheduler)
)
return (
positions[relevant_inds],
velocities[relevant_inds],
final_system_state,
final_integrator_state,
)
@jax.jit
def _ephem(
times: jnp.ndarray,
state: SystemState,
acc_func: Callable,
integrator_func: Callable,
integrator_state: IAS15IntegratorState,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[jnp.ndarray, jnp.ndarray]:
positions, velocities, _, _ = _integrate(
times,
state,
acc_func,
integrator_func,
integrator_state,
relevant_inds,
step_scheduler,
)
def interior(px: jnp.ndarray, pv: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
def scan_func(
carry: None, scan_over: tuple[jnp.ndarray, jnp.ndarray]
) -> tuple[None, tuple[jnp.ndarray, jnp.ndarray]]:
position, velocity, time, observer_position = scan_over
ra, dec = on_sky(position, velocity, time, observer_position, acc_func)
return None, (ra, dec)
_, (ras, decs) = jax.lax.scan(
scan_func,
None,
(px, pv, times, observer_positions),
)
return ras, decs
ras, decs = jax.vmap(interior, in_axes=(1, 1))(positions, velocities)
return ras, decs
@jax.jit
def _ephem_ias15(
times: jnp.ndarray,
state: SystemState,
acc_func: Callable,
integrator_state: IAS15IntegratorState,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""IAS15-only variant of ``_ephem`` that uses dense-output b-coefficients for LTT.
Vmaps the per-observation polynomial-LTT closure over both the observation axis
and the particle axis; each particle gets its own light-travel-time correction.
"""
(
_positions,
_velocities,
_final_system_state,
_final_integrator_state,
_iter_num,
b_buf,
a0_buf,
x0_buf,
v0_buf,
dts_buf,
_t_step_starts,
step_indices,
h_values,
) = ias15_evolve_with_dense_output(
state, acc_func, times, integrator_state, step_scheduler
)
obs_times = times[relevant_inds]
obs_step_indices = step_indices[relevant_inds]
obs_h_values = h_values[relevant_inds]
# Per-obs gather. Shapes: b (n_obs, 7, n_particles, 3), a0/v0/x0 (n_obs, n_particles, 3),
# dt (n_obs,).
b_per_obs_all = b_buf[obs_step_indices]
a0_per_obs_all = a0_buf[obs_step_indices]
x0_per_obs_all = x0_buf[obs_step_indices]
v0_per_obs_all = v0_buf[obs_step_indices]
dt_per_obs = dts_buf[obs_step_indices]
def per_particle_per_obs(
b_step: jnp.ndarray,
a0_step: jnp.ndarray,
x0_step: jnp.ndarray,
v0_step: jnp.ndarray,
dt_step: jnp.ndarray,
h_obs: jnp.ndarray,
time: jnp.ndarray,
observer_pos: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
propagator = make_ltt_propagator(
b_step, a0_step, x0_step, v0_step, dt_step, h_obs
)
x_obs = propagator(jnp.array(0.0))
return on_sky(
x_obs,
jnp.zeros(3),
time,
observer_pos,
acc_func,
ltt_position_fn=propagator,
)
def for_single_particle(
b_obs_p: jnp.ndarray,
a0_obs_p: jnp.ndarray,
x0_obs_p: jnp.ndarray,
v0_obs_p: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
# b_obs_p: (n_obs, 7, 3); a0/v0/x0_obs_p: (n_obs, 3)
return jax.vmap(per_particle_per_obs, in_axes=(0, 0, 0, 0, 0, 0, 0, 0))(
b_obs_p,
a0_obs_p,
x0_obs_p,
v0_obs_p,
dt_per_obs,
obs_h_values,
obs_times,
observer_positions,
)
# Vmap over particle axis: 2 in b_per_obs_all (axes are obs/coeff/particle/xyz),
# 1 in a0/v0/x0_per_obs_all (axes are obs/particle/xyz).
ras, decs = jax.vmap(for_single_particle, in_axes=(2, 1, 1, 1))(
b_per_obs_all, a0_per_obs_all, x0_per_obs_all, v0_per_obs_all
)
return ras, decs
@jax.jit
def _keplerian_system_integrate(
xs: jnp.ndarray,
vs: jnp.ndarray,
t0: float,
times: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
# vmap _keplerian_integrate over particles: (N,T,3) for each
positions, velocities = jax.vmap(_keplerian_integrate, in_axes=(0, 0, None, None))(
xs, vs, t0, times
)
# transpose to (T,N,3) to match existing convention
positions = jnp.transpose(positions, (1, 0, 2))
velocities = jnp.transpose(velocities, (1, 0, 2))
return positions, velocities
@jax.jit
def _keplerian_system_ephem(
xs: jnp.ndarray,
vs: jnp.ndarray,
t0: float,
times: jnp.ndarray,
observer_positions: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
positions, velocities = _keplerian_system_integrate(xs, vs, t0, times)
# _keplerian_on_sky operates on a single (position, velocity, time, observer)
# vmap over times (axis 0 of positions[n]), then over particles (axis 1)
_on_sky_over_times = jax.vmap(_keplerian_on_sky, in_axes=(0, 0, 0, 0))
_on_sky_over_particles = jax.vmap(_on_sky_over_times, in_axes=(1, 1, None, None))
ras, decs = _on_sky_over_particles(positions, velocities, times, observer_positions)
return ras, decs