"""The Particle class and its supporting functions."""
from __future__ import annotations
import jax
jax.config.update("jax_enable_x64", True)
import warnings
from collections.abc import Callable
from copy import deepcopy
import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time
# from jaxlib.xla_extension import PjitFunction
from scipy.optimize import minimize
from jorbit import Observations
from jorbit.accelerations import (
create_default_ephemeris_acceleration_func,
create_gr_ephemeris_acceleration_func,
create_newtonian_ephemeris_acceleration_func,
)
from jorbit.astrometry.orbit_fit_seeds import gauss_method_orbit, simple_circular
from jorbit.astrometry.sky_projection import on_sky, tangent_plane_projection
from jorbit.astrometry.transformations import (
elements_to_cartesian,
horizons_ecliptic_to_icrs,
icrs_to_horizons_ecliptic,
)
from jorbit.data.constants import (
INV_SPEED_OF_LIGHT,
SPEED_OF_LIGHT,
TOTAL_SOLAR_SYSTEM_GM,
Y4_C,
Y4_D,
Y6_C,
Y6_D,
Y8_C,
Y8_D,
)
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
create_leapfrog_times,
ias15_evolve,
ias15_evolve_forced_landing,
ias15_evolve_with_dense_output,
initialize_ias15_integrator_state,
leapfrog_evolve,
make_ltt_propagator,
next_proposed_dt_global,
next_proposed_dt_PRS23,
)
from jorbit.likelihoods.setup_static_likelihood import (
create_default_static_residuals_func,
precompute_likelihood_data,
)
from jorbit.utils.horizons import get_observer_positions, horizons_bulk_vector_query
from jorbit.utils.kepler import keplerian_propagate
from jorbit.utils.states import (
CartesianState,
IAS15IntegratorState,
KeplerianState,
LeapfrogIntegratorState,
SystemState,
)
# Squared conversion factor from radians^2 to arcsec^2.
_RAD2ARCSEC_SQ = (180.0 * 3600.0 / jnp.pi) ** 2
[docs]
class Particle:
"""An object representing a single particle in the solar system.
This class is used to represent and manipulate a single particle moving within the
solar system. It is mostly a collection of convenience wrappers around the more
general integrators and accelerations, but it also provides some useful methods for
projecting the particle's position onto the sky and fitting orbits to observations.
By construction, `Particle` objects are massless.
Note: none of the methods associated with this class will alter the underlying state
of the particle. For example, "integrate" will give you the positions and velocities
of the particle at future times, but after it returns, the particle will still be
at its original state.
Attributes:
state: The state of the particle, either in Cartesian or Keplerian coordinates.
time: The time of the particle's state.
x: The position of the particle in Cartesian coordinates.
v: The velocity of the particle in Cartesian coordinates.
observations: A collection of observations of the particle.
name: The name of the particle.
gravity: The gravitational acceleration function to use for the particle.
integrator: The integrator to use for the particle.
earliest_time: The earliest time for which ephemeris data is available.
latest_time: The latest time for which ephemeris data is available.
fit_seed: A seed for fitting the orbit of the particle.
"""
def __init__(
self,
state: KeplerianState | CartesianState | None = None,
time: Time | None = None,
x: jnp.ndarray | None = None,
v: jnp.ndarray | None = None,
observations: Observations | None = None,
name: str = "",
gravity: str | Callable = "default solar system",
de_ephemeris_version: str | None = "440",
integrator: str = "ias15",
earliest_time: Time = Time("1980-01-01"),
latest_time: Time = Time("2050-01-01"),
fit_seed: KeplerianState | CartesianState | None = None,
max_step_size: u.Quantity | None = None,
step_scheduler: str = "prs23",
) -> None:
"""Initialize a Particle object.
Args:
state (KeplerianState | CartesianState | None):
The state of the particle. None if x and v are provided.
time (Time | None):
The time of the particle's state. None if state is provided, since that
will have its own time baked-in.
x (jnp.ndarray | None):
The 3D barycentric cartesian position of the particle in AU. None if
state is provided.
v (jnp.ndarray | None):
The 3D barycentric cartesian velocity of the particle in AU/day. None if
state is provided.
observations (Observations | None):
Optional Observations associated with the particle. Necessary if fitting
or evaluating likelihoods.
name (str):
The name of the particle. Defaults to "".
gravity (str | Callable):
The gravitational acceleration function to use when integrating the
particle's orbit. Defaults to "default solar system", which corresponds
to parameterized post-Newtonian interactions with the 10 bodies in the
JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest
asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be
a jax.tree_util.Partial object that follows the same signature as the
acceleration functions in jorbit.accelerations. Other string options are
"newtonian planets", "newtonian solar system", "gr planets", "gr
solar system", and "keplerian" (pure 2-body Keplerian propagation with
no ephemeris or integrator setup).
de_ephemeris_version (str | None):
Which version of the JPL DE ephemeris to use for perturber positions
when using one of the built-in gravity models. Accepts either "440" or
"430", default is "440". Ignored when gravity is "keplerian".
integrator (str):
The integrator to use for the particle. Choices are "ias15", which is a
15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which
are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed
step sizes. Defaults to "ias15". Ignored when gravity is "keplerian".
earliest_time (Time):
The earliest time we expect to integrate the particle to. Defaults to
Time("1980-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
latest_time (Time):
The latest time we expect to integrate the particle to. Defaults to
Time("2050-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
fit_seed (KeplerianState | CartesianState | None):
A seed for fitting the orbit of the particle. If None, a seed will be
generated from the observations if they exist. Otherwise, a circular
orbit with semi-major axis 2.5 AU will be used.
max_step_size (u.Quantity, optional):
The fixed step size to use for leapfrog integrators. Required if
integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15".
Note that this is the maximum step size; the actual step size may be
smaller to ensure that the particle lands exactly on the requested
output times, and that the step size may change if the spacing between
output times is not constant. Defaults to None.
step_scheduler (str):
The scheduler used by IAS15 for picking the next proposed step size.
Choices are "prs23" (Pham+ 2023 controller, default) or "global"
(the controller from the original IAS15 paper). Used consistently
by ``integrate``, ``integrate_or_interpolate``, ``ephemeris``, and
the residuals/loglike closures. Ignored when gravity is "keplerian"
or for leapfrog integrators.
"""
self._observations = observations
self._earliest_time = earliest_time
self._latest_time = latest_time
self._de_ephemeris_version = de_ephemeris_version
self._is_keplerian = gravity == "keplerian"
self._step_scheduler = self._resolve_step_scheduler(step_scheduler)
self.gravity = gravity
# Preserve the original gravity spec (string or user-supplied Partial) so
# max_likelihood can pass it to the result Particle rather than self.gravity.
# self.gravity is overwritten by _setup_acceleration_func with a Partial that
# already has t_ref_jd baked in; passing *that* Partial to a new Particle would
# re-wrap it a second time and double-count the offset (see BUG_FIXES_MAY2026).
self._gravity_str = gravity
state = deepcopy(state) if state is not None else None
# self._time is kept at jnp.array(0.0) internally; absolute time lives in
# self._t_ref_astropy / self._t_ref_jd. All JAX-visible times are offsets
# (in days) from _t_ref_jd, which keeps step-boundary and interpolation
# arithmetic well-conditioned at decadal timescales.
(
self._x,
self._v,
self._time,
self._t_ref_astropy,
self._t_ref_jd,
self._cartesian_state,
self._keplerian_state,
self._name,
self._acc_func_kwargs,
) = self._setup_state(x, v, state, time, name)
if self._is_keplerian:
self.gravity = "keplerian"
self._integrator_state = None
self._integrator = None
self._integrator_method = "keplerian"
self._max_step_size = None
else:
self.gravity = self._setup_acceleration_func(gravity)
self._integrator_state, self._integrator, self._forced_integrator = (
self._setup_integrator(integrator, max_step_size)
)
self._integrator_method = integrator
self._max_step_size = max_step_size
self._fit_seed = self._setup_fit_seed(fit_seed)
(
self.residuals,
self.loglike,
self.scipy_objective,
self.scipy_objective_grad,
) = self._setup_likelihood()
self.static_residuals = self._setup_default_static_residuals()
def __repr__(self) -> str:
"""Return a string representation of the Particle object."""
return f"Particle: {self._name}"
@property
def cartesian_state(self) -> CartesianState:
"""Return the Cartesian state of the particle.
The state is self-describing about its absolute epoch via the
``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0``
(the offset in the particle's internal frame) and ``time_reference``
is :attr:`t_ref_jd` (the absolute JD anchor).
"""
return self._cartesian_state
@property
def keplerian_state(self) -> KeplerianState:
"""Return the Keplerian state of the particle.
The state is self-describing about its absolute epoch via the
``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0``
(the offset in the particle's internal frame) and ``time_reference``
is :attr:`t_ref_jd` (the absolute JD anchor).
"""
return self._keplerian_state
@property
def observations(self) -> Observations | None:
"""Return the observations associated with the particle."""
return self._observations
@property
def t_ref(self) -> Time:
"""Reference time (astropy Time, TDB) — the particle's epoch.
All JAX-visible times inside the Particle are offsets in days from this
reference, which keeps the integrator's internal arithmetic
well-conditioned at decadal timescales.
"""
return self._t_ref_astropy
@property
def t_ref_jd(self) -> float:
"""Reference time as a float JD (TDB), matching ``t_ref``."""
return self._t_ref_jd
###############
# SETUP METHODS
###############
def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray:
"""Convert user-provided query times to offsets from ``self._t_ref_astropy``.
Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair
through the subtraction, so the returned offsets retain sub-ns precision
regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs
are assumed to be absolute JD (TDB) and are subtracted in float64 — this
is Sterbenz-exact when the magnitudes match, but the offset precision is
capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities).
"""
if isinstance(times, Time):
return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day))
return jnp.asarray(times) - self._t_ref_jd
def _observations_times_as_offsets(self) -> jnp.ndarray:
"""Return the observation times as offsets from ``self._t_ref_astropy``.
Prefers the full-precision astropy Time stored on the Observations object
when available (sub-ns precision preserved). Falls back to subtracting
:attr:`t_ref_jd` from the float-JD array, which is Sterbenz-exact but
inherits the ulp(JD) ≈ 40 μs quantization of the stored float-JD.
"""
obs = self._observations
times_astropy = getattr(obs, "_times_astropy", None)
if times_astropy is not None:
return jnp.asarray(
(times_astropy.tdb - self._t_ref_astropy).to_value(u.day)
)
return jnp.asarray(obs.times) - self._t_ref_jd
def _setup_state(
self,
x: jnp.ndarray | None,
v: jnp.ndarray | None,
state: CartesianState | KeplerianState | None,
time: Time,
name: str,
) -> tuple:
if state is not None:
assert time is None, "Cannot provide both state and time"
# The state self-describes its absolute epoch as the sum of its
# offset (relative_time) and its anchor (time_reference). We use
# the sum as the Particle's absolute reference time, then rebase
# the stored state to relative_time=0 against that new anchor.
time = state.relative_time + state.time_reference
assert time is not None, "Must provide an epoch for the particle"
# Build the reference time pair (astropy Time at full precision + float JD)
# and then set the Particle's internal epoch to 0.0 in the offset frame.
if isinstance(time, Time):
t_ref_astropy = time.tdb
else:
t_ref_astropy = Time(float(time), format="jd", scale="tdb")
t_ref_jd = jnp.array(float(t_ref_astropy.tdb.jd))
internal_time = jnp.array(0.0)
if state is not None:
assert x is None and v is None, "Cannot provide both state and x, v"
# to_keplerian() and to_cartesian() are documented not to propagate `cov`
# because the covariance is parameterisation-specific. Save and re-attach
# it to the same-type output state so callers who set cov on the input
# state can retrieve it on particle.cartesian_state / particle.keplerian_state.
_input_is_keplerian = isinstance(state, KeplerianState)
_input_cov = state.cov
state = state.to_cartesian()
if state.x.ndim != 2:
state.x = state.x[None, :]
state.v = state.v[None, :]
state.relative_time = internal_time
state.time_reference = t_ref_jd
keplerian_state = state.to_keplerian()
cartesian_state = state.to_cartesian()
if _input_is_keplerian:
keplerian_state.cov = _input_cov
else:
cartesian_state.cov = _input_cov
x = state.x.flatten()
v = state.v.flatten()
elif x is not None:
assert v is not None, "Must provide both x and v"
x = x.flatten()
v = v.flatten()
cartesian_state = CartesianState(
x=jnp.array([x]),
v=jnp.array([v]),
relative_time=internal_time,
time_reference=t_ref_jd,
acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2},
)
keplerian_state = cartesian_state.to_keplerian()
else:
raise ValueError(
"time must be either astropy.time.Time or float (interpreted as JD in"
" TDB)"
)
if name == "":
name = "unnamed"
acc_func_kwargs = cartesian_state.acceleration_func_kwargs
return (
x,
v,
internal_time,
t_ref_astropy,
t_ref_jd,
cartesian_state,
keplerian_state,
name,
acc_func_kwargs,
)
def _setup_acceleration_func(self, gravity: str) -> Callable:
if isinstance(gravity, jax.tree_util.Partial):
# The wrapper converts from the Particle's internal offset frame
# (state.time = days since self._t_ref_jd) to absolute JD before
# calling the user's function, which must expect absolute JD.
#
# CONTRACT: the custom function must be built with t_ref_jd=0 (or
# equivalently, must treat state.time as an absolute Julian Date).
# Do NOT pass a function returned by jorbit's factory functions
# (create_default_ephemeris_acceleration_func, etc.) that was
# built with a non-zero t_ref_jd — that would double-count the
# offset and silently query the ephemeris at the wrong time.
# The safe pattern is to build the function fresh with t_ref_jd=0:
# acc_func = create_default_ephemeris_acceleration_func(
# eph.processor, t_ref_jd=0.0
# )
user_func = gravity
t_ref_jd = self._t_ref_jd
def _wrapped_user_acc(state: SystemState) -> jnp.ndarray:
shifted = state.replace(time=state.time + t_ref_jd)
return user_func(shifted)
return jax.tree_util.Partial(_wrapped_user_acc)
assert self._de_ephemeris_version in ["440", "430"], (
"de_ephemeris_version must be either '440' or '430' if not using a custom "
"gravity function"
)
if gravity == "newtonian planets":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default planets",
de_ephemeris_version=self._de_ephemeris_version,
)
acc_func = create_newtonian_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "newtonian solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
de_ephemeris_version=self._de_ephemeris_version,
)
acc_func = create_newtonian_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "gr planets":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default planets",
de_ephemeris_version=self._de_ephemeris_version,
)
acc_func = create_gr_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "gr solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
de_ephemeris_version=self._de_ephemeris_version,
)
acc_func = create_gr_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
elif gravity == "default solar system":
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
ssos="default solar system",
de_ephemeris_version=self._de_ephemeris_version,
)
acc_func = create_default_ephemeris_acceleration_func(
eph.processor, t_ref_jd=self._t_ref_jd
)
else:
raise ValueError(
f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', "
"'newtonian solar system', 'gr planets', 'gr solar system', "
"'default solar system', 'keplerian'. For a custom acceleration function, "
"pass a jax.tree_util.Partial."
)
return acc_func
@staticmethod
def _resolve_step_scheduler(name: str) -> Callable:
"""Translate a string name to a JIT-friendly Partial step scheduler."""
if name.lower() == "prs23":
return jax.tree_util.Partial(next_proposed_dt_PRS23)
elif name.lower() == "global":
return jax.tree_util.Partial(next_proposed_dt_global)
raise ValueError(
f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'."
)
def _setup_integrator(
self, integrator: str, max_step_size: u.Quantity | None
) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]:
if integrator == "ias15":
assert (
max_step_size is None
), "max_step_size should not be provided for IAS15 integrator."
a0 = self.gravity(self._cartesian_state.to_system())
integrator_state = initialize_ias15_integrator_state(a0)
integrator = jax.tree_util.Partial(ias15_evolve)
forced_integrator = jax.tree_util.Partial(ias15_evolve_forced_landing)
elif integrator in ["Y4", "Y6", "Y8"]:
assert (
max_step_size is not None
), "Must provide max_step_size for leapfrog integrators."
dt = max_step_size.to(u.day).value
if integrator == "Y4":
c = Y4_C
d = Y4_D
elif integrator == "Y6":
c = Y6_C
d = Y6_D
elif integrator == "Y8":
c = Y8_C
d = Y8_D
integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d)
integrator = jax.tree_util.Partial(leapfrog_evolve)
forced_integrator = integrator
return integrator_state, integrator, forced_integrator
def _setup_fit_seed(
self, fit_seed: KeplerianState | CartesianState | None
) -> KeplerianState | CartesianState | None:
if self._observations is None:
return None
if isinstance(fit_seed, (CartesianState, KeplerianState)):
return fit_seed
if len(self._observations) >= 3:
mean_time = jnp.mean(self._observations.times)
mid_idx = jnp.argmin(jnp.abs(self._observations.times - mean_time))
fit_seed = gauss_method_orbit(
self._observations[0]
+ self._observations[mid_idx]
+ self._observations[-1]
)
if fit_seed.to_keplerian().ecc > 1:
warnings.warn(
"Warning: initial Gauss's method fit produced an unbound orbit",
RuntimeWarning,
stacklevel=2,
)
else:
# Pass the Particle's reference epoch (not the observation time)
# so that the returned state's x, v represent the object at
# self._t_ref_jd — the same epoch the optimizer will use as its
# starting point. Using obs.times[0] instead would yield x, v at
# a different epoch that the optimizer would then misinterpret.
fit_seed = simple_circular(
self._observations.ra[0],
self._observations.dec[0],
semi=2.5,
time=self._t_ref_jd,
)
return fit_seed
def _setup_likelihood(self) -> tuple[Callable, Callable, Callable, Callable]:
if self._observations is None:
return None, None, None, None
obs_times = self._observations_times_as_offsets()
if self._is_keplerian:
times = obs_times
residuals = jax.tree_util.Partial(
_keplerian_residuals,
times,
self._observations.observer_positions,
self._observations.ra,
self._observations.dec,
)
ll = jax.tree_util.Partial(
_keplerian_loglike,
times,
self._observations.observer_positions,
self._observations.ra,
self._observations.dec,
self._observations.inv_cov_matrices,
self._observations.cov_log_dets,
)
# Keplerian propagation is fully JIT-able, reverse mode works natively
residuals = jax.jit(residuals)
loglike = jax.jit(ll)
else:
if self._integrator_method == "ias15":
times = obs_times
inds = jnp.arange(len(obs_times))
elif self._integrator_method in ["Y4", "Y6", "Y8"]:
times, inds = create_leapfrog_times(
self._cartesian_state.relative_time,
obs_times,
self._max_step_size,
)
residuals = jax.tree_util.Partial(
_residuals,
times,
self.gravity,
self._integrator,
self._integrator_state,
self._observations.observer_positions,
self._observations.ra,
self._observations.dec,
inds,
step_scheduler=self._step_scheduler,
)
ll = jax.tree_util.Partial(
_loglike,
times,
self.gravity,
self._integrator,
self._integrator_state,
self._observations.observer_positions,
self._observations.ra,
self._observations.dec,
self._observations.inv_cov_matrices,
self._observations.cov_log_dets,
inds,
step_scheduler=self._step_scheduler,
)
# since we've gone with the while loop version of the ias15 integrator,
# can no longer use reverse mode. But, actually specifying forward mode
# everywhere is annoying, so we're going to re-define a custom vjp for
# "reverse" mode that's actually just forward mode
@jax.custom_vjp
def loglike(params: CartesianState | KeplerianState) -> float:
return ll(params)
def loglike_fwd(params: CartesianState | KeplerianState) -> tuple:
output = ll(params)
jac = jax.jacfwd(ll)(params)
return output, (jac,)
def loglike_bwd(res: tuple, g: float) -> float:
jac = res
val = jax.tree.map(lambda x: x * g, jac)
return val
loglike.defvjp(loglike_fwd, loglike_bwd)
residuals = jax.jit(residuals)
loglike = jax.jit(loglike)
def scipy_objective(x: jnp.ndarray) -> float:
c = CartesianState(
x=jnp.array([x[:3]]),
v=jnp.array([x[3:]]),
relative_time=self._time,
time_reference=self._t_ref_jd,
acceleration_func_kwargs=self._acc_func_kwargs,
)
return -loglike(c)
def scipy_grad(x: jnp.ndarray) -> jnp.ndarray:
c = CartesianState(
x=jnp.array([x[:3]]),
v=jnp.array([x[3:]]),
relative_time=self._time,
time_reference=self._t_ref_jd,
acceleration_func_kwargs=self._acc_func_kwargs,
)
c_grad = jax.grad(loglike)(c)
g = jnp.concatenate([c_grad.x.flatten(), c_grad.v.flatten()])
return -g
return residuals, loglike, scipy_objective, scipy_grad
def _setup_default_static_residuals(self) -> Callable:
if self._observations is None or self._is_keplerian:
return None
precomputed_data = precompute_likelihood_data(self, self._step_scheduler)
static_residuals_func = create_default_static_residuals_func(precomputed_data)
return static_residuals_func
################
# PUBLIC METHODS
################
[docs]
@classmethod
def from_horizons(
cls,
name: str,
time: Time,
observations: Observations | None = None,
gravity: str | Callable = "default solar system",
integrator: str = "ias15",
earliest_time: Time = Time("1980-01-01"),
latest_time: Time = Time("2050-01-01"),
fit_seed: KeplerianState | CartesianState | None = None,
max_step_size: u.Quantity | None = None,
de_ephemeris_version: str | None = "440",
) -> Particle:
"""Query JPL Horizons for an SSOs state at a given time and create a Particle object.
Args:
name (str):
The name of the SSO to query. Can be a string or an integer.
time (Time):
The time to query the SSO at.
observations (Observations | None):
Optional Observations associated with the particle. Necessary if fitting
or evaluating likelihoods.
gravity (str | Callable):
The gravitational acceleration function to use when integrating the
particle's orbit. Defaults to "default solar system", which corresponds
to parameterized post-Newtonian interactions with the 10 bodies in the
JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest
asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be
a jax.tree_util.Partial object that follows the same signature as the
acceleration functions in jorbit.accelerations. Other string options are
"newtonian planets", "newtonian solar system", "gr planets", and "gr
solar system".
integrator (str):
The integrator to use for the particle. Choices are "ias15", which is a
15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which
are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed
step sizes. Defaults to "ias15".
earliest_time (Time):
The earliest time we expect to integrate the particle to. Defaults to
Time("1980-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
latest_time (Time):
The latest time we expect to integrate the particle to. Defaults to
Time("2050-01-01"). Larger time windows will result in larger in-memory
ephemeris objects.
fit_seed (KeplerianState | CartesianState | None):
A seed for fitting the orbit of the particle. If None, a seed will be
generated from the observations if they exist. Otherwise, a circular
orbit with semi-major axis 2.5 AU will be used.
max_step_size (u.Quantity, optional):
The fixed step size to use for leapfrog integrators. Required if
integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15".
Note that this is the maximum step size; the actual step size may be
smaller to ensure that the particle lands exactly on the requested
output times, and that the step size may change if the spacing between
output times is not constant. Defaults to None.
de_ephemeris_version (str | None):
Which version of the JPL DE ephemeris to use for perturber positions.
When using `from_horizons` to pull an initial state, only DE440 is
supported. Defaults to "440", will error on anything else as a
safeguard.
Returns:
Particle:
A Particle object representing the SSO at the given time.
"""
if de_ephemeris_version != "440":
raise ValueError(
"Only DE440 ephemeris version is supported when pulling "
"an initial state from JPL Horizons."
)
data = horizons_bulk_vector_query(target=name, center="500@0", times=time)
x0 = jnp.array([data["x"][0], data["y"][0], data["z"][0]])
v0 = jnp.array([data["vx"][0], data["vy"][0], data["vz"][0]])
return cls(
x=x0,
v=v0,
time=time,
observations=observations,
name=name,
gravity=gravity,
integrator=integrator,
earliest_time=earliest_time,
latest_time=latest_time,
fit_seed=fit_seed,
max_step_size=max_step_size,
)
def _integrate_base(
self,
times: Time,
state: CartesianState | KeplerianState | None = None,
forced_landing: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
if self._is_keplerian:
if state is not None:
state = state.to_cartesian()
# relative_time is already in this Particle's offset frame
# (provided state.time_reference matches self._t_ref_jd, which
# is the standard convention for states passed via state=).
x, v, t0 = (
state.x.flatten(),
state.v.flatten(),
state.relative_time,
)
else:
x, v, t0 = self._x, self._v, self._time
times = self._times_to_offsets(times)
if times.shape == ():
times = jnp.array([times])
return *_keplerian_integrate(x, v, t0, times), None
if state is None:
state = self._cartesian_state
integrator_state = self._integrator_state
else:
a0 = self.gravity(state.to_system())
integrator_state = initialize_ias15_integrator_state(a0)
times = self._times_to_offsets(times)
if times.shape == ():
times = jnp.array([times])
if self._integrator_method in ["Y4", "Y6", "Y8"]:
times, inds = create_leapfrog_times(
state.relative_time, times, self._max_step_size
)
else:
inds = jnp.arange(times.shape[0])
integrator = self._forced_integrator if forced_landing else self._integrator
positions, velocities, _final_system_state, _final_integrator_state, steps = (
_integrate(
times,
state,
self.gravity,
integrator,
integrator_state,
inds,
self._step_scheduler,
)
)
return positions[:, 0, :], velocities[:, 0, :], steps
[docs]
def integrate(
self,
times: Time,
state: CartesianState | KeplerianState | None = None,
return_steps: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Integrate the particle's orbit to specified times, landing exactly on each one.
Note that this method does not change the state of the particle. It returns the
positions and velocities of the particle at the given times, but the particle
itself is not changed.
Args:
times (Time | jnp.ndarray):
The times to integrate to. Can be a single time or an array of times.
If provided as a jnp.array, the entries are assumed to be in TDB JD.
state (CartesianState | None):
The state to integrate from. If None, the particle's current state will
be used. Usually not necessary to provide this.
return_steps (bool):
Whether to return the number of steps taken to reach each output time.
If True, the method returns a tuple of (positions, velocities, steps).
If False, only returns (positions, velocities). Defaults to False.
Returns:
tuple[jnp.ndarray, jnp.ndarray]:
The positions of the particle at the given times, in AU, and the
The velocities of the particle at the given times, in AU/day. If
return_steps is True, also returns an array of the number of steps taken
to reach each output time.
"""
x, v, steps = self._integrate_base(times, state, forced_landing=True)
if return_steps:
return x, v, steps
return x, v
[docs]
def integrate_or_interpolate(
self,
times: Time,
state: CartesianState | KeplerianState | None = None,
return_steps: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Integrate the particle's orbit to specified times, overshooting and 'interpolating' if necessary.
Note that this method does not change the state of the particle. It returns the
positions and velocities of the particle at the given times, but the particle
itself is not changed.
Args:
times (Time | jnp.ndarray):
The times to integrate to. Can be a single time or an array of times.
If provided as a jnp.array, the entries are assumed to be in TDB JD.
state (CartesianState | None):
The state to integrate from. If None, the particle's current state will
be used. Usually not necessary to provide this.
return_steps (bool):
Whether to return the number of steps taken to reach each output time.
If True, the method returns a tuple of (positions, velocities, steps).
If False, only returns (positions, velocities). Defaults to False.
Returns:
tuple[jnp.ndarray, jnp.ndarray]:
The positions of the particle at the given times, in AU, and the
The velocities of the particle at the given times, in AU/day. If
return_steps is True, also returns an array of the number of steps taken
to reach each output time.
"""
x, v, steps = self._integrate_base(times, state, forced_landing=False)
if return_steps:
return x, v, steps
return x, v
[docs]
def ephemeris(
self,
times: Time,
observer: str | jnp.ndarray,
state: CartesianState | KeplerianState | None = None,
interpolate: bool = True,
uncertainty: bool = False,
) -> SkyCoord | tuple[SkyCoord, jnp.ndarray]:
"""Compute an ephemeris for the particle.
Args:
times (Time | jnp.ndarray):
The times to compute the ephemeris for. Can be a single time or an array
of times. If provided as a jnp.array, the entries are assumed to be in
TDB JD.
observer (str | jnp.ndarray):
The observer to compute the ephemeris for. Can be a string representing
an observatory name, or a 3D position vector in AU. For more info on
acceptable strings, see the get_observer_positions function.
state (CartesianState | None):
The state to compute the ephemeris from. If None, the particle's current
state will be used. Usually not necessary to provide this.
interpolate (bool):
Whether to use `integrate` or `integrate_or_interpolate` for the
underlying integrations.
uncertainty (bool):
If True, also propagate the 6x6 covariance matrix stored on the state's
``cov`` field onto the sky plane via forward-mode autodiff (linear error
propagation). The covariance is propagated in whichever parameterization
the input state uses (Cartesian or Keplerian); no automatic conversion
between parameterizations is performed. Expect roughly a ``6x`` cost
relative to the nominal call because ``jax.jacfwd`` performs one JVP
evaluation per input dimension. Defaults to False.
Returns:
coords (SkyCoord | tuple[SkyCoord, jnp.ndarray]):
If ``uncertainty=False`` (default), returns a SkyCoord with the
ephemeris of the particle at the given times, in ICRS coordinates, as
seen from that specific observer and correcting for light travel time.
If ``uncertainty=True``, returns a tuple with the same SkyCoord and the
propagated ``(N, 2, 2)`` sky-plane covariance in ``arcsec**2``,
(the propagated ``(N, 2, 2)`` sky-plane covariance in ``arcsec**2``,
axis order (RA, Dec)).
"""
if isinstance(observer, str):
observer_positions = get_observer_positions(
times, observer, self._de_ephemeris_version
)
else:
observer_positions = observer
if uncertainty:
cov_state = (
state
if state is not None
else (
self._keplerian_state
if self._is_keplerian
else self._cartesian_state
)
)
cov = cov_state.cov
if cov.shape != (6, 6):
raise ValueError(
"uncertainty=True requires a (6, 6) covariance matrix on the "
f"state's `cov` field, but got shape {tuple(cov.shape)}. "
"Build a CartesianState or KeplerianState with `cov=...` and pass "
"it via the `state=` keyword (or set it on the particle's "
"internal state)."
)
if self._is_keplerian:
offsets = self._times_to_offsets(times)
if offsets.shape == ():
offsets = jnp.array([offsets])
if uncertainty:
kep_state = state if state is not None else self._keplerian_state
ras, decs, cov_radec = _keplerian_ephem_with_cov(
kep_state, offsets, observer_positions, cov
)
return (SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs"), cov_radec)
if state is not None:
state = state.to_cartesian()
x, v, t0 = (
state.x.flatten(),
state.v.flatten(),
state.relative_time,
)
else:
x, v, t0 = self._x, self._v, self._time
ras, decs = _keplerian_ephem(x, v, t0, offsets, observer_positions)
return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
if state is None:
state = self._cartesian_state
integrator_state = self._integrator_state
else:
a0 = self.gravity(state.to_system())
integrator_state = initialize_ias15_integrator_state(a0)
times = self._times_to_offsets(times)
if times.shape == ():
times = jnp.array([times])
if self._integrator_method in ["Y4", "Y6", "Y8"]:
times, inds = create_leapfrog_times(
state.relative_time, times, self._max_step_size
)
else:
inds = jnp.arange(times.shape[0])
integrator = self._integrator if interpolate else self._forced_integrator
# IAS15 with interpolation has dense-output b-coefficients available, so
# we can use them to evaluate the polynomial at light-travel-delayed times in on_sky's
# LTT loop instead of a constant-acceleration Taylor expansion. All other
# paths (leapfrog, IAS15 forced-landing) fall back to the original Taylor.
use_dense_ltt = interpolate and self._integrator_method not in (
"Y4",
"Y6",
"Y8",
)
if uncertainty:
if use_dense_ltt:
ras, decs, cov_radec = _ephem_ias15_with_cov(
times,
state,
self.gravity,
observer_positions,
inds,
self._step_scheduler,
cov,
)
else:
ras, decs, cov_radec = _ephem_with_cov(
times,
state,
self.gravity,
integrator,
integrator_state,
observer_positions,
inds,
self._step_scheduler,
cov,
)
return (SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs"), cov_radec)
if use_dense_ltt:
ras, decs = _ephem_ias15(
times,
state,
self.gravity,
integrator_state,
observer_positions,
inds,
self._step_scheduler,
)
else:
ras, decs = _ephem(
times,
state,
self.gravity,
integrator,
integrator_state,
observer_positions,
inds,
self._step_scheduler,
)
return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
[docs]
def max_likelihood(
self,
fit_seed: CartesianState | KeplerianState | None = None,
verbose: bool = False,
) -> Particle:
"""Find the maximum likelihood orbit for the particle.
Args:
fit_seed (CartesianState | KeplerianState | None):
A seed for fitting the orbit of the particle. If None, a seed will be
generated from the observations if they exist. Otherwise, a circular
orbit with semi-major axis 2.5 AU will be used.
verbose (bool):
Whether to print the optimization progress. Defaults to False.
Returns:
Particle:
A new Particle object whose state matches the maximum likelihood orbit.
"""
if self.loglike is None:
raise ValueError("No observations provided, cannot fit an orbit")
if fit_seed is None:
fit_seed = self._fit_seed
x0 = jnp.concatenate(
[
fit_seed.to_cartesian().x.flatten(),
fit_seed.to_cartesian().v.flatten(),
]
)
if self._is_keplerian:
# Keplerian propagation produces NaN for hyperbolic orbits, so we
# guard the objective and gradient. We also scale the parameters so
# the initial identity-Hessian step in L-BFGS-B is well-sized.
scale = jnp.abs(x0) + 1e-10
def _scaled_objective(x_scaled: jnp.ndarray) -> float:
val = self.scipy_objective(x_scaled * scale)
return float(jnp.where(jnp.isnan(val), 1e30, val))
def _scaled_grad(x_scaled: jnp.ndarray) -> jnp.ndarray:
g = self.scipy_objective_grad(x_scaled * scale) * scale
return jnp.where(jnp.isnan(g), 0.0, g)
result = minimize(
fun=_scaled_objective,
x0=x0 / scale,
jac=_scaled_grad,
method="L-BFGS-B",
options={
"disp": verbose,
"maxls": 100,
"maxcor": 100,
"maxfun": 5000,
"maxiter": 1000,
"ftol": 1e-14,
},
)
result.x = result.x * scale
else:
result = minimize(
fun=lambda x: self.scipy_objective(x),
x0=x0,
jac=lambda x: self.scipy_objective_grad(x),
method="L-BFGS-B",
options={
"disp": verbose,
"maxls": 100,
"maxcor": 100,
"maxfun": 5000,
"maxiter": 1000,
"ftol": 1e-14,
},
)
if result.success:
c = CartesianState(
x=jnp.array([result.x[:3]]),
v=jnp.array([result.x[3:]]),
relative_time=self._time,
time_reference=self._t_ref_jd,
acceleration_func_kwargs=self._acc_func_kwargs,
)
if c.to_keplerian().ecc > 1:
warnings.warn(
"Warning: max_likelihood fit produced an unbound orbit",
RuntimeWarning,
stacklevel=2,
)
return Particle(
x=result.x[:3],
v=result.x[3:],
time=self._t_ref_astropy,
state=None,
observations=self._observations,
name=self._name,
gravity=self._gravity_str,
de_ephemeris_version=self._de_ephemeris_version,
integrator=self._integrator_method,
earliest_time=self._earliest_time,
latest_time=self._latest_time,
fit_seed=self._fit_seed,
max_step_size=self._max_step_size,
)
else:
raise ValueError("Failed to converge")
[docs]
def is_observable(
self,
times: Time,
observer: str | jnp.ndarray,
sun_limit: float = 20.0,
ephem: SkyCoord | None = None,
return_angle: bool = False,
) -> jnp.ndarray:
"""Check whether a particle is observable or too close to the Sun.
Args:
times (Time):
The times to check the observability.
observer (str | jnp.ndarray):
The observer/observatory making the observations. Can be a string for
the name/code of an observatory, or a jnp.array of 3D barycentric ICRS
positions in AU.
sun_limit (float):
The minimum allowed angular separation from the Sun, in degrees.
Defaults to 20 degrees.
ephem (SkyCoord | None, optional):
Optionally, the ephemeris of the particle at the given times. If not
provided, will be computed using the ephemeris method. Helpful if you've
already computed the ephemeris and want to avoid doing it twice.
return_angle (bool, optional):
If True, will return the angles to the Sun in degrees, not the mask.
Default is False.
Returns:
jnp.ndarray:
A boolean array indicating whether the particle is observable at each
time (True) or too close to the Sun (False).
"""
if isinstance(observer, str):
observer_pos = get_observer_positions(
times, observer, self._de_ephemeris_version
)
else:
observer_pos = observer
if ephem is None:
ephem = self.ephemeris(times, observer)
if isinstance(times, Time):
times = jnp.array(times.tdb.jd)
if times.shape == ():
times = jnp.array([times])
eph = Ephemeris(
earliest_time=self._earliest_time,
latest_time=self._latest_time,
)
sun_pos = jax.vmap(eph.processor.state)(times)[0][:, 0, :]
eph_unit_vec = jnp.array(ephem.cartesian.xyz).T
# want the angle between the eph_unit_vec, with its tail on the observer, and the vector from the observer to the sun
obs_to_sun_vec = sun_pos - observer_pos
obs_to_sun_unit_vec = (
obs_to_sun_vec / jnp.linalg.norm(obs_to_sun_vec, axis=1)[:, None]
)
angle = jnp.arccos(jnp.sum(obs_to_sun_unit_vec * eph_unit_vec, axis=1))
if return_angle:
return jnp.rad2deg(angle)
else:
return angle > jnp.deg2rad(sun_limit)
###########################
# EXTERNAL JITTED FUNCTIONS
###########################
@jax.jit
def _integrate(
times: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
acc_func: Callable,
integrator_func: Callable,
integrator_state: IAS15IntegratorState | LeapfrogIntegratorState,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[
jnp.ndarray,
jnp.ndarray,
SystemState,
IAS15IntegratorState | LeapfrogIntegratorState,
]:
state = particle_state.to_system()
positions, velocities, final_system_state, final_integrator_state, steps = (
integrator_func(state, acc_func, times, integrator_state, step_scheduler)
)
return (
positions[relevant_inds],
velocities[relevant_inds],
final_system_state,
final_integrator_state,
steps,
)
@jax.jit
def _ephem(
times: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
acc_func: Callable,
integrator_func: Callable,
integrator_state: IAS15IntegratorState | LeapfrogIntegratorState,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[jnp.ndarray, jnp.ndarray]:
positions, velocities, _, _, _ = _integrate(
times,
particle_state,
acc_func,
integrator_func,
integrator_state,
relevant_inds,
step_scheduler,
)
def scan_func(carry: None, scan_over: tuple) -> tuple[None, tuple]:
position, velocity, time, observer_position = scan_over
ra, dec = on_sky(position, velocity, time, observer_position, acc_func)
return None, (ra, dec)
_, (ras, decs) = jax.lax.scan(
scan_func,
None,
(
positions[:, 0, :],
velocities[:, 0, :],
times[relevant_inds],
observer_positions,
),
)
return ras, decs
@jax.jit
def _ephem_ias15(
times: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
acc_func: Callable,
integrator_state: IAS15IntegratorState,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""IAS15-only variant of ``_ephem`` that uses dense-output b-coefficients for LTT.
The ``on_sky`` light-travel-time correction defaults to a 2nd-order Taylor with a
constant acceleration. For IAS15, we already have the converged 7th-order
polynomial per step (the "dense output"). This variant builds a per-observation
closure that evaluates that polynomial at the light-travel-delayed time, replacing the
Taylor expansion. Only used when ``interpolate=True``.
"""
state = particle_state.to_system()
(
_positions,
_velocities,
_final_system_state,
_final_integrator_state,
_iter_num,
b_buf,
a0_buf,
x0_buf,
v0_buf,
dts_buf,
_t_step_starts,
step_indices,
h_values,
) = ias15_evolve_with_dense_output(
state, acc_func, times, integrator_state, step_scheduler
)
# Restrict to observation times only (drops any intermediate landing times).
obs_times = times[relevant_inds]
obs_step_indices = step_indices[relevant_inds]
obs_h_values = h_values[relevant_inds]
# Per-obs dense-output gather (single tracer at index 0).
b_per_obs = b_buf[obs_step_indices][:, :, 0, :]
a0_per_obs = a0_buf[obs_step_indices][:, 0, :]
x0_per_obs = x0_buf[obs_step_indices][:, 0, :]
v0_per_obs = v0_buf[obs_step_indices][:, 0, :]
dt_per_obs = dts_buf[obs_step_indices]
def per_obs_on_sky(
b_step: jnp.ndarray,
a0_step: jnp.ndarray,
x0_step: jnp.ndarray,
v0_step: jnp.ndarray,
dt_step: jnp.ndarray,
h_obs: jnp.ndarray,
time: jnp.ndarray,
observer_pos: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
propagator = make_ltt_propagator(
b_step, a0_step, x0_step, v0_step, dt_step, h_obs
)
x_obs = propagator(jnp.array(0.0))
return on_sky(
x_obs,
jnp.zeros(3),
time,
observer_pos,
acc_func,
ltt_position_fn=propagator,
)
ras, decs = jax.vmap(per_obs_on_sky, in_axes=(0, 0, 0, 0, 0, 0, 0, 0))(
b_per_obs,
a0_per_obs,
x0_per_obs,
v0_per_obs,
dt_per_obs,
obs_h_values,
obs_times,
observer_positions,
)
return ras, decs
def _state_vec_to_xv(
state_vec: jnp.ndarray, is_keplerian_param: bool
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Unpack a (6,) parameter vector to ICRS Cartesian position/velocity.
If ``is_keplerian_param``, the entries are (semi, ecc, inc, Omega, omega, nu)
and we convert via :func:`elements_to_cartesian` followed by an ecliptic ->
ICRS rotation. Otherwise the entries are interpreted directly as flat
Cartesian (x, y, z, vx, vy, vz) in ICRS.
Returns ``(x, v)`` each shaped ``(1, 3)``.
"""
if is_keplerian_param:
x_ecl, v_ecl = elements_to_cartesian(
state_vec[0:1],
state_vec[1:2],
state_vec[5:6],
state_vec[2:3],
state_vec[3:4],
state_vec[4:5],
TOTAL_SOLAR_SYSTEM_GM,
)
x = horizons_ecliptic_to_icrs(x_ecl)
v = horizons_ecliptic_to_icrs(v_ecl)
else:
x = state_vec[:3].reshape(1, 3)
v = state_vec[3:].reshape(1, 3)
return x, v
def _state_to_vec(state: CartesianState | KeplerianState) -> jnp.ndarray:
"""Flatten a CartesianState or KeplerianState to a (6,) parameter vector."""
if isinstance(state, KeplerianState):
return jnp.concatenate(
[
jnp.atleast_1d(state.semi),
jnp.atleast_1d(state.ecc),
jnp.atleast_1d(state.inc),
jnp.atleast_1d(state.Omega),
jnp.atleast_1d(state.omega),
jnp.atleast_1d(state.nu),
]
)
return jnp.concatenate([state.x.flatten(), state.v.flatten()])
def _cov_from_jacobian(
radec_fn: Callable,
nominal_vec: jnp.ndarray,
cov: jnp.ndarray,
N: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Linear error propagation via forward-mode AD.
``radec_fn`` must accept a ``(6,)`` parameter vector and return a flat
``(2N,)`` interleaved ``[ra0, dec0, ra1, dec1, ...]`` array (radians).
Returns ``(ra, dec, cov_radec)`` where ``cov_radec`` has shape ``(N, 2, 2)``
in ``arcsec**2``.
"""
radec_nominal = radec_fn(nominal_vec)
ras = radec_nominal[0::2]
decs = radec_nominal[1::2]
J = jax.jacfwd(radec_fn)(nominal_vec)
J_t = J.reshape(N, 2, 6)
cov_radec = jnp.einsum("nij,jk,nlk->nil", J_t, cov, J_t) * _RAD2ARCSEC_SQ
return ras, decs, cov_radec
@jax.jit
def _ephem_ias15_with_cov(
times: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
acc_func: Callable,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
cov: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""IAS15 dense-output ephemeris with sky-plane covariance via forward-mode AD."""
is_keplerian_param = isinstance(particle_state, KeplerianState)
def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray:
x, v = _state_vec_to_xv(state_vec, is_keplerian_param)
state = CartesianState(
x=x,
v=v,
relative_time=particle_state.relative_time,
time_reference=particle_state.time_reference,
acceleration_func_kwargs=particle_state.acceleration_func_kwargs,
)
a0 = acc_func(state.to_system())
integrator_state = initialize_ias15_integrator_state(a0)
ras, decs = _ephem_ias15(
times,
state,
acc_func,
integrator_state,
observer_positions,
relevant_inds,
step_scheduler,
)
return jnp.stack([ras, decs], axis=1).flatten()
nominal_vec = _state_to_vec(particle_state)
return _cov_from_jacobian(radec_fn, nominal_vec, cov, relevant_inds.shape[0])
@jax.jit
def _ephem_with_cov(
times: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
acc_func: Callable,
integrator_func: Callable,
integrator_state: IAS15IntegratorState | LeapfrogIntegratorState,
observer_positions: jnp.ndarray,
relevant_inds: jnp.ndarray,
step_scheduler: Callable,
cov: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Generic non-dense ephemeris with sky-plane covariance via forward-mode AD.
Handles both the IAS15 forced-landing path (``interpolate=False`` + IAS15) and
the leapfrog path. For IAS15, the integrator state must be re-initialized
inside the AD closure so the initial-acceleration entry tracks the perturbed
state vector; for leapfrog, ``LeapfrogIntegratorState`` is independent of
the dynamical state and is reused as-is.
"""
is_keplerian_param = isinstance(particle_state, KeplerianState)
reinit_ias15 = isinstance(integrator_state, IAS15IntegratorState)
def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray:
x, v = _state_vec_to_xv(state_vec, is_keplerian_param)
state = CartesianState(
x=x,
v=v,
relative_time=particle_state.relative_time,
time_reference=particle_state.time_reference,
acceleration_func_kwargs=particle_state.acceleration_func_kwargs,
)
if reinit_ias15:
a0 = acc_func(state.to_system())
local_integrator_state = initialize_ias15_integrator_state(a0)
else:
local_integrator_state = integrator_state
ras, decs = _ephem(
times,
state,
acc_func,
integrator_func,
local_integrator_state,
observer_positions,
relevant_inds,
step_scheduler,
)
return jnp.stack([ras, decs], axis=1).flatten()
nominal_vec = _state_to_vec(particle_state)
return _cov_from_jacobian(radec_fn, nominal_vec, cov, relevant_inds.shape[0])
@jax.jit
def _residuals(
times: jnp.ndarray,
gravity: Callable,
integrator: Callable,
integrator_state: IAS15IntegratorState | LeapfrogIntegratorState,
observer_positions: jnp.ndarray,
ra: jnp.ndarray,
dec: jnp.ndarray,
relevant_inds: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
step_scheduler: Callable,
) -> jnp.ndarray:
ras, decs = _ephem(
times,
particle_state,
gravity,
integrator,
integrator_state,
observer_positions,
relevant_inds,
step_scheduler,
)
xis_etas = jax.vmap(tangent_plane_projection)(ra, dec, ras, decs)
return xis_etas
# note: this external jitted function does not have fwd mode autodiff enforced, will
# break on reverse mode when using ias15
@jax.jit
def _loglike(
times: jnp.ndarray,
gravity: Callable,
integrator: Callable,
integrator_state: IAS15IntegratorState | LeapfrogIntegratorState,
observer_positions: jnp.ndarray,
ra: jnp.ndarray,
dec: jnp.ndarray,
inv_cov_matrices: jnp.ndarray,
cov_log_dets: jnp.ndarray,
relevant_inds: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
step_scheduler: Callable,
) -> float:
xis_etas = _residuals(
times,
gravity,
integrator,
integrator_state,
observer_positions,
ra,
dec,
relevant_inds,
particle_state,
step_scheduler,
)
quad = jnp.einsum("bi,bij,bj->b", xis_etas, inv_cov_matrices, xis_etas)
ll = jnp.sum(-0.5 * (2 * jnp.log(2 * jnp.pi) + cov_log_dets + quad))
return ll
###################################
# KEPLERIAN EXTERNAL JITTED FUNCTIONS
###################################
@jax.jit
def _keplerian_integrate(
x: jnp.ndarray,
v: jnp.ndarray,
t0: float,
times: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
x_ecl = icrs_to_horizons_ecliptic(x[None, :])
v_ecl = icrs_to_horizons_ecliptic(v[None, :])
positions_ecl, velocities_ecl = keplerian_propagate(
x_ecl, v_ecl, t0, times, TOTAL_SOLAR_SYSTEM_GM
)
positions = horizons_ecliptic_to_icrs(positions_ecl)
velocities = horizons_ecliptic_to_icrs(velocities_ecl)
return positions, velocities
@jax.jit
def _keplerian_on_sky(
x: jnp.ndarray,
v: jnp.ndarray,
time: float,
observer_position: jnp.ndarray,
) -> tuple[float, float]:
r = jnp.linalg.norm(x)
a0 = -TOTAL_SOLAR_SYSTEM_GM * x / (r**3)
xz = x
for _ in range(3):
earth_distance = jnp.linalg.norm(xz - observer_position)
dt = -earth_distance * INV_SPEED_OF_LIGHT
xz = x + v * dt + 0.5 * a0 * dt * dt
X = xz - observer_position
calc_ra = jnp.mod(jnp.arctan2(X[1], X[0]) + 2 * jnp.pi, 2 * jnp.pi)
calc_dec = jnp.pi / 2 - jnp.arccos(X[-1] / jnp.linalg.norm(X))
return calc_ra, calc_dec
@jax.jit
def _keplerian_ephem(
x: jnp.ndarray,
v: jnp.ndarray,
t0: float,
times: jnp.ndarray,
observer_positions: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
positions, velocities = _keplerian_integrate(x, v, t0, times)
def scan_func(carry: None, scan_over: tuple) -> tuple[None, tuple]:
position, velocity, time, observer_position = scan_over
ra, dec = _keplerian_on_sky(position, velocity, time, observer_position)
return None, (ra, dec)
_, (ras, decs) = jax.lax.scan(
scan_func,
None,
(positions, velocities, times, observer_positions),
)
return ras, decs
@jax.jit
def _keplerian_ephem_with_cov(
particle_state: CartesianState | KeplerianState,
times: jnp.ndarray,
observer_positions: jnp.ndarray,
cov: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Keplerian-path ephemeris with sky-plane covariance via forward-mode AD.
Supports both Keplerian and Cartesian input parameterizations; the covariance
is propagated in whichever space the input state was supplied in.
"""
is_keplerian_param = isinstance(particle_state, KeplerianState)
t0 = particle_state.relative_time
def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray:
x, v = _state_vec_to_xv(state_vec, is_keplerian_param)
ras, decs = _keplerian_ephem(
x.flatten(), v.flatten(), t0, times, observer_positions
)
return jnp.stack([ras, decs], axis=1).flatten()
nominal_vec = _state_to_vec(particle_state)
return _cov_from_jacobian(radec_fn, nominal_vec, cov, times.shape[0])
@jax.jit
def _keplerian_residuals(
times: jnp.ndarray,
observer_positions: jnp.ndarray,
ra: jnp.ndarray,
dec: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
) -> jnp.ndarray:
x = particle_state.to_cartesian().x.flatten()
v = particle_state.to_cartesian().v.flatten()
t0 = particle_state.relative_time
ras, decs = _keplerian_ephem(x, v, t0, times, observer_positions)
xis_etas = jax.vmap(tangent_plane_projection)(ra, dec, ras, decs)
return xis_etas
@jax.jit
def _keplerian_loglike(
times: jnp.ndarray,
observer_positions: jnp.ndarray,
ra: jnp.ndarray,
dec: jnp.ndarray,
inv_cov_matrices: jnp.ndarray,
cov_log_dets: jnp.ndarray,
particle_state: CartesianState | KeplerianState,
) -> float:
xis_etas = _keplerian_residuals(times, observer_positions, ra, dec, particle_state)
quad = jnp.einsum("bi,bij,bj->b", xis_etas, inv_cov_matrices, xis_etas)
ll = jnp.sum(-0.5 * (2 * jnp.log(2 * jnp.pi) + cov_log_dets + quad))
return ll