Source code for jorbit.particle

"""The Particle class and its supporting functions."""

from __future__ import annotations

import jax

jax.config.update("jax_enable_x64", True)
import warnings
from collections.abc import Callable
from copy import deepcopy

import astropy.units as u
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time

# from jaxlib.xla_extension import PjitFunction
from scipy.optimize import minimize

from jorbit import Observations
from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_gr_ephemeris_acceleration_func,
    create_newtonian_ephemeris_acceleration_func,
)
from jorbit.astrometry.orbit_fit_seeds import gauss_method_orbit, simple_circular
from jorbit.astrometry.sky_projection import on_sky, tangent_plane_projection
from jorbit.astrometry.transformations import (
    elements_to_cartesian,
    horizons_ecliptic_to_icrs,
    icrs_to_horizons_ecliptic,
)
from jorbit.data.constants import (
    INV_SPEED_OF_LIGHT,
    SPEED_OF_LIGHT,
    TOTAL_SOLAR_SYSTEM_GM,
    Y4_C,
    Y4_D,
    Y6_C,
    Y6_D,
    Y8_C,
    Y8_D,
)
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
    create_leapfrog_times,
    ias15_evolve,
    ias15_evolve_forced_landing,
    ias15_evolve_with_dense_output,
    initialize_ias15_integrator_state,
    leapfrog_evolve,
    make_ltt_propagator,
    next_proposed_dt_global,
    next_proposed_dt_PRS23,
)
from jorbit.likelihoods.setup_static_likelihood import (
    create_default_static_residuals_func,
    precompute_likelihood_data,
)
from jorbit.utils.horizons import get_observer_positions, horizons_bulk_vector_query
from jorbit.utils.kepler import keplerian_propagate
from jorbit.utils.states import (
    CartesianState,
    IAS15IntegratorState,
    KeplerianState,
    LeapfrogIntegratorState,
    SystemState,
)

# Squared conversion factor from radians^2 to arcsec^2.
_RAD2ARCSEC_SQ = (180.0 * 3600.0 / jnp.pi) ** 2


