Source code for jorbit.observation

"""Module for the Observations class."""

import jax

jax.config.update("jax_enable_x64", True)
import warnings

import jax.numpy as jnp

warnings.filterwarnings("ignore", module="erfa")
import astropy.units as u
from astropy.coordinates import ICRS, SkyCoord
from astropy.time import Time

from jorbit.data.observatory_codes import OBSERVATORY_CODES
from jorbit.utils.horizons import get_observer_positions
from jorbit.utils.mpc import read_mpc_file


[docs] class Observations: """The Observations class. This is a container for astrometric observations of a particle at different times. When a user supplies times, coordinates, and observatory names, this class will under-the-hood pre-compute the required covariance matrices required for fitting, query Horizons to get the Barycentric positions of the observers, and store everything as ready-to-use-later JAX arrays. """ def __init__( self, observed_coordinates: SkyCoord | None = None, times: Time | None = None, observatories: str | list[str] | None = None, astrometric_uncertainties: u.Quantity | None = None, de_ephemeris_version: str | None = "440", mpc_file: str | None = None, ) -> None: """Initialize the Observations class. Args: observed_coordinates (SkyCoord | None): The observed coordinates of the particle. None if loading an MPC file. times (Time): The times of the observations. None if loading an MPC file. observatories (str | list[str] | None): The observatories where the observations were made. If only one observatory is used, it's assumed that all observations were made from that observatory. None if loading an MPC file. astrometric_uncertainties (u.Quantity | None): The astrometric uncertainties of the observations. None if loading an MPC file. de_ephemeris_version (str | None): Which version of the JPL DE ephemeris to use when calculating Earth's position. Accepts either "440" or "430", default is "440". mpc_file (str | None): The path to an MPC file containing the observations. """ self._observed_coordinates = observed_coordinates self._times = times self._observatories = observatories self._astrometric_uncertainties = astrometric_uncertainties self._de_ephemeris_version = de_ephemeris_version self._mpc_file = mpc_file self._input_checks() ( self._ra, self._dec, self._times, self._times_astropy, self._observatories, self._astrometric_uncertainties, self._observer_positions, self._cov_matrices, self._inv_cov_matrices, self._cov_log_dets, ) = self._parse_astrometry() self._final_init_checks() def __repr__(self) -> str: """Return a string representation of the Observations class.""" return f"Observations with {len(self._ra)} set(s) of observations" def __len__(self) -> int: """Return the number of observations.""" return len(self._ra) def __add__(self, newobs: "Observations") -> "Observations": """Add two Observations objects together.""" t = jnp.concatenate([self._times, newobs.times]) ra = jnp.concatenate([self._ra, newobs.ra]) dec = jnp.concatenate([self._dec, newobs.dec]) obs_precision = jnp.concatenate( [self._astrometric_uncertainties, newobs.astrometric_uncertainties] ) observer_positions = jnp.concatenate( [self._observer_positions, newobs.observer_positions] ) order = jnp.argsort(t) t_sorted = t[order] # Pass times as an astropy Time so that the new Observations object retains # a valid times_astropy attribute. Passing a raw jnp array causes the # __init__ parser to set times_astropy=None, which forces downstream code # to reconstruct times from the float JD array and introduces tiny # floating-point offsets that can break the static-likelihood dt_seed logic. t_astropy = Time(t_sorted.tolist(), format="jd", scale="tdb") return Observations( observed_coordinates=SkyCoord(ra=ra[order], dec=dec[order], unit=u.rad), times=t_astropy, observatories=observer_positions[order], astrometric_uncertainties=obs_precision[order], mpc_file=None, ) def __getitem__(self, index: int | slice | jnp.ndarray) -> "Observations": """Return a new Observations object from a slice of the current one.""" ra = self._slice_observation_axis(self._ra, index) dec = self._slice_observation_axis(self._dec, index) times = self._slice_observation_axis(self._times, index) astrometric_uncertainties = self._slice_observation_axis( self._astrometric_uncertainties, index ) observatories = self._observatories[index] if isinstance(self._observatories, jnp.ndarray): observatories = self._slice_observation_axis(self._observatories, index) # Preserve the high-precision astropy Time if available. Passing it as # `times` lets the constructor set _times_astropy correctly, so that # downstream code (_observations_times_as_offsets, precompute_likelihood_data) # can still use the full jd1+jd2 precision path. # Integer indexing on an astropy Time array returns a scalar Time, which # doesn't have len() and breaks _input_checks. Wrap it to stay 1-D. if self._times_astropy is not None: t_sliced = self._times_astropy[index] if t_sliced.isscalar: t_sliced = Time([t_sliced.jd], format="jd", scale=t_sliced.scale) times = t_sliced # _astrometric_uncertainties is stored internally in arcsec (1-D) or as # dimensionless arcsec² covariance matrices (N, 2, 2). The constructor # distinguishes these by ndim, so we pass the raw JAX array directly: # attaching `* u.arcsec` is wrong for the covariance-matrix path (the # unit would be arcsec, not arcsec²) and unnecessary for the 1-D path # (the values are already in arcsec). mpc_file is cleared: the slice # is no longer the full file. return Observations( observed_coordinates=SkyCoord(ra=ra, dec=dec, unit=u.rad), times=times, observatories=observatories, astrometric_uncertainties=astrometric_uncertainties, mpc_file=None, ) @staticmethod def _slice_observation_axis( values: jnp.ndarray, index: int | slice | jnp.ndarray ) -> jnp.ndarray: """Slice an observation-indexed array without dropping singleton rows.""" sliced = values[index] if sliced.ndim == values.ndim - 1: sliced = sliced[jnp.newaxis, ...] return sliced @property def ra(self) -> jnp.ndarray: """Right ascension of the observations in radians, ICRS.""" return self._ra @property def dec(self) -> jnp.ndarray: """Declination of the observations in radians, ICRS.""" return self._dec @property def times(self) -> jnp.ndarray: """Times of the observations in JD TDB.""" return self._times @property def times_astropy(self) -> Time | None: """Original astropy Time (TDB) if provided at construction; else None.""" return self._times_astropy @property def observatories(self) -> list[str] | str: """Names of the observatories.""" return self._observatories @property def astrometric_uncertainties(self) -> jnp.ndarray: """Astrometric uncertainties of the observations in arcseconds.""" return self._astrometric_uncertainties @property def observer_positions(self) -> jnp.ndarray: """Barycentric cartesian positions of the observers in AU.""" return self._observer_positions @property def cov_matrices(self) -> jnp.ndarray: """Covariance matrices of the observations in arcsec^2.""" return self._cov_matrices @property def inv_cov_matrices(self) -> jnp.ndarray: """Inverse covariance matrices of the observations in arcsec^-2.""" return self._inv_cov_matrices @property def cov_log_dets(self) -> jnp.ndarray: """Log determinants of the covariance matrices.""" return self._cov_log_dets #################################################################################### # Initialization helpers def _input_checks(self) -> None: if self._mpc_file is None: assert ( (self._observed_coordinates is not None) and (self._times is not None) and (self._observatories is not None) and (self._astrometric_uncertainties is not None) ), ( "If no MPC file is provided, observed_coordinates, times," " observatories, and astrometric_uncertainties must be given" " manually." ) if not isinstance( self._times, (type(Time("2023-01-01")), list, jnp.ndarray) ): raise ValueError( "times must be either astropy.time.Time, list of astropy.time.Time," " or jax.numpy.ndarray (interpreted as JD in TDB)" ) assert isinstance(self._observatories, (str, list, jnp.ndarray)), ( "observatories must be either a string (interpreted as an MPC" " observatory code), a list of observatory codes, or a" " jax.numpy.ndarray" ) if isinstance(self._observatories, list): assert len(self._observatories) == len(self._times), ( "If observatories is a list, it must be the same length as" " the number of observations." ) elif isinstance(self._observatories, jnp.ndarray): assert len(self._observatories) == len(self._times), ( "If observatories is a jax.numpy.ndarray, it must be the" " same length as the number of observations." ) else: assert ( (self._observed_coordinates is None) and (self._times is None) and (self._observatories is None) and (self._astrometric_uncertainties is None) ), ( "If an MPC file is provided, observed_coordinates, times," " observatories, and astrometric_uncertainties must be None." ) def _parse_astrometry(self) -> tuple: if self._mpc_file is None: ( observed_coordinates, times, observatories, astrometric_uncertainties, ) = ( self._observed_coordinates, self._times, self._observatories, self._astrometric_uncertainties, ) else: ( observed_coordinates, times, observatories, astrometric_uncertainties, ) = read_mpc_file(self._mpc_file) # POSITIONS if isinstance(observed_coordinates, SkyCoord): # in case they're barycentric, etc s = observed_coordinates.transform_to(ICRS) ra = s.ra.rad dec = s.dec.rad elif isinstance(observed_coordinates, list): ras = [] decs = [] for s in observed_coordinates: s = s.transform_to(ICRS) ras.append(s.ra.rad) decs.append(s.dec.rad) ra = jnp.array(ras) dec = jnp.array(decs) if ra.shape == (): ra = jnp.array([ra]) dec = jnp.array([dec]) # TIMES if isinstance(times, Time): times_astropy = times.tdb if times_astropy.isscalar: times_astropy = Time([times_astropy.jd], format="jd", scale="tdb") times = jnp.array(times.tdb.jd) elif isinstance(times, list): times_astropy = Time([t.tdb for t in times]) times = jnp.array([t.tdb.jd for t in times]) else: times_astropy = None if times.shape == (): times = jnp.array([times]) # OBSERVER POSITIONS if isinstance(observatories, str): observatories = [observatories] * len(times) if isinstance(observatories, list): for i, loc in enumerate(observatories): loc = loc.lower() if loc in OBSERVATORY_CODES: observatories[i] = OBSERVATORY_CODES[loc] elif "@" in loc: pass else: raise ValueError( f"Observer location '{loc}' is not a recognized observatory. Please" " refer to" " https://minorplanetcenter.net/iau/lists/ObsCodesF.html" ) observer_positions = get_observer_positions( times=Time(times, format="jd", scale="tdb"), observatories=observatories, de_ephemeris_version=self._de_ephemeris_version, ) else: observer_positions = observatories # UNCERTAINTIES if astrometric_uncertainties.shape == (): astrometric_uncertainties = ( jnp.ones(len(times)) * astrometric_uncertainties.to(u.arcsec).value ) if isinstance(astrometric_uncertainties, u.Quantity): astrometric_uncertainties = astrometric_uncertainties.to(u.arcsec).value # if our uncertainties are 1D, convert to diagonal covariance matrices if astrometric_uncertainties.ndim == 1: cov_matrices = jnp.array( [jnp.diag(jnp.array([a**2, a**2])) for a in astrometric_uncertainties] ) else: cov_matrices = astrometric_uncertainties inv_cov_matrices = jnp.array([jnp.linalg.inv(c) for c in cov_matrices]) cov_log_dets = jnp.log(jnp.array([jnp.linalg.det(c) for c in cov_matrices])) return ( jnp.array(ra), jnp.array(dec), times, times_astropy, observatories, astrometric_uncertainties, jnp.array(observer_positions), jnp.array(cov_matrices), jnp.array(inv_cov_matrices), jnp.array(cov_log_dets), ) def _final_init_checks(self) -> None: assert ( len(self._ra) == len(self._dec) == len(self._times) == len(self.observer_positions) == len(self.astrometric_uncertainties) ), ( f"Inputs must have the same length. Currently: ra={len(self._ra)}, dec={len(self._dec)}, times={len(self._times)}," f" observer_positions={len(self.observer_positions)}, astrometric_uncertainties={len(self.astrometric_uncertainties)}" ) t = self._times[0] for i in range(1, len(self._times)): assert ( self._times[i] > t ), "Observations must be in ascending chronological order." t = self._times[i]