Source code for jorbit.particle.particle

"""The Particle class.

The jitted/host helper functions that implement each integration branch live in sibling
modules of this subpackage:

- :mod:`jorbit.particle.ephem` (generic integrate/ephem, leapfrog)
- :mod:`jorbit.particle.ias15_dense` (IAS15 ``interpolate=True``)
- :mod:`jorbit.particle.ias15_forced` (IAS15 ``interpolate=False``)
- :mod:`jorbit.particle.keplerian` (analytic two-body path)
- :mod:`jorbit.particle.covariance` (shared forward-mode-AD covariance leaves)
- :mod:`jorbit.particle.likelihood` (residuals/log-likelihood for fitting)
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from copy import deepcopy

import astropy.units as u
import jax
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.time import Time
from scipy.optimize import minimize

from jorbit import Observations
from jorbit.accelerations import (
    create_default_ephemeris_acceleration_func,
    create_gr_ephemeris_acceleration_func,
    create_newtonian_ephemeris_acceleration_func,
)
from jorbit.astrometry.orbit_fit_seeds import gauss_method_orbit, simple_circular
from jorbit.data.constants import (
    SPEED_OF_LIGHT,
    Y4_C,
    Y4_D,
    Y6_C,
    Y6_D,
    Y8_C,
    Y8_D,
)
from jorbit.ephemeris.ephemeris import Ephemeris
from jorbit.integrators import (
    budgeted_forced_landing,
    create_leapfrog_times,
    ias15_evolve,
    ias15_evolve_forced_landing,
    ias15_span_probe,
    initialize_ias15_integrator_state,
    leapfrog_evolve,
    next_proposed_dt_global,
    next_proposed_dt_PRS23,
    stitched_interpolate,
)
from jorbit.likelihoods.setup_static_likelihood import (
    create_default_static_residuals_func,
    precompute_likelihood_data,
)
from jorbit.particle.ephem import _ephem, _ephem_with_cov, _integrate
from jorbit.particle.ias15_dense import _ephem_ias15_stitched, _ephem_ias15_with_cov
from jorbit.particle.ias15_forced import _ephem_forced_budgeted
from jorbit.particle.keplerian import (
    _keplerian_ephem,
    _keplerian_ephem_with_cov,
    _keplerian_integrate,
    _keplerian_loglike,
    _keplerian_residuals,
)
from jorbit.particle.likelihood import _loglike, _residuals
from jorbit.utils.horizons import get_observer_positions, horizons_bulk_vector_query
from jorbit.utils.states import (
    CartesianState,
    IAS15IntegratorState,
    KeplerianState,
    LeapfrogIntegratorState,
)


[docs] class Particle: """An object representing a single particle in the solar system. This class is used to represent and manipulate a single particle moving within the solar system. It is mostly a collection of convenience wrappers around the more general integrators and accelerations, but it also provides some useful methods for projecting the particle's position onto the sky and fitting orbits to observations. By construction, `Particle` objects are massless. Note: none of the methods associated with this class will alter the underlying state of the particle. For example, "integrate" will give you the positions and velocities of the particle at future times, but after it returns, the particle will still be at its original state. Attributes: state: The state of the particle, either in Cartesian or Keplerian coordinates. time: The time of the particle's state. x: The position of the particle in Cartesian coordinates. v: The velocity of the particle in Cartesian coordinates. observations: A collection of observations of the particle. name: The name of the particle. gravity: The gravitational acceleration function to use for the particle. integrator: The integrator to use for the particle. earliest_time: The earliest time for which ephemeris data is available. latest_time: The latest time for which ephemeris data is available. fit_seed: A seed for fitting the orbit of the particle. """ def __init__( self, state: KeplerianState | CartesianState | None = None, time: Time | None = None, x: jnp.ndarray | None = None, v: jnp.ndarray | None = None, observations: Observations | None = None, name: str = "", gravity: str | Callable = "default solar system", de_ephemeris_version: str | None = "440", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), fit_seed: KeplerianState | CartesianState | None = None, max_step_size: u.Quantity | None = None, step_scheduler: str = "prs23", ) -> None: """Initialize a Particle object. Args: state (KeplerianState | CartesianState | None): The state of the particle. None if x and v are provided. time (Time | None): The time of the particle's state. None if state is provided, since that will have its own time baked-in. x (jnp.ndarray | None): The 3D barycentric cartesian position of the particle in AU. None if state is provided. v (jnp.ndarray | None): The 3D barycentric cartesian velocity of the particle in AU/day. None if state is provided. observations (Observations | None): Optional Observations associated with the particle. Necessary if fitting or evaluating likelihoods. name (str): The name of the particle. Defaults to "". gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. Other string options are "newtonian planets", "newtonian solar system", "gr planets", "gr solar system", and "keplerian" (pure 2-body Keplerian propagation with no ephemeris or integrator setup). de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions when using one of the built-in gravity models. Accepts either "440" or "430", default is "440". Ignored when gravity is "keplerian". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". Ignored when gravity is "keplerian". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. fit_seed (KeplerianState | CartesianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. step_scheduler (str): The scheduler used by IAS15 for picking the next proposed step size. Choices are "prs23" (Pham+ 2023 controller, default) or "global" (the controller from the original IAS15 paper). Used consistently by ``integrate``, ``integrate_or_interpolate``, ``ephemeris``, and the residuals/loglike closures. Ignored when gravity is "keplerian" or for leapfrog integrators. """ self._observations = observations self._earliest_time = earliest_time self._latest_time = latest_time self._de_ephemeris_version = de_ephemeris_version self._is_keplerian = gravity == "keplerian" self._step_scheduler = self._resolve_step_scheduler(step_scheduler) self.gravity = gravity # Preserve the original gravity spec (string or user-supplied Partial) so # max_likelihood can pass it to the result Particle rather than self.gravity. # self.gravity is overwritten by _setup_acceleration_func with a Partial that # already has t_ref_jd baked in; passing *that* Partial to a new Particle would # re-wrap it a second time and double-count the offset (see BUG_FIXES_MAY2026). self._gravity_str = gravity state = deepcopy(state) if state is not None else None # self._time is kept at jnp.array(0.0) internally; absolute time lives in # self._t_ref_astropy / self._t_ref_jd. All JAX-visible times are offsets # (in days) from _t_ref_jd, which keeps step-boundary and interpolation # arithmetic well-conditioned at decadal timescales. ( self._x, self._v, self._time, self._t_ref_astropy, self._t_ref_jd, self._cartesian_state, self._keplerian_state, self._name, self._acc_func_kwargs, ) = self._setup_state(x, v, state, time, name) if self._is_keplerian: self.gravity = "keplerian" self._integrator_state = None self._integrator = None self._integrator_method = "keplerian" self._max_step_size = None else: self.gravity = self._setup_acceleration_func(gravity) self._integrator_state, self._integrator, self._forced_integrator = ( self._setup_integrator(integrator, max_step_size) ) self._integrator_method = integrator self._max_step_size = max_step_size self._fit_seed = self._setup_fit_seed(fit_seed) ( self.residuals, self.loglike, self.scipy_objective, self.scipy_objective_grad, ) = self._setup_likelihood() self.static_residuals = self._setup_default_static_residuals() def __repr__(self) -> str: """Return a string representation of the Particle object.""" return f"Particle: {self._name}" @property def cartesian_state(self) -> CartesianState: """Return the Cartesian state of the particle. The state is self-describing about its absolute epoch via the ``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0`` (the offset in the particle's internal frame) and ``time_reference`` is :attr:`t_ref_jd` (the absolute JD anchor). """ return self._cartesian_state @property def keplerian_state(self) -> KeplerianState: """Return the Keplerian state of the particle. The state is self-describing about its absolute epoch via the ``(relative_time, time_reference)`` pair: ``relative_time`` is ``0.0`` (the offset in the particle's internal frame) and ``time_reference`` is :attr:`t_ref_jd` (the absolute JD anchor). """ return self._keplerian_state @property def observations(self) -> Observations | None: """Return the observations associated with the particle.""" return self._observations @property def t_ref(self) -> Time: """Reference time (astropy Time, TDB) — the particle's epoch. All JAX-visible times inside the Particle are offsets in days from this reference, which keeps the integrator's internal arithmetic well-conditioned at decadal timescales. """ return self._t_ref_astropy @property def t_ref_jd(self) -> float: """Reference time as a float JD (TDB), matching ``t_ref``.""" return self._t_ref_jd ############### # SETUP METHODS ############### def _times_to_offsets(self, times: Time | jnp.ndarray) -> jnp.ndarray: """Convert user-provided query times to offsets from ``self._t_ref_astropy``. Astropy ``Time`` inputs preserve their internal jd1+jd2 high-precision pair through the subtraction, so the returned offsets retain sub-ns precision regardless of how far ``times`` is from ``t_ref``. Plain float/jnp inputs are assumed to be absolute JD (TDB) and are subtracted in float64 — this is Sterbenz-exact when the magnitudes match, but the offset precision is capped at ulp(JD) ≈ 40 μs (≈1 m for typical solar system velocities). """ if isinstance(times, Time): return jnp.asarray((times.tdb - self._t_ref_astropy).to_value(u.day)) return jnp.asarray(times) - self._t_ref_jd def _validate_state_epoch(self, state: CartesianState | KeplerianState) -> None: """Guard against a ``state=`` whose epoch disagrees with this particle's. Query times handed to ``integrate``/``ephemeris`` are offsets from ``self._t_ref_jd``, so a supplied ``state`` must share that anchor: the integrator marches ``relative_time`` and the acceleration function adds ``state.time_reference`` to recover the absolute JD. If the two disagree the ephemeris would be queried at the wrong epoch (potentially years off) while still "looking" plausible, so we fail loudly instead. """ if abs(float(state.time_reference) - float(self._t_ref_jd)) > 1e-6: raise ValueError( "The provided state's time_reference " f"({float(state.time_reference)}) does not match this particle's " f"reference epoch ({float(self._t_ref_jd)}). States passed via " "`state=` must be expressed in the particle's frame " "(time_reference == particle.t_ref_jd, with relative_time the offset " "in days). Build the state at the particle's epoch, or create a new " "Particle at the desired epoch." ) def _observations_times_as_offsets(self) -> jnp.ndarray: """Return the observation times as offsets from ``self._t_ref_astropy``. Prefers the full-precision astropy Time stored on the Observations object when available (sub-ns precision preserved). Falls back to subtracting :attr:`t_ref_jd` from the float-JD array, which is Sterbenz-exact but inherits the ulp(JD) ≈ 40 μs quantization of the stored float-JD. """ obs = self._observations times_astropy = getattr(obs, "_times_astropy", None) if times_astropy is not None: return jnp.asarray( (times_astropy.tdb - self._t_ref_astropy).to_value(u.day) ) return jnp.asarray(obs.times) - self._t_ref_jd def _setup_state( self, x: jnp.ndarray | None, v: jnp.ndarray | None, state: CartesianState | KeplerianState | None, time: Time, name: str, ) -> tuple: if state is not None: assert time is None, "Cannot provide both state and time" # The state self-describes its absolute epoch as the sum of its # offset (relative_time) and its anchor (time_reference). We use # the sum as the Particle's absolute reference time, then rebase # the stored state to relative_time=0 against that new anchor. time = state.relative_time + state.time_reference assert time is not None, "Must provide an epoch for the particle" # Build the reference time pair (astropy Time at full precision + float JD) # and then set the Particle's internal epoch to 0.0 in the offset frame. if isinstance(time, Time): t_ref_astropy = time.tdb else: t_ref_astropy = Time(float(time), format="jd", scale="tdb") t_ref_jd = jnp.array(float(t_ref_astropy.tdb.jd)) internal_time = jnp.array(0.0) if state is not None: assert x is None and v is None, "Cannot provide both state and x, v" # to_keplerian() and to_cartesian() are documented not to propagate `cov` # because the covariance is parameterisation-specific. Save and re-attach # it to the same-type output state so callers who set cov on the input # state can retrieve it on particle.cartesian_state / particle.keplerian_state. _input_is_keplerian = isinstance(state, KeplerianState) _input_cov = state.cov state = state.to_cartesian() if state.x.ndim != 2: state.x = state.x[None, :] state.v = state.v[None, :] state.relative_time = internal_time state.time_reference = t_ref_jd keplerian_state = state.to_keplerian() cartesian_state = state.to_cartesian() if _input_is_keplerian: keplerian_state.cov = _input_cov else: cartesian_state.cov = _input_cov x = state.x.flatten() v = state.v.flatten() elif x is not None: assert v is not None, "Must provide both x and v" x = x.flatten() v = v.flatten() cartesian_state = CartesianState( x=jnp.array([x]), v=jnp.array([v]), relative_time=internal_time, time_reference=t_ref_jd, acceleration_func_kwargs={"c2": SPEED_OF_LIGHT**2}, ) keplerian_state = cartesian_state.to_keplerian() else: raise ValueError( "time must be either astropy.time.Time or float (interpreted as JD in" " TDB)" ) if name == "": name = "unnamed" acc_func_kwargs = cartesian_state.acceleration_func_kwargs return ( x, v, internal_time, t_ref_astropy, t_ref_jd, cartesian_state, keplerian_state, name, acc_func_kwargs, ) def _setup_acceleration_func(self, gravity: str) -> Callable: if isinstance(gravity, jax.tree_util.Partial): # CONTRACT: the custom function must recover the absolute JD from the # state itself, as ``inputs.relative_time + inputs.time_reference`` — # exactly what jorbit's own factory functions # (create_default_ephemeris_acceleration_func, etc.) now do. The state # the integrator hands to the function carries # ``time_reference == self._t_ref_jd`` and a small ``relative_time`` # offset, so no wrapper/shim is needed: the state is self-describing. return gravity assert self._de_ephemeris_version in ["440", "430"], ( "de_ephemeris_version must be either '440' or '430' if not using a custom " "gravity function" ) if gravity == "newtonian planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor) elif gravity == "newtonian solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_newtonian_ephemeris_acceleration_func(eph.processor) elif gravity == "gr planets": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default planets", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_gr_ephemeris_acceleration_func(eph.processor) elif gravity == "gr solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_gr_ephemeris_acceleration_func(eph.processor) elif gravity == "default solar system": eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ssos="default solar system", de_ephemeris_version=self._de_ephemeris_version, ) acc_func = create_default_ephemeris_acceleration_func(eph.processor) else: raise ValueError( f"Unrecognized gravity '{gravity}'. Valid options are: 'newtonian planets', " "'newtonian solar system', 'gr planets', 'gr solar system', " "'default solar system', 'keplerian'. For a custom acceleration function, " "pass a jax.tree_util.Partial." ) return acc_func @staticmethod def _resolve_step_scheduler(name: str) -> Callable: """Translate a string name to a JIT-friendly Partial step scheduler.""" if name.lower() == "prs23": return jax.tree_util.Partial(next_proposed_dt_PRS23) elif name.lower() == "global": return jax.tree_util.Partial(next_proposed_dt_global) raise ValueError( f"Unknown step_scheduler '{name}'. Choices are 'prs23' or 'global'." ) def _setup_integrator( self, integrator: str, max_step_size: u.Quantity | None ) -> tuple[IAS15IntegratorState | LeapfrogIntegratorState, Callable]: if integrator == "ias15": assert ( max_step_size is None ), "max_step_size should not be provided for IAS15 integrator." a0 = self.gravity(self._cartesian_state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) integrator = jax.tree_util.Partial(ias15_evolve) forced_integrator = jax.tree_util.Partial(ias15_evolve_forced_landing) elif integrator in ["Y4", "Y6", "Y8"]: assert ( max_step_size is not None ), "Must provide max_step_size for leapfrog integrators." dt = max_step_size.to(u.day).value if integrator == "Y4": c = Y4_C d = Y4_D elif integrator == "Y6": c = Y6_C d = Y6_D elif integrator == "Y8": c = Y8_C d = Y8_D integrator_state = LeapfrogIntegratorState(dt=dt, C=c, D=d) integrator = jax.tree_util.Partial(leapfrog_evolve) forced_integrator = integrator return integrator_state, integrator, forced_integrator def _setup_fit_seed( self, fit_seed: KeplerianState | CartesianState | None ) -> KeplerianState | CartesianState | None: if self._observations is None: return None if isinstance(fit_seed, (CartesianState, KeplerianState)): return fit_seed if len(self._observations) >= 3: mean_time = jnp.mean(self._observations.times) mid_idx = jnp.argmin(jnp.abs(self._observations.times - mean_time)) fit_seed = gauss_method_orbit( self._observations[0] + self._observations[mid_idx] + self._observations[-1] ) if fit_seed.to_keplerian().ecc > 1: warnings.warn( "Warning: initial Gauss's method fit produced an unbound orbit", RuntimeWarning, stacklevel=2, ) else: # Pass the Particle's reference epoch (not the observation time) # so that the returned state's x, v represent the object at # self._t_ref_jd — the same epoch the optimizer will use as its # starting point. Using obs.times[0] instead would yield x, v at # a different epoch that the optimizer would then misinterpret. fit_seed = simple_circular( self._observations.ra[0], self._observations.dec[0], semi=2.5, time=self._t_ref_jd, ) return fit_seed def _setup_likelihood(self) -> tuple[Callable, Callable, Callable, Callable]: if self._observations is None: return None, None, None, None obs_times = self._observations_times_as_offsets() if self._is_keplerian: times = obs_times residuals = jax.tree_util.Partial( _keplerian_residuals, times, self._observations.observer_positions, self._observations.ra, self._observations.dec, ) ll = jax.tree_util.Partial( _keplerian_loglike, times, self._observations.observer_positions, self._observations.ra, self._observations.dec, self._observations.inv_cov_matrices, self._observations.cov_log_dets, ) # Keplerian propagation is fully JIT-able, reverse mode works natively residuals = jax.jit(residuals) loglike = jax.jit(ll) else: if self._integrator_method == "ias15": times = obs_times inds = jnp.arange(len(obs_times)) elif self._integrator_method in ["Y4", "Y6", "Y8"]: times, inds = create_leapfrog_times( self._cartesian_state.relative_time, obs_times, self._max_step_size, ) residuals = jax.tree_util.Partial( _residuals, times, self.gravity, self._integrator, self._integrator_state, self._observations.observer_positions, self._observations.ra, self._observations.dec, inds, step_scheduler=self._step_scheduler, ) ll = jax.tree_util.Partial( _loglike, times, self.gravity, self._integrator, self._integrator_state, self._observations.observer_positions, self._observations.ra, self._observations.dec, self._observations.inv_cov_matrices, self._observations.cov_log_dets, inds, step_scheduler=self._step_scheduler, ) # since we've gone with the while loop version of the ias15 integrator, # can no longer use reverse mode. But, actually specifying forward mode # everywhere is annoying, so we're going to re-define a custom vjp for # "reverse" mode that's actually just forward mode @jax.custom_vjp def loglike(params: CartesianState | KeplerianState) -> float: return ll(params) def loglike_fwd(params: CartesianState | KeplerianState) -> tuple: output = ll(params) jac = jax.jacfwd(ll)(params) return output, (jac,) def loglike_bwd(res: tuple, g: float) -> float: jac = res val = jax.tree.map(lambda x: x * g, jac) return val loglike.defvjp(loglike_fwd, loglike_bwd) residuals = jax.jit(residuals) loglike = jax.jit(loglike) def scipy_objective(x: jnp.ndarray) -> float: c = CartesianState( x=jnp.array([x[:3]]), v=jnp.array([x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) return -loglike(c) def scipy_grad(x: jnp.ndarray) -> jnp.ndarray: c = CartesianState( x=jnp.array([x[:3]]), v=jnp.array([x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) c_grad = jax.grad(loglike)(c) g = jnp.concatenate([c_grad.x.flatten(), c_grad.v.flatten()]) return -g return residuals, loglike, scipy_objective, scipy_grad def _setup_default_static_residuals(self) -> Callable: if self._observations is None or self._is_keplerian: return None precomputed_data = precompute_likelihood_data(self, self._step_scheduler) static_residuals_func = create_default_static_residuals_func(precomputed_data) return static_residuals_func ################ # PUBLIC METHODS ################
[docs] @classmethod def from_horizons( cls, name: str, time: Time, observations: Observations | None = None, gravity: str | Callable = "default solar system", integrator: str = "ias15", earliest_time: Time = Time("1980-01-01"), latest_time: Time = Time("2050-01-01"), fit_seed: KeplerianState | CartesianState | None = None, max_step_size: u.Quantity | None = None, de_ephemeris_version: str | None = "440", ) -> Particle: """Query JPL Horizons for an SSOs state at a given time and create a Particle object. Args: name (str): The name of the SSO to query. Can be a string or an integer. time (Time): The time to query the SSO at. observations (Observations | None): Optional Observations associated with the particle. Necessary if fitting or evaluating likelihoods. gravity (str | Callable): The gravitational acceleration function to use when integrating the particle's orbit. Defaults to "default solar system", which corresponds to parameterized post-Newtonian interactions with the 10 bodies in the JPL DE440 ephemeris, plus Newtonian interactions with the 16 largest asteroids in the asteroids_de441/sb441-n16.bsp ephemeris. Can also be a jax.tree_util.Partial object that follows the same signature as the acceleration functions in jorbit.accelerations. Other string options are "newtonian planets", "newtonian solar system", "gr planets", and "gr solar system". integrator (str): The integrator to use for the particle. Choices are "ias15", which is a 15th order adaptive step-size integrator, or "Y4", "Y6", or "Y8", which are 4th, 6th, and 8th order Yoshida leapfrog integrators with fixed step sizes. Defaults to "ias15". earliest_time (Time): The earliest time we expect to integrate the particle to. Defaults to Time("1980-01-01"). Larger time windows will result in larger in-memory ephemeris objects. latest_time (Time): The latest time we expect to integrate the particle to. Defaults to Time("2050-01-01"). Larger time windows will result in larger in-memory ephemeris objects. fit_seed (KeplerianState | CartesianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. max_step_size (u.Quantity, optional): The fixed step size to use for leapfrog integrators. Required if integrator is "Y4", "Y6", or "Y8". Ignored if integrator is "ias15". Note that this is the maximum step size; the actual step size may be smaller to ensure that the particle lands exactly on the requested output times, and that the step size may change if the spacing between output times is not constant. Defaults to None. de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use for perturber positions. When using `from_horizons` to pull an initial state, only DE440 is supported. Defaults to "440", will error on anything else as a safeguard. Returns: Particle: A Particle object representing the SSO at the given time. """ if de_ephemeris_version != "440": raise ValueError( "Only DE440 ephemeris version is supported when pulling " "an initial state from JPL Horizons." ) data = horizons_bulk_vector_query(target=name, center="500@0", times=time) x0 = jnp.array([data["x"][0], data["y"][0], data["z"][0]]) v0 = jnp.array([data["vx"][0], data["vy"][0], data["vz"][0]]) return cls( x=x0, v=v0, time=time, observations=observations, name=name, gravity=gravity, integrator=integrator, earliest_time=earliest_time, latest_time=latest_time, fit_seed=fit_seed, max_step_size=max_step_size, )
def _integrate_base( self, times: Time, state: CartesianState | KeplerianState | None = None, forced_landing: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: if state is not None: self._validate_state_epoch(state) if self._is_keplerian: if state is not None: state = state.to_cartesian() # relative_time is already in this Particle's offset frame # (validated above: state.time_reference == self._t_ref_jd). x, v, t0 = ( state.x.flatten(), state.v.flatten(), state.relative_time, ) else: x, v, t0 = self._x, self._v, self._time times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) return *_keplerian_integrate(x, v, t0, times), None if state is None: state = self._cartesian_state integrator_state = self._integrator_state else: a0 = self.gravity(state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._integrator_method in ["Y4", "Y6", "Y8"]: # Leapfrog pre-expands its (uncapped) time array, so it never truncates. times, inds = create_leapfrog_times( state.relative_time, times, self._max_step_size ) integrator = self._forced_integrator if forced_landing else self._integrator positions, velocities, _fss, _fis, steps = _integrate( times, state, self.gravity, integrator, integrator_state, inds, self._step_scheduler, ) return positions[:, 0, :], velocities[:, 0, :], steps # IAS15: orchestrate on the host so neither the dense-output buffer # (interpolation) nor the per-interval cap (forced landing) can silently # truncate. See jorbit.integrators.budgeted. sys_state = state.to_system() if forced_landing: positions, velocities, steps = budgeted_forced_landing( sys_state, self.gravity, times, integrator_state, self._step_scheduler ) else: positions, velocities, steps = stitched_interpolate( sys_state, self.gravity, times, integrator_state, self._step_scheduler ) return positions[:, 0, :], velocities[:, 0, :], steps
[docs] def integrate( self, times: Time, state: CartesianState | KeplerianState | None = None, return_steps: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the particle's orbit to specified times, landing exactly on each one. Note that this method does not change the state of the particle. It returns the positions and velocities of the particle at the given times, but the particle itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. state (CartesianState | None): The state to integrate from. If None, the particle's current state will be used. Usually not necessary to provide this. return_steps (bool): Whether to return the number of steps taken to reach each output time. If True, the method returns a tuple of (positions, velocities, steps). If False, only returns (positions, velocities). Defaults to False. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. If return_steps is True, also returns an array of the number of steps taken to reach each output time. """ x, v, steps = self._integrate_base(times, state, forced_landing=True) if return_steps: return x, v, steps return x, v
[docs] def integrate_or_interpolate( self, times: Time, state: CartesianState | KeplerianState | None = None, return_steps: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Integrate the particle's orbit to specified times, overshooting and 'interpolating' if necessary. Note that this method does not change the state of the particle. It returns the positions and velocities of the particle at the given times, but the particle itself is not changed. Args: times (Time | jnp.ndarray): The times to integrate to. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. state (CartesianState | None): The state to integrate from. If None, the particle's current state will be used. Usually not necessary to provide this. return_steps (bool): Whether to return the number of steps taken to reach each output time. If True, the method returns a tuple of (positions, velocities, steps). If False, only returns (positions, velocities). Defaults to False. Returns: tuple[jnp.ndarray, jnp.ndarray]: The positions of the particle at the given times, in AU, and the The velocities of the particle at the given times, in AU/day. If return_steps is True, also returns an array of the number of steps taken to reach each output time. """ x, v, steps = self._integrate_base(times, state, forced_landing=False) if return_steps: return x, v, steps return x, v
[docs] def ephemeris( self, times: Time, observer: str | jnp.ndarray, state: CartesianState | KeplerianState | None = None, interpolate: bool = True, uncertainty: bool = False, return_steps: bool = False, ) -> SkyCoord | tuple: """Compute an ephemeris for the particle. Args: times (Time | jnp.ndarray): The times to compute the ephemeris for. Can be a single time or an array of times. If provided as a jnp.array, the entries are assumed to be in TDB JD. observer (str | jnp.ndarray): The observer to compute the ephemeris for. Can be a string representing an observatory name, or a 3D position vector in AU. For more info on acceptable strings, see the get_observer_positions function. state (CartesianState | None): The state to compute the ephemeris from. If None, the particle's current state will be used. Usually not necessary to provide this. interpolate (bool): Whether to use `integrate` or `integrate_or_interpolate` for the underlying integrations. uncertainty (bool): If True, also propagate the 6x6 covariance matrix stored on the state's ``cov`` field onto the sky plane via forward-mode autodiff (linear error propagation). The covariance is propagated in whichever parameterization the input state uses (Cartesian or Keplerian); no automatic conversion between parameterizations is performed. Expect roughly a ``6x`` cost relative to the nominal call because ``jax.jacfwd`` performs one JVP evaluation per input dimension. Defaults to False. return_steps (bool): If True, also return the total number of integration steps taken (summed across any stitched interpolation chunks / inserted forced-landing landings). Appended as the final element of the returned tuple. For the IAS15 integrator this is the figure to watch: the nominal ephemeris is truncation-proof, but the count reveals an unusually heavy integration. ``None`` for the analytic Keplerian and fixed-step leapfrog paths, which cannot truncate. Defaults to False. Returns: coords (SkyCoord | tuple): If ``uncertainty=False`` (default), returns a SkyCoord with the ephemeris of the particle at the given times, in ICRS coordinates, as seen from that specific observer and correcting for light travel time. If ``uncertainty=True``, returns a tuple with the same SkyCoord and the propagated ``(N, 2, 2)`` sky-plane covariance in ``arcsec**2`` (axis order (RA, Dec)). If ``return_steps=True``, the step count is appended as the final tuple element. """ if isinstance(observer, str): observer_positions = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_positions = observer if state is not None: self._validate_state_epoch(state) if uncertainty: cov_state = ( state if state is not None else ( self._keplerian_state if self._is_keplerian else self._cartesian_state ) ) cov = cov_state.cov if cov.shape != (6, 6): raise ValueError( "uncertainty=True requires a (6, 6) covariance matrix on the " f"state's `cov` field, but got shape {tuple(cov.shape)}. " "Build a CartesianState or KeplerianState with `cov=...` and pass " "it via the `state=` keyword (or set it on the particle's " "internal state)." ) if self._is_keplerian: offsets = self._times_to_offsets(times) if offsets.shape == (): offsets = jnp.array([offsets]) if uncertainty: kep_state = state if state is not None else self._keplerian_state ras, decs, cov_radec = _keplerian_ephem_with_cov( kep_state, offsets, observer_positions, cov ) coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") # Keplerian propagation is analytic: no integration steps to report. if return_steps: return (coords, cov_radec, None) return (coords, cov_radec) if state is not None: state = state.to_cartesian() x, v, t0 = ( state.x.flatten(), state.v.flatten(), state.relative_time, ) else: x, v, t0 = self._x, self._v, self._time ras, decs = _keplerian_ephem(x, v, t0, offsets, observer_positions) coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if return_steps: return (coords, None) return coords if state is None: state = self._cartesian_state integrator_state = self._integrator_state else: a0 = self.gravity(state.to_system()) integrator_state = initialize_ias15_integrator_state(a0) times = self._times_to_offsets(times) if times.shape == (): times = jnp.array([times]) if self._integrator_method in ["Y4", "Y6", "Y8"]: times, inds = create_leapfrog_times( state.relative_time, times, self._max_step_size ) else: inds = jnp.arange(times.shape[0]) integrator = self._integrator if interpolate else self._forced_integrator is_leapfrog = self._integrator_method in ("Y4", "Y6", "Y8") # IAS15 with interpolation has dense-output b-coefficients available, so # we can use them to evaluate the polynomial at light-travel-delayed times in on_sky's # LTT loop instead of a constant-acceleration Taylor expansion. All other # paths (leapfrog, IAS15 forced-landing) fall back to the original Taylor. use_dense_ltt = interpolate and not is_leapfrog if uncertainty: # The covariance path runs the integration inside jax.jacfwd, which cannot # be threaded through the host-side stitching loop. For IAS15, detect a span # that would overflow the dense-output buffer and raise rather than silently # truncate (a nominal ephemeris() call auto-handles the same span). steps = None if not is_leapfrog: would_truncate, steps = ias15_span_probe( state.to_system(), self.gravity, times, integrator_state, self._step_scheduler, ) if would_truncate: raise RuntimeError( "ephemeris(uncertainty=True) over this span would exceed the " "IAS15 dense-output buffer and silently truncate. The covariance " "path uses forward-mode autodiff and cannot be transparently " "stitched; request a shorter span or more closely spaced times. " "(A nominal ephemeris(uncertainty=False) call auto-handles this.)" ) if use_dense_ltt: ras, decs, cov_radec = _ephem_ias15_with_cov( times, state, self.gravity, observer_positions, inds, self._step_scheduler, cov, ) else: ras, decs, cov_radec = _ephem_with_cov( times, state, self.gravity, integrator, integrator_state, observer_positions, inds, self._step_scheduler, cov, ) coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if return_steps: return (coords, cov_radec, steps) return (coords, cov_radec) if use_dense_ltt: ras, decs, steps = _ephem_ias15_stitched( times, state, self.gravity, integrator_state, observer_positions, inds, self._step_scheduler, ) elif not is_leapfrog: # IAS15 forced-landing (interpolate=False): Taylor-LTT on budgeted landings. ras, decs, steps = _ephem_forced_budgeted( times, state, self.gravity, integrator_state, observer_positions, inds, self._step_scheduler, ) else: ras, decs = _ephem( times, state, self.gravity, integrator, integrator_state, observer_positions, inds, self._step_scheduler, ) steps = None coords = SkyCoord(ra=ras, dec=decs, unit=u.rad, frame="icrs") if return_steps: return (coords, steps) return coords
[docs] def max_likelihood( self, fit_seed: CartesianState | KeplerianState | None = None, verbose: bool = False, ) -> Particle: """Find the maximum likelihood orbit for the particle. Args: fit_seed (CartesianState | KeplerianState | None): A seed for fitting the orbit of the particle. If None, a seed will be generated from the observations if they exist. Otherwise, a circular orbit with semi-major axis 2.5 AU will be used. verbose (bool): Whether to print the optimization progress. Defaults to False. Returns: Particle: A new Particle object whose state matches the maximum likelihood orbit. """ if self.loglike is None: raise ValueError("No observations provided, cannot fit an orbit") if fit_seed is None: fit_seed = self._fit_seed x0 = jnp.concatenate( [ fit_seed.to_cartesian().x.flatten(), fit_seed.to_cartesian().v.flatten(), ] ) if (not self._is_keplerian) and self._integrator_method == "ias15": # Best-effort detect-and-raise guard. The likelihood integrates the # observation span with the interpolation path (the 15k dense-output # buffer). The orbit changes during optimization, but if even the seed # already overflows the buffer the fit would silently truncate. Orbit # fitting runs inside forward-mode autodiff and so cannot be transparently # stitched the way the nominal integrate/ephemeris methods are. seed_state = CartesianState( x=jnp.array([x0[:3]]), v=jnp.array([x0[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) obs_times = self._observations_times_as_offsets() a0 = self.gravity(seed_state.to_system()) would_truncate, _steps = ias15_span_probe( seed_state.to_system(), self.gravity, obs_times, initialize_ias15_integrator_state(a0), self._step_scheduler, ) if would_truncate: raise RuntimeError( "The observation time span would exceed the IAS15 dense-output " "buffer and silently truncate the likelihood integration. Orbit " "fitting runs inside forward-mode autodiff and cannot be " "transparently stitched; fit over a shorter arc or split the data." ) if self._is_keplerian: # Keplerian propagation produces NaN for hyperbolic orbits, so we # guard the objective and gradient. We also scale the parameters so # the initial identity-Hessian step in L-BFGS-B is well-sized. scale = jnp.abs(x0) + 1e-10 def _scaled_objective(x_scaled: jnp.ndarray) -> float: val = self.scipy_objective(x_scaled * scale) return float(jnp.where(jnp.isnan(val), 1e30, val)) def _scaled_grad(x_scaled: jnp.ndarray) -> jnp.ndarray: g = self.scipy_objective_grad(x_scaled * scale) * scale return jnp.where(jnp.isnan(g), 0.0, g) result = minimize( fun=_scaled_objective, x0=x0 / scale, jac=_scaled_grad, method="L-BFGS-B", options={ "disp": verbose, "maxls": 100, "maxcor": 100, "maxfun": 5000, "maxiter": 1000, "ftol": 1e-14, }, ) result.x = result.x * scale else: result = minimize( fun=lambda x: self.scipy_objective(x), x0=x0, jac=lambda x: self.scipy_objective_grad(x), method="L-BFGS-B", options={ "disp": verbose, "maxls": 100, "maxcor": 100, "maxfun": 5000, "maxiter": 1000, "ftol": 1e-14, }, ) if result.success: c = CartesianState( x=jnp.array([result.x[:3]]), v=jnp.array([result.x[3:]]), relative_time=self._time, time_reference=self._t_ref_jd, acceleration_func_kwargs=self._acc_func_kwargs, ) if c.to_keplerian().ecc > 1: warnings.warn( "Warning: max_likelihood fit produced an unbound orbit", RuntimeWarning, stacklevel=2, ) return Particle( x=result.x[:3], v=result.x[3:], time=self._t_ref_astropy, state=None, observations=self._observations, name=self._name, gravity=self._gravity_str, de_ephemeris_version=self._de_ephemeris_version, integrator=self._integrator_method, earliest_time=self._earliest_time, latest_time=self._latest_time, fit_seed=self._fit_seed, max_step_size=self._max_step_size, ) else: raise ValueError("Failed to converge")
[docs] def is_observable( self, times: Time, observer: str | jnp.ndarray, sun_limit: float = 20.0, ephem: SkyCoord | None = None, return_angle: bool = False, ) -> jnp.ndarray: """Check whether a particle is observable or too close to the Sun. Args: times (Time): The times to check the observability. observer (str | jnp.ndarray): The observer/observatory making the observations. Can be a string for the name/code of an observatory, or a jnp.array of 3D barycentric ICRS positions in AU. sun_limit (float): The minimum allowed angular separation from the Sun, in degrees. Defaults to 20 degrees. ephem (SkyCoord | None, optional): Optionally, the ephemeris of the particle at the given times. If not provided, will be computed using the ephemeris method. Helpful if you've already computed the ephemeris and want to avoid doing it twice. return_angle (bool, optional): If True, will return the angles to the Sun in degrees, not the mask. Default is False. Returns: jnp.ndarray: A boolean array indicating whether the particle is observable at each time (True) or too close to the Sun (False). """ if isinstance(observer, str): observer_pos = get_observer_positions( times, observer, self._de_ephemeris_version ) else: observer_pos = observer if ephem is None: ephem = self.ephemeris(times, observer) if isinstance(times, Time): times = jnp.array(times.tdb.jd) if times.shape == (): times = jnp.array([times]) eph = Ephemeris( earliest_time=self._earliest_time, latest_time=self._latest_time, ) sun_pos = jax.vmap(eph.processor.state)(times)[0][:, 0, :] eph_unit_vec = jnp.array(ephem.cartesian.xyz).T # want the angle between the eph_unit_vec, with its tail on the observer, and the vector from the observer to the sun obs_to_sun_vec = sun_pos - observer_pos obs_to_sun_unit_vec = ( obs_to_sun_vec / jnp.linalg.norm(obs_to_sun_vec, axis=1)[:, None] ) angle = jnp.arccos(jnp.sum(obs_to_sun_unit_vec * eph_unit_vec, axis=1)) if return_angle: return jnp.rad2deg(angle) else: return angle > jnp.deg2rad(sun_limit)