[docs] class Particle: """An object representing a single particle in the solar system. This class is used to represent and manipulate a single particle moving within the solar system. It is mostly a collection of convenience wrappers around the more general integrators and accelerations, but it also provides some useful methods for projecting the particle's position onto the sky and fitting orbits to observations. By construction, `Particle` objects are massless. Note: none of the methods associated with this class will alter the underlying state of the particle. For example, "integrate" will give you the positions and velocities of the particle at future times, but after it returns, the particle will still be at its original state. Attributes: state: The state of the particle, either in Cartesian or Keplerian coordinates. time: The time of the particle's state. x: The position of the particle in Cartesian coordinates. v: The velocity of the particle in Cartesian coordinates. observations: A collection of observations of the particle. name: The name of the particle. gravity: The gravitational acceleration function to use for the particle. integrator: The integrator to use for the particle. earliest_time: The earliest time for which ephemeris data is available. latest_time: The latest time for which ephemeris data is available. fit_seed: A seed for fitting the orbit of the particle. """ def __init__( self, state: KeplerianState | CartesianState | None = None, time: Time | None = None, x: jnp.ndarray | None = None, v: jnp.ndarray | None = None, observations: Observations | None = None, name: str = "", gravity: str | Callable = "default solar system", de_ephemeris_version: str | None = "440", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), fit_seed: KeplerianState | CartesianState | None = None, max_step_size: u.Quantity | None = None, step_scheduler: str = "prs23", ) -> None: """Initialize a Particle object. Args: state (KeplerianState | CartesianState | None): The state of the particle. None if x and v are provided. time (Time | None): The time of the particle's state. None if state is provided, since that will have its own time baked-in. x (jnp.ndarray | None): The 3D barycentric cartesian position of the particle in AU. None if state is provided. v (jnp.ndarray | None): The 3D barycentric cartesian velocity of the particle in AU/day. None if state is provided. observations (Observations | None): Optional Observations associated with the particle. Necessary if fitting or evaluating likelihoods. name (str): The name of the particle. Defaults to "". gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. Other string options are "newtonian planets", "newtonian solar system", "gr planets", "gr solar system", and "keplerian" (pure 2-body Keplerian propagation with no ephemeris or integrator setup). de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions when using one of the built-in gravity models. Accepts either "440" or "430", default is "440". Ignored when gravity is "keplerian". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". Ignored when gravity is "keplerian". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. fit_seed (KeplerianState | CartesianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. step_scheduler (str): The scheduler used by IAS15 for picking the next proposed step size. Choices are "prs23" (Pham+ 2023 controller, default) or "global" (the controller from the original IAS15 paper). Used consistently by ``integrate``, ``integrate_or_interpolate``, ``ephemeris``, and the residuals/loglike closures. Ignored when gravity is "keplerian" or for leapfrog integrators. """ self._observations = observations self._earliest_time = earliest_time self._latest_time = latest_time self._de_ephemeris_version = de_ephemeris_version self._is_keplerian = gravity == "keplerian" self._step_scheduler = self._resolve_step_scheduler(step_scheduler) self.gravity = gravity # Preserve the original gravity spec (string or user-supplied Partial) so # max_likelihood can pass it to the result Particle rather than self.gravity. # self.gravity is overwritten by _setup_acceleration_func with a Partial that # already has t_ref_jd baked in; passing *that* Partial to a new Particle would # re-wrap it a second time and double-count the offset (see BUG_FIXES_MAY2026). self._gravity_str = gravity state = deepcopy(state) if state is not None else None # self._time is kept at jnp.array(0.0) internally; absolute time lives in # self._t_ref_astropy / self._t_ref_jd. All JAX-visible times are offsets # (in days) from _t_ref_jd, which keeps step-boundary and interpolation # arithmetic well-conditioned at decadal timescales. ( self._x, self._v, self._time, self._t_ref_astropy, self._t_ref_jd, self._cartesian_state, self._keplerian_state, self._name, self._acc_func_kwargs, ) = self._setup_state(x, v, state, time, name) if self._is_keplerian: self.gravity = "keplerian" self._integrator_state = None self._integrator = None self._integrator_method = "keplerian" self._max_step_size = None else: self.gravity = self._setup_acceleration_func(gravity) self._integrator_state, self._integrator, self._forced_integrator = ( self._setup_integrator(integrator, max_step_size) ) self._integrator_method = integrator self._max_step_size = max_step_size self._fit_seed = self._setup_fit_seed(fit_seed) ( self.residuals, self.loglike, self.scipy_objective, self.scipy_objective_grad, ) = self._setup_likelihood() self.static_residuals = self._setup_default_static_residuals() def __repr__(self) -> str: """Return a string representation of the Particle object.""" return f"Particle: {self._name}" @property def cartesian_state(self) -> CartesianState: """Return the Cartesian state of the particle. The state is self-describing about its absolute epoch via the ``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0`` (the offset in the particle's internal frame) and ``time_reference`` is :attr:`t_ref_jd` (the absolute JD anchor). """ return self._cartesian_state @property def keplerian_state(self) -> KeplerianState: """Return the Keplerian state of the particle. The state is self-describing about its absolute epoch via the ``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0`` (the offset in the particle's internal frame) and ``time_reference`` is :attr:`t_ref_jd` (the absolute JD anchor). """ return self._keplerian_state @property def observations(self) -> Observations | None: """Return the observations associated with the particle.""" return self._observations @property def t_ref(self) -> Time: """Reference time (astropy Time, TDB) — the particle's epoch. All JAX-visible times inside the Particle are offsets in days from this reference, which keeps the integrator's internal arithmetic well-conditioned at decadal timescales. """ return self._t_ref_astropy @property def t_ref_jd(self) -> float: """Reference time as a float JD (TDB), matching ``t_ref``.""" return self._t_ref_jd ############### # SETUP METHODS ############### def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray: """Convert user-provided query times to offsets from ``self._t_ref_astropy``. Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair through the subtraction, so the returned offsets retain sub-ns precision regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs are assumed to be absolute JD (TDB) and are subtracted in float64 — this is Sterbenz-exact when the magnitudes match, but the offset precision is capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities). """ if isinstance(times, Time): return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day)) return jnp.asarray(times) - self._t_ref_jd def _observations_times_as_offsets(self) -> jnp.ndarray: """Return the observation times as offsets from ``self._t_ref_astropy``. Prefers the full-precision astropy Time stored on the Observations object when available (sub-ns precision preserved). Falls back to subtracting :attr:`t_ref_jd` from the float-JD array, which is Sterbenz-exact but inherits the ulp(JD) ≈ 40 μs quantization of the stored float-JD. """ obs = self._observations times_astropy = getattr(obs, "_times_astropy", None) if times_astropy is not None: return jnp.asarray( (times_astropy.tdb - self._t_ref_astropy).to_value(u.day) ) return jnp.asarray(obs.times) - self._t_ref_jd def _setup_state( self, x: jnp.ndarray | None, v: jnp.ndarray | None, state: CartesianState | KeplerianState | None, time: Time, name: str, ) -> tuple: if state is not None: assert time is None, "Cannot provide both state and time" # The state self-describes its absolute epoch as the sum of its # offset (relative_time) and its anchor (time_reference). We use # the sum as the Particle's absolute reference time, then rebase # the stored state to relative_time=0 against that new anchor. time = state.relative_time + state.time_reference assert time is not None, "Must provide an epoch for the particle" # Build the reference time pair (astropy Time at full precision + float JD) # and then set the Particle's internal epoch to 0.0 in the offset frame. if isinstance(time, Time): t_ref_astropy = time.tdb else: t_ref_astropy = Time(float(time), format="jd", scale="tdb") t_ref_jd = jnp.array(float(t_ref_astropy.tdb.jd)) internal_time = jnp.array(0.0) if state is not None: assert x is None and v is None, "Cannot provide both state and x, v" # to_keplerian() and to_cartesian() are documented not to propagate `cov` # because the covariance is parameterisation-specific. Save and re-attach # it to the same-type output state so callers who set cov on the input # state can retrieve it on particle.cartesian_state / particle.keplerian_state. _input_is_keplerian = isinstance(state, KeplerianState) _input_cov = state.cov state = state.to_cartesian() if state.x.ndim != 2: state.x = state.x[None, :] state.v = state.v[None, :] state.relative_time = internal_time state.time_reference = t_ref_jd keplerian_state = state.to_keplerian() cartesian_state = state.to_cartesian() if _input_is_keplerian: keplerian_state.cov = _input_cov else: cartesian_state.cov = _input_cov x = state.x.flatten() v = state.v.flatten() elif x is not None: assert v is not None, "Must provide both x and v" x = x.flatten() v = v.flatten() cartesian_state = CartesianState( x=jnp.array([x]), v=jnp.array([v]), relative_time=internal_time, time_reference=t_ref_jd, acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2}, ) keplerian_state = cartesian_state.to_keplerian() else: raise ValueError( "time must be either astropy.time.Time or float (interpreted as JD in" " TDB)" ) if name == "": name = "unnamed" acc_func_kwargs = cartesian_state.acceleration_func_kwargs return ( x, v, internal_time, t_ref_astropy, t_ref_jd, cartesian_state, keplerian_state, name, acc_func_kwargs, ) def _setup_acceleration_func(self, gravity: str) -> Callable: if isinstance(gravity, jax.tree_util.Partial): # The wrapper converts from the Particle's internal offset frame # (state.time = days since self._t_ref_jd) to absolute JD before # calling the user's function, which must expect absolute JD. # # CONTRACT: the custom function must be built with t_ref_jd=0 (or # equivalently, must treat state.time as an absolute Julian Date). # Do NOT pass a function returned by jorbit's factory functions # (create_default_ephemeris_acceleration_func, etc.) that was # built with a non-zero t_ref_jd — that would double-count the # offset and silently query the ephemeris at the wrong time. # The safe pattern is to build the function fresh with t_ref_jd=0: # acc_func = create_default_ephemeris_acceleration_func( # eph.processor, t_ref_jd=0.0 # ) user_func = gravity t_ref_jd = self._t_ref_jd def _wrapped_user_acc(state: SystemState) -> jnp.ndarray: shifted = state.replace(time=state.time + t_ref_jd) return user_func(shifted) return jax.tree_util.Partial(_wrapped_user_acc) assert self._de_ephemeris_version in ["440", "430"], ( "de_ephemeris_version must be either '440' or '430' if not using a custom " "gravity function" ) if gravity == "newtonian planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_newtonian_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "newtonian solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_newtonian_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "gr planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_gr_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "gr solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_gr_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) elif gravity == "default solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_default_ephemeris_acceleration_func( eph.processor, t_ref_jd=self._t_ref_jd ) else: raise ValueError( f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', " "'newtonian solar system', 'gr planets', 'gr solar system', " "'default solar system', 'keplerian'. For a custom acceleration function, " "pass a jax.tree_util.Partial." ) return acc_func @staticmethod def _resolve_step_scheduler(name: str) -> Callable: """Translate a string name to a JIT-friendly Partial step scheduler.""" if name.lower() == "prs23": return jax.tree_util.Partial(next_proposed_dt_PRS23) elif name.lower() == "global": return jax.tree_util.Partial(next_proposed_dt_global) raise ValueError( f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'." ) def _setup_integrator( self, integrator: str, max_step_size: u.Quantity | None ) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]: if integrator == "ias15": assert ( max_step_size is None ), "max_step_size should not be provided for IAS15 integrator." a0 = self.gravity(self._cartesian_state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) integrator = jax.tree_util.Partial(ias15_evolve) forced_integrator = jax.tree_util.Partial(ias15_evolve_forced_landing) elif integrator in ["Y4", "Y6", "Y8"]: assert ( max_step_size is not None ), "Must provide max_step_size for leapfrog integrators." dt = max_step_size.to(u.day).value if integrator == "Y4": c = Y4_C d = Y4_D elif integrator == "Y6": c = Y6_C d = Y6_D elif integrator == "Y8": c = Y8_C d = Y8_D integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d) integrator = jax.tree_util.Partial(leapfrog_evolve) forced_integrator = integrator return integrator_state, integrator, forced_integrator def _setup_fit_seed( self, fit_seed: KeplerianState | CartesianState | None ) -> KeplerianState | CartesianState | None: if self._observations is None: return None if isinstance(fit_seed, (CartesianState, KeplerianState)): return fit_seed if len(self._observations) >= 3: mean_time = jnp.mean(self._observations.times) mid_idx = jnp.argmin(jnp.abs(self._observations.times - mean_time)) fit_seed = gauss_method_orbit( self._observations[0] + self._observations[mid_idx] + self._observations[-1] ) if fit_seed.to_keplerian().ecc > 1: warnings.warn( "Warning: initial Gauss's method fit produced an unbound orbit", RuntimeWarning, stacklevel=2, ) else: # Pass the Particle's reference epoch (not the observation time) # so that the returned state's x, v represent the object at # self._t_ref_jd — the same epoch the optimizer will use as its # starting point. Using obs.times[0] instead would yield x, v at # a different epoch that the optimizer would then misinterpret. fit_seed = simple_circular( self._observations.ra[0], self._observations.dec[0], semi=2.5, time=self._t_ref_jd, ) return fit_seed def _setup_likelihood(self) -> tuple[Callable, Callable, Callable, Callable]: if self._observations is None: return None, None, None, None obs_times = self._observations_times_as_offsets() if self._is_keplerian: times = obs_times residuals = jax.tree_util.Partial( _keplerian_residuals, times, self._observations.observer_positions, self._observations.ra, self._observations.dec, ) ll = jax.tree_util.Partial( _keplerian_loglike, times, self._observations.observer_positions, self._observations.ra, self._observations.dec, self._observations.inv_cov_matrices, self._observations.cov_log_dets, ) # Keplerian propagation is fully JIT-able, reverse mode works natively residuals = jax.jit(residuals) loglike = jax.jit(ll) else: if self._integrator_method == "ias15": times = obs_times inds = jnp.arange(len(obs_times)) elif self._integrator_method in ["Y4", "Y6", "Y8"]: times, inds = create_leapfrog_times( self._cartesian_state.relative_time, obs_times, self._max_step_size, ) residuals = jax.tree_util.Partial( _residuals, times, self.gravity, self._integrator, self._integrator_state, self._observations.observer_positions, self._observations.ra, self._observations.dec, inds, step_scheduler=self._step_scheduler, ) ll = jax.tree_util.Partial( _loglike, times, self.gravity, self._integrator, self._integrator_state, self._observations.observer_positions, self._observations.ra, self._observations.dec, self._observations.inv_cov_matrices, self._observations.cov_log_dets, inds, step_scheduler=self._step_scheduler, ) # since we've gone with the while loop version of the ias15 integrator, # can no longer use reverse mode. But, actually specifying forward mode # everywhere is annoying, so we're going to re-define a custom vjp for # "reverse" mode that's actually just forward mode @jax.custom_vjp def loglike(params: CartesianState | KeplerianState) -> float: return ll(params) def loglike_fwd(params: CartesianState | KeplerianState) -> tuple: output = ll(params) jac = jax.jacfwd(ll)(params) return output, (jac,) def loglike_bwd(res: tuple, g: float) -> float: jac = res val = jax.tree.map(lambda x: x * g, jac) return val loglike.defvjp(loglike_fwd, loglike_bwd) residuals = jax.jit(residuals) loglike = jax.jit(loglike) def scipy_objective(x: jnp.ndarray) -> float: c = CartesianState( x=jnp.array([x[:3]]), v=jnp.array([x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) return -loglike(c) def scipy_grad(x: jnp.ndarray) -> jnp.ndarray: c = CartesianState( x=jnp.array([x[:3]]), v=jnp.array([x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) c_grad = jax.grad(loglike)(c) g = jnp.concatenate([c_grad.x.flatten(), c_grad.v.flatten()]) return -g return residuals, loglike, scipy_objective, scipy_grad def _setup_default_static_residuals(self) -> Callable: if self._observations is None or self._is_keplerian: return None precomputed_data = precompute_likelihood_data(self, self._step_scheduler) static_residuals_func = create_default_static_residuals_func(precomputed_data) return static_residuals_func ################ # PUBLIC METHODS ################
[docs] @classmethod def from_horizons( cls, name: str, time: Time, observations: Observations | None = None, gravity: str | Callable = "default solar system", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), fit_seed: KeplerianState | CartesianState | None = None, max_step_size: u.Quantity | None = None, de_ephemeris_version: str | None = "440", ) -> Particle: """Query JPL Horizons for an SSOs state at a given time and create a Particle object. Args: name (str): The name of the SSO to query. Can be a string or an integer. time (Time): The time to query the SSO at. observations (Observations | None): Optional Observations associated with the particle. Necessary if fitting or evaluating likelihoods. gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. Other string options are "newtonian planets", "newtonian solar system", "gr planets", and "gr solar system". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. fit_seed (KeplerianState | CartesianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions. When using `from_horizons` to pull an initial state, only DE440 is supported. Defaults to "440", will error on anything else as a safeguard. Returns: Particle: A Particle object representing the SSO at the given time. """ if de_ephemeris_version != "440": raise ValueError( "Only DE440 ephemeris version is supported when pulling " "an initial state from JPL Horizons." ) data = horizons_bulk_vector_query(target=name, center="500@0", times=time) x0 = jnp.array([data["x"][0], data["y"][0], data["z"][0]]) v0 = jnp.array([data["vx"][0], data["vy"][0], data["vz"][0]]) return cls( x=x0, v=v0, time=time, observations=observations, name=name, gravity=gravity, integrator=integrator, earliest_time=earliest_time, latest_time=latest_time, fit_seed=fit_seed, max_step_size=max_step_size, )
def _integrate_base( self, times: Time, state: CartesianState | KeplerianState | None = None, forced_landing: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: if self._is_keplerian: if state is not None: state = state.to_cartesian() # relative_time is already in this Particle's offset frame # (provided state.time_reference matches self._t_ref_jd, which # is the standard convention for states passed via state=). x, v, t0 = ( state.x.flatten(), state.v.flatten(), state.relative_time, ) else: x, v, t0 = self._x, self._v, self._time times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) return *_keplerian_integrate(x, v, t0, times), None if state is None: state = self._cartesian_state integrator_state = self._integrator_state else: a0 = self.gravity(state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._integrator_method in ["Y4", "Y6", "Y8"]: times, inds = create_leapfrog_times( state.relative_time, times, self._max_step_size ) else: inds = jnp.arange(times.shape[0]) integrator = self._forced_integrator if forced_landing else self._integrator positions, velocities, _final_system_state, _final_integrator_state, steps = ( _integrate( times, state, self.gravity, integrator, integrator_state, inds, self._step_scheduler, ) ) return positions[:, 0, :], velocities[:, 0, :], steps
[docs] def integrate( self, times: Time, state: CartesianState | KeplerianState | None = None, return_steps: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the particle's orbit to specified times, landing exactly on each one. Note that this method does not change the state of the particle. It returns the positions and velocities of the particle at the given times, but the particle itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. state (CartesianState | None): The state to integrate from. If None, the particle's current state will be used. Usually not necessary to provide this. return_steps (bool): Whether to return the number of steps taken to reach each output time. If True, the method returns a tuple of (positions, velocities, steps). If False, only returns (positions, velocities). Defaults to False. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. If return_steps is True, also returns an array of the number of steps taken to reach each output time. """ x, v, steps = self._integrate_base(times, state, forced_landing=True) if return_steps: return x, v, steps return x, v
[docs] def integrate_or_interpolate( self, times: Time, state: CartesianState | KeplerianState | None = None, return_steps: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the particle's orbit to specified times, overshooting and 'interpolating' if necessary. Note that this method does not change the state of the particle. It returns the positions and velocities of the particle at the given times, but the particle itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. state (CartesianState | None): The state to integrate from. If None, the particle's current state will be used. Usually not necessary to provide this. return_steps (bool): Whether to return the number of steps taken to reach each output time. If True, the method returns a tuple of (positions, velocities, steps). If False, only returns (positions, velocities). Defaults to False. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. If return_steps is True, also returns an array of the number of steps taken to reach each output time. """ x, v, steps = self._integrate_base(times, state, forced_landing=False) if return_steps: return x, v, steps return x, v
[docs] def ephemeris( self, times: Time, observer: str | jnp.ndarray, state: CartesianState | KeplerianState | None = None, interpolate: bool = True, uncertainty: bool = False, ) -> SkyCoord | tuple[SkyCoord, jnp.ndarray]: """Compute an ephemeris for the particle. Args: times (Time | jnp.ndarray): The times to compute the ephemeris for. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. observer (str | jnp.ndarray): The observer to compute the ephemeris for. Can be a string representing an observatory name, or a 3D position vector in AU. For more info on acceptable strings, see the get_observer_positions function. state (CartesianState | None): The state to compute the ephemeris from. If None, the particle's current state will be used. Usually not necessary to provide this. interpolate (bool): Whether to use `integrate` or `integrate_or_interpolate` for the underlying integrations. uncertainty (bool): If True, also propagate the 6x6 covariance matrix stored on the state's ``cov`` field onto the sky plane via forward-mode autodiff (linear error propagation). The covariance is propagated in whichever parameterization the input state uses (Cartesian or Keplerian); no automatic conversion between parameterizations is performed. Expect roughly a ``6x`` cost relative to the nominal call because ``jax.jacfwd`` performs one JVP evaluation per input dimension. Defaults to False. Returns: coords (SkyCoord | tuple[SkyCoord, jnp.ndarray]): If ``uncertainty=False`` (default), returns a SkyCoord with the ephemeris of the particle at the given times, in ICRS coordinates, as seen from that specific observer and correcting for light travel time. If ``uncertainty=True``, returns a tuple with the same SkyCoord and the propagated ``(N, 2, 2)`` sky-plane covariance in ``arcsec**2``, (the propagated ``(N, 2, 2)`` sky-plane covariance in ``arcsec**2``, axis order (RA, Dec)). """ if isinstance(observer, str): observer_positions = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_positions = observer if uncertainty: cov_state = ( state if state is not None else ( self._keplerian_state if self._is_keplerian else self._cartesian_state ) ) cov = cov_state.cov if cov.shape != (6, 6): raise ValueError( "uncertainty=True requires a (6, 6) covariance matrix on the " f"state's `cov` field, but got shape {tuple(cov.shape)}. " "Build a CartesianState or KeplerianState with `cov=...` and pass " "it via the `state=` keyword (or set it on the particle's " "internal state)." ) if self._is_keplerian: offsets = self._times_to_offsets(times) if offsets.shape == (): offsets = jnp.array([offsets]) if uncertainty: kep_state = state if state is not None else self._keplerian_state ras, decs, cov_radec = _keplerian_ephem_with_cov( kep_state, offsets, observer_positions, cov ) return (SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs"), cov_radec) if state is not None: state = state.to_cartesian() x, v, t0 = ( state.x.flatten(), state.v.flatten(), state.relative_time, ) else: x, v, t0 = self._x, self._v, self._time ras, decs = _keplerian_ephem(x, v, t0, offsets, observer_positions) return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if state is None: state = self._cartesian_state integrator_state = self._integrator_state else: a0 = self.gravity(state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._integrator_method in ["Y4", "Y6", "Y8"]: times, inds = create_leapfrog_times( state.relative_time, times, self._max_step_size ) else: inds = jnp.arange(times.shape[0]) integrator = self._integrator if interpolate else self._forced_integrator # IAS15 with interpolation has dense-output b-coefficients available, so # we can use them to evaluate the polynomial at light-travel-delayed times in on_sky's # LTT loop instead of a constant-acceleration Taylor expansion. All other # paths (leapfrog, IAS15 forced-landing) fall back to the original Taylor. use_dense_ltt = interpolate and self._integrator_method not in ( "Y4", "Y6", "Y8", ) if uncertainty: if use_dense_ltt: ras, decs, cov_radec = _ephem_ias15_with_cov( times, state, self.gravity, observer_positions, inds, self._step_scheduler, cov, ) else: ras, decs, cov_radec = _ephem_with_cov( times, state, self.gravity, integrator, integrator_state, observer_positions, inds, self._step_scheduler, cov, ) return (SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs"), cov_radec) if use_dense_ltt: ras, decs = _ephem_ias15( times, state, self.gravity, integrator_state, observer_positions, inds, self._step_scheduler, ) else: ras, decs = _ephem( times, state, self.gravity, integrator, integrator_state, observer_positions, inds, self._step_scheduler, ) return SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs")
[docs] def max_likelihood( self, fit_seed: CartesianState | KeplerianState | None = None, verbose: bool = False, ) -> Particle: """Find the maximum likelihood orbit for the particle. Args: fit_seed (CartesianState | KeplerianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. verbose (bool): Whether to print the optimization progress. Defaults to False. Returns: Particle: A new Particle object whose state matches the maximum likelihood orbit. """ if self.loglike is None: raise ValueError("No observations provided, cannot fit an orbit") if fit_seed is None: fit_seed = self._fit_seed x0 = jnp.concatenate( [ fit_seed.to_cartesian().x.flatten(), fit_seed.to_cartesian().v.flatten(), ] ) if self._is_keplerian: # Keplerian propagation produces NaN for hyperbolic orbits, so we # guard the objective and gradient. We also scale the parameters so # the initial identity-Hessian step in L-BFGS-B is well-sized. scale = jnp.abs(x0) + 1e-10 def _scaled_objective(x_scaled: jnp.ndarray) -> float: val = self.scipy_objective(x_scaled * scale) return float(jnp.where(jnp.isnan(val), 1e30, val)) def _scaled_grad(x_scaled: jnp.ndarray) -> jnp.ndarray: g = self.scipy_objective_grad(x_scaled * scale) * scale return jnp.where(jnp.isnan(g), 0.0, g) result = minimize( fun=_scaled_objective, x0=x0 / scale, jac=_scaled_grad, method="L-BFGS-B", options={ "disp": verbose, "maxls": 100, "maxcor": 100, "maxfun": 5000, "maxiter": 1000, "ftol": 1e-14, }, ) result.x = result.x * scale else: result = minimize( fun=lambda x: self.scipy_objective(x), x0=x0, jac=lambda x: self.scipy_objective_grad(x), method="L-BFGS-B", options={ "disp": verbose, "maxls": 100, "maxcor": 100, "maxfun": 5000, "maxiter": 1000, "ftol": 1e-14, }, ) if result.success: c = CartesianState( x=jnp.array([result.x[:3]]), v=jnp.array([result.x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) if c.to_keplerian().ecc > 1: warnings.warn( "Warning: max_likelihood fit produced an unbound orbit", RuntimeWarning, stacklevel=2, ) return Particle( x=result.x[:3], v=result.x[3:], time=self._t_ref_astropy, state=None, observations=self._observations, name=self._name, gravity=self._gravity_str, de_ephemeris_version=self._de_ephemeris_version, integrator=self._integrator_method, earliest_time=self._earliest_time, latest_time=self._latest_time, fit_seed=self._fit_seed, max_step_size=self._max_step_size, ) else: raise ValueError("Failed to converge")
[docs] def is_observable( self, times: Time, observer: str | jnp.ndarray, sun_limit: float = 20.0, ephem: SkyCoord | None = None, return_angle: bool = False, ) -> jnp.ndarray: """Check whether a particle is observable or too close to the Sun. Args: times (Time): The times to check the observability. observer (str | jnp.ndarray): The observer/observatory making the observations. Can be a string for the name/code of an observatory, or a jnp.array of 3D barycentric ICRS positions in AU. sun_limit (float): The minimum allowed angular separation from the Sun, in degrees. Defaults to 20 degrees. ephem (SkyCoord | None, optional): Optionally, the ephemeris of the particle at the given times. If not provided, will be computed using the ephemeris method. Helpful if you've already computed the ephemeris and want to avoid doing it twice. return_angle (bool, optional): If True, will return the angles to the Sun in degrees, not the mask. Default is False. Returns: jnp.ndarray: A boolean array indicating whether the particle is observable at each time (True) or too close to the Sun (False). """ if isinstance(observer, str): observer_pos = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_pos = observer if ephem is None: ephem = self.ephemeris(times, observer) if isinstance(times, Time): times = jnp.array(times.tdb.jd) if times.shape == (): times = jnp.array([times]) eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ) sun_pos = jax.vmap(eph.processor.state)(times)[0][:, 0, :] eph_unit_vec = jnp.array(ephem.cartesian.xyz).T # want the angle between the eph_unit_vec, with its tail on the observer, and the vector from the observer to the sun obs_to_sun_vec = sun_pos - observer_pos obs_to_sun_unit_vec = ( obs_to_sun_vec / jnp.linalg.norm(obs_to_sun_vec, axis=1)[:, None] ) angle = jnp.arccos(jnp.sum(obs_to_sun_unit_vec * eph_unit_vec, axis=1)) if return_angle: return jnp.rad2deg(angle) else: return angle > jnp.deg2rad(sun_limit)
########################### # EXTERNAL JITTED FUNCTIONS ########################### @jax.jit def _integrate( times: jnp.ndarray, particle_state: CartesianState | KeplerianState, acc_func: Callable, integrator_func: Callable, integrator_state: IAS15IntegratorState | LeapfrogIntegratorState, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[ jnp.ndarray, jnp.ndarray, SystemState, IAS15IntegratorState | LeapfrogIntegratorState, ]: state = particle_state.to_system() positions, velocities, final_system_state, final_integrator_state, steps = ( integrator_func(state, acc_func, times, integrator_state, step_scheduler) ) return ( positions[relevant_inds], velocities[relevant_inds], final_system_state, final_integrator_state, steps, ) @jax.jit def _ephem( times: jnp.ndarray, particle_state: CartesianState | KeplerianState, acc_func: Callable, integrator_func: Callable, integrator_state: IAS15IntegratorState | LeapfrogIntegratorState, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[jnp.ndarray, jnp.ndarray]: positions, velocities, _, _, _ = _integrate( times, particle_state, acc_func, integrator_func, integrator_state, relevant_inds, step_scheduler, ) def scan_func(carry: None, scan_over: tuple) -> tuple[None, tuple]: position, velocity, time, observer_position = scan_over ra, dec = on_sky(position, velocity, time, observer_position, acc_func) return None, (ra, dec) _, (ras, decs) = jax.lax.scan( scan_func, None, ( positions[:, 0, :], velocities[:, 0, :], times[relevant_inds], observer_positions, ), ) return ras, decs @jax.jit def _ephem_ias15( times: jnp.ndarray, particle_state: CartesianState | KeplerianState, acc_func: Callable, integrator_state: IAS15IntegratorState, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, ) -> tuple[jnp.ndarray, jnp.ndarray]: """IAS15-only variant of ``_ephem`` that uses dense-output b-coefficients for LTT. The ``on_sky`` light-travel-time correction defaults to a 2nd-order Taylor with a constant acceleration. For IAS15, we already have the converged 7th-order polynomial per step (the "dense output"). This variant builds a per-observation closure that evaluates that polynomial at the light-travel-delayed time, replacing the Taylor expansion. Only used when ``interpolate=True``. """ state = particle_state.to_system() ( _positions, _velocities, _final_system_state, _final_integrator_state, _iter_num, b_buf, a0_buf, x0_buf, v0_buf, dts_buf, _t_step_starts, step_indices, h_values, ) = ias15_evolve_with_dense_output( state, acc_func, times, integrator_state, step_scheduler ) # Restrict to observation times only (drops any intermediate landing times). obs_times = times[relevant_inds] obs_step_indices = step_indices[relevant_inds] obs_h_values = h_values[relevant_inds] # Per-obs dense-output gather (single tracer at index 0). b_per_obs = b_buf[obs_step_indices][:, :, 0, :] a0_per_obs = a0_buf[obs_step_indices][:, 0, :] x0_per_obs = x0_buf[obs_step_indices][:, 0, :] v0_per_obs = v0_buf[obs_step_indices][:, 0, :] dt_per_obs = dts_buf[obs_step_indices] def per_obs_on_sky( b_step: jnp.ndarray, a0_step: jnp.ndarray, x0_step: jnp.ndarray, v0_step: jnp.ndarray, dt_step: jnp.ndarray, h_obs: jnp.ndarray, time: jnp.ndarray, observer_pos: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: propagator = make_ltt_propagator( b_step, a0_step, x0_step, v0_step, dt_step, h_obs ) x_obs = propagator(jnp.array(0.0)) return on_sky( x_obs, jnp.zeros(3), time, observer_pos, acc_func, ltt_position_fn=propagator, ) ras, decs = jax.vmap(per_obs_on_sky, in_axes=(0, 0, 0, 0, 0, 0, 0, 0))( b_per_obs, a0_per_obs, x0_per_obs, v0_per_obs, dt_per_obs, obs_h_values, obs_times, observer_positions, ) return ras, decs def _state_vec_to_xv( state_vec: jnp.ndarray, is_keplerian_param: bool ) -> tuple[jnp.ndarray, jnp.ndarray]: """Unpack a (6,) parameter vector to ICRS Cartesian position/velocity. If ``is_keplerian_param``, the entries are (semi, ecc, inc, Omega, omega, nu) and we convert via :func:`elements_to_cartesian` followed by an ecliptic -> ICRS rotation. Otherwise the entries are interpreted directly as flat Cartesian (x, y, z, vx, vy, vz) in ICRS. Returns ``(x, v)`` each shaped ``(1, 3)``. """ if is_keplerian_param: x_ecl, v_ecl = elements_to_cartesian( state_vec[0:1], state_vec[1:2], state_vec[5:6], state_vec[2:3], state_vec[3:4], state_vec[4:5], TOTAL_SOLAR_SYSTEM_GM, ) x = horizons_ecliptic_to_icrs(x_ecl) v = horizons_ecliptic_to_icrs(v_ecl) else: x = state_vec[:3].reshape(1, 3) v = state_vec[3:].reshape(1, 3) return x, v def _state_to_vec(state: CartesianState | KeplerianState) -> jnp.ndarray: """Flatten a CartesianState or KeplerianState to a (6,) parameter vector.""" if isinstance(state, KeplerianState): return jnp.concatenate( [ jnp.atleast_1d(state.semi), jnp.atleast_1d(state.ecc), jnp.atleast_1d(state.inc), jnp.atleast_1d(state.Omega), jnp.atleast_1d(state.omega), jnp.atleast_1d(state.nu), ] ) return jnp.concatenate([state.x.flatten(), state.v.flatten()]) def _cov_from_jacobian( radec_fn: Callable, nominal_vec: jnp.ndarray, cov: jnp.ndarray, N: int, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Linear error propagation via forward-mode AD. ``radec_fn`` must accept a ``(6,)`` parameter vector and return a flat ``(2N,)`` interleaved ``[ra0, dec0, ra1, dec1, ...]`` array (radians). Returns ``(ra, dec, cov_radec)`` where ``cov_radec`` has shape ``(N, 2, 2)`` in ``arcsec**2``. """ radec_nominal = radec_fn(nominal_vec) ras = radec_nominal[0::2] decs = radec_nominal[1::2] J = jax.jacfwd(radec_fn)(nominal_vec) J_t = J.reshape(N, 2, 6) cov_radec = jnp.einsum("nij,jk,nlk->nil", J_t, cov, J_t) * _RAD2ARCSEC_SQ return ras, decs, cov_radec @jax.jit def _ephem_ias15_with_cov( times: jnp.ndarray, particle_state: CartesianState | KeplerianState, acc_func: Callable, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, cov: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """IAS15 dense-output ephemeris with sky-plane covariance via forward-mode AD.""" is_keplerian_param = isinstance(particle_state, KeplerianState) def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray: x, v = _state_vec_to_xv(state_vec, is_keplerian_param) state = CartesianState( x=x, v=v, relative_time=particle_state.relative_time, time_reference=particle_state.time_reference, acceleration_func_kwargs=particle_state.acceleration_func_kwargs, ) a0 = acc_func(state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) ras, decs = _ephem_ias15( times, state, acc_func, integrator_state, observer_positions, relevant_inds, step_scheduler, ) return jnp.stack([ras, decs], axis=1).flatten() nominal_vec = _state_to_vec(particle_state) return _cov_from_jacobian(radec_fn, nominal_vec, cov, relevant_inds.shape[0]) @jax.jit def _ephem_with_cov( times: jnp.ndarray, particle_state: CartesianState | KeplerianState, acc_func: Callable, integrator_func: Callable, integrator_state: IAS15IntegratorState | LeapfrogIntegratorState, observer_positions: jnp.ndarray, relevant_inds: jnp.ndarray, step_scheduler: Callable, cov: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Generic non-dense ephemeris with sky-plane covariance via forward-mode AD. Handles both the IAS15 forced-landing path (``interpolate=False`` + IAS15) and the leapfrog path. For IAS15, the integrator state must be re-initialized inside the AD closure so the initial-acceleration entry tracks the perturbed state vector; for leapfrog, ``LeapfrogIntegratorState`` is independent of the dynamical state and is reused as-is. """ is_keplerian_param = isinstance(particle_state, KeplerianState) reinit_ias15 = isinstance(integrator_state, IAS15IntegratorState) def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray: x, v = _state_vec_to_xv(state_vec, is_keplerian_param) state = CartesianState( x=x, v=v, relative_time=particle_state.relative_time, time_reference=particle_state.time_reference, acceleration_func_kwargs=particle_state.acceleration_func_kwargs, ) if reinit_ias15: a0 = acc_func(state.to_system()) local_integrator_state = initialize_ias15_integrator_state(a0) else: local_integrator_state = integrator_state ras, decs = _ephem( times, state, acc_func, integrator_func, local_integrator_state, observer_positions, relevant_inds, step_scheduler, ) return jnp.stack([ras, decs], axis=1).flatten() nominal_vec = _state_to_vec(particle_state) return _cov_from_jacobian(radec_fn, nominal_vec, cov, relevant_inds.shape[0]) @jax.jit def _residuals( times: jnp.ndarray, gravity: Callable, integrator: Callable, integrator_state: IAS15IntegratorState | LeapfrogIntegratorState, observer_positions: jnp.ndarray, ra: jnp.ndarray, dec: jnp.ndarray, relevant_inds: jnp.ndarray, particle_state: CartesianState | KeplerianState, step_scheduler: Callable, ) -> jnp.ndarray: ras, decs = _ephem( times, particle_state, gravity, integrator, integrator_state, observer_positions, relevant_inds, step_scheduler, ) xis_etas = jax.vmap(tangent_plane_projection)(ra, dec, ras, decs) return xis_etas # note: this external jitted function does not have fwd mode autodiff enforced, will # break on reverse mode when using ias15 @jax.jit def _loglike( times: jnp.ndarray, gravity: Callable, integrator: Callable, integrator_state: IAS15IntegratorState | LeapfrogIntegratorState, observer_positions: jnp.ndarray, ra: jnp.ndarray, dec: jnp.ndarray, inv_cov_matrices: jnp.ndarray, cov_log_dets: jnp.ndarray, relevant_inds: jnp.ndarray, particle_state: CartesianState | KeplerianState, step_scheduler: Callable, ) -> float: xis_etas = _residuals( times, gravity, integrator, integrator_state, observer_positions, ra, dec, relevant_inds, particle_state, step_scheduler, ) quad = jnp.einsum("bi,bij,bj->b", xis_etas, inv_cov_matrices, xis_etas) ll = jnp.sum(-0.5 * (2 * jnp.log(2 * jnp.pi) + cov_log_dets + quad)) return ll ################################### # KEPLERIAN EXTERNAL JITTED FUNCTIONS ################################### @jax.jit def _keplerian_integrate( x: jnp.ndarray, v: jnp.ndarray, t0: float, times: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: x_ecl = icrs_to_horizons_ecliptic(x[None, :]) v_ecl = icrs_to_horizons_ecliptic(v[None, :]) positions_ecl, velocities_ecl = keplerian_propagate( x_ecl, v_ecl, t0, times, TOTAL_SOLAR_SYSTEM_GM ) positions = horizons_ecliptic_to_icrs(positions_ecl) velocities = horizons_ecliptic_to_icrs(velocities_ecl) return positions, velocities @jax.jit def _keplerian_on_sky( x: jnp.ndarray, v: jnp.ndarray, time: float, observer_position: jnp.ndarray, ) -> tuple[float, float]: r = jnp.linalg.norm(x) a0 = -TOTAL_SOLAR_SYSTEM_GM * x / (r**3) xz = x for _ in range(3): earth_distance = jnp.linalg.norm(xz - observer_position) dt = -earth_distance * INV_SPEED_OF_LIGHT xz = x + v * dt + 0.5 * a0 * dt * dt X = xz - observer_position calc_ra = jnp.mod(jnp.arctan2(X[1], X[0]) + 2 * jnp.pi, 2 * jnp.pi) calc_dec = jnp.pi / 2 - jnp.arccos(X[-1] / jnp.linalg.norm(X)) return calc_ra, calc_dec @jax.jit def _keplerian_ephem( x: jnp.ndarray, v: jnp.ndarray, t0: float, times: jnp.ndarray, observer_positions: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: positions, velocities = _keplerian_integrate(x, v, t0, times) def scan_func(carry: None, scan_over: tuple) -> tuple[None, tuple]: position, velocity, time, observer_position = scan_over ra, dec = _keplerian_on_sky(position, velocity, time, observer_position) return None, (ra, dec) _, (ras, decs) = jax.lax.scan( scan_func, None, (positions, velocities, times, observer_positions), ) return ras, decs @jax.jit def _keplerian_ephem_with_cov( particle_state: CartesianState | KeplerianState, times: jnp.ndarray, observer_positions: jnp.ndarray, cov: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Keplerian-path ephemeris with sky-plane covariance via forward-mode AD. Supports both Keplerian and Cartesian input parameterizations; the covariance is propagated in whichever space the input state was supplied in. """ is_keplerian_param = isinstance(particle_state, KeplerianState) t0 = particle_state.relative_time def radec_fn(state_vec: jnp.ndarray) -> jnp.ndarray: x, v = _state_vec_to_xv(state_vec, is_keplerian_param) ras, decs = _keplerian_ephem( x.flatten(), v.flatten(), t0, times, observer_positions ) return jnp.stack([ras, decs], axis=1).flatten() nominal_vec = _state_to_vec(particle_state) return _cov_from_jacobian(radec_fn, nominal_vec, cov, times.shape[0]) @jax.jit def _keplerian_residuals( times: jnp.ndarray, observer_positions: jnp.ndarray, ra: jnp.ndarray, dec: jnp.ndarray, particle_state: CartesianState | KeplerianState, ) -> jnp.ndarray: x = particle_state.to_cartesian().x.flatten() v = particle_state.to_cartesian().v.flatten() t0 = particle_state.relative_time ras, decs = _keplerian_ephem(x, v, t0, times, observer_positions) xis_etas = jax.vmap(tangent_plane_projection)(ra, dec, ras, decs) return xis_etas @jax.jit def _keplerian_loglike( times: jnp.ndarray, observer_positions: jnp.ndarray, ra: jnp.ndarray, dec: jnp.ndarray, inv_cov_matrices: jnp.ndarray, cov_log_dets: jnp.ndarray, particle_state: CartesianState | KeplerianState, ) -> float: xis_etas = _keplerian_residuals(times, observer_positions, ra, dec, particle_state) quad = jnp.einsum("bi,bij,bj->b", xis_etas, inv_cov_matrices, xis_etas) ll = jnp.sum(-0.5 * (2 * jnp.log(2 * jnp.pi) + cov_log_dets + quad)) return ll