"""Module for the Observations class."""
import jax
jax.config.update("jax_enable_x64", True)
import warnings
import jax.numpy as jnp
warnings.filterwarnings("ignore", module="erfa")
import astropy.units as u
from astropy.coordinates import ICRS, SkyCoord
from astropy.time import Time
from jorbit.data.observatory_codes import OBSERVATORY_CODES
from jorbit.utils.horizons import get_observer_positions
from jorbit.utils.mpc import read_mpc_file
[docs]
class Observations:
"""The Observations class.
This is a container for astrometric observations of a particle at different times.
When a user supplies times, coordinates, and observatory names, this class will
under-the-hood pre-compute the required covariance matrices required for fitting,
query Horizons to get the Barycentric positions of the observers, and store
everything as ready-to-use-later JAX arrays.
"""
def __init__(
self,
observed_coordinates: SkyCoord | None = None,
times: Time | None = None,
observatories: str | list[str] | None = None,
astrometric_uncertainties: u.Quantity | None = None,
de_ephemeris_version: str | None = "440",
mpc_file: str | None = None,
) -> None:
"""Initialize the Observations class.
Args:
observed_coordinates (SkyCoord | None):
The observed coordinates of the particle. None if loading an MPC file.
times (Time):
The times of the observations. None if loading an MPC file.
observatories (str | list[str] | None):
The observatories where the observations were made. If only one
observatory is used, it's assumed that all observations were made from
that observatory. None if loading an MPC file.
astrometric_uncertainties (u.Quantity | None):
The astrometric uncertainties of the observations. None if loading an
MPC file.
de_ephemeris_version (str | None):
Which version of the JPL DE ephemeris to use when calculating Earth's
position. Accepts either "440" or "430", default is "440".
mpc_file (str | None):
The path to an MPC file containing the observations.
"""
self._observed_coordinates = observed_coordinates
self._times = times
self._observatories = observatories
self._astrometric_uncertainties = astrometric_uncertainties
self._de_ephemeris_version = de_ephemeris_version
self._mpc_file = mpc_file
self._input_checks()
(
self._ra,
self._dec,
self._times,
self._times_astropy,
self._observatories,
self._astrometric_uncertainties,
self._observer_positions,
self._cov_matrices,
self._inv_cov_matrices,
self._cov_log_dets,
) = self._parse_astrometry()
self._final_init_checks()
def __repr__(self) -> str:
"""Return a string representation of the Observations class."""
return f"Observations with {len(self._ra)} set(s) of observations"
def __len__(self) -> int:
"""Return the number of observations."""
return len(self._ra)
def __add__(self, newobs: "Observations") -> "Observations":
"""Add two Observations objects together."""
t = jnp.concatenate([self._times, newobs.times])
ra = jnp.concatenate([self._ra, newobs.ra])
dec = jnp.concatenate([self._dec, newobs.dec])
obs_precision = jnp.concatenate(
[self._astrometric_uncertainties, newobs.astrometric_uncertainties]
)
observer_positions = jnp.concatenate(
[self._observer_positions, newobs.observer_positions]
)
order = jnp.argsort(t)
t_sorted = t[order]
# Pass times as an astropy Time so that the new Observations object retains
# a valid times_astropy attribute. Passing a raw jnp array causes the
# __init__ parser to set times_astropy=None, which forces downstream code
# to reconstruct times from the float JD array and introduces tiny
# floating-point offsets that can break the static-likelihood dt_seed logic.
t_astropy = Time(t_sorted.tolist(), format="jd", scale="tdb")
return Observations(
observed_coordinates=SkyCoord(ra=ra[order], dec=dec[order], unit=u.rad),
times=t_astropy,
observatories=observer_positions[order],
astrometric_uncertainties=obs_precision[order],
mpc_file=None,
)
def __getitem__(self, index: int | slice | jnp.ndarray) -> "Observations":
"""Return a new Observations object from a slice of the current one."""
ra = self._slice_observation_axis(self._ra, index)
dec = self._slice_observation_axis(self._dec, index)
times = self._slice_observation_axis(self._times, index)
astrometric_uncertainties = self._slice_observation_axis(
self._astrometric_uncertainties, index
)
observatories = self._observatories[index]
if isinstance(self._observatories, jnp.ndarray):
observatories = self._slice_observation_axis(self._observatories, index)
# Preserve the high-precision astropy Time if available. Passing it as
# `times` lets the constructor set _times_astropy correctly, so that
# downstream code (_observations_times_as_offsets, precompute_likelihood_data)
# can still use the full jd1+jd2 precision path.
# Integer indexing on an astropy Time array returns a scalar Time, which
# doesn't have len() and breaks _input_checks. Wrap it to stay 1-D.
if self._times_astropy is not None:
t_sliced = self._times_astropy[index]
if t_sliced.isscalar:
t_sliced = Time([t_sliced.jd], format="jd", scale=t_sliced.scale)
times = t_sliced
# _astrometric_uncertainties is stored internally in arcsec (1-D) or as
# dimensionless arcsec² covariance matrices (N, 2, 2). The constructor
# distinguishes these by ndim, so we pass the raw JAX array directly:
# attaching `* u.arcsec` is wrong for the covariance-matrix path (the
# unit would be arcsec, not arcsec²) and unnecessary for the 1-D path
# (the values are already in arcsec). mpc_file is cleared: the slice
# is no longer the full file.
return Observations(
observed_coordinates=SkyCoord(ra=ra, dec=dec, unit=u.rad),
times=times,
observatories=observatories,
astrometric_uncertainties=astrometric_uncertainties,
mpc_file=None,
)
@staticmethod
def _slice_observation_axis(
values: jnp.ndarray, index: int | slice | jnp.ndarray
) -> jnp.ndarray:
"""Slice an observation-indexed array without dropping singleton rows."""
sliced = values[index]
if sliced.ndim == values.ndim - 1:
sliced = sliced[jnp.newaxis, ...]
return sliced
@property
def ra(self) -> jnp.ndarray:
"""Right ascension of the observations in radians, ICRS."""
return self._ra
@property
def dec(self) -> jnp.ndarray:
"""Declination of the observations in radians, ICRS."""
return self._dec
@property
def times(self) -> jnp.ndarray:
"""Times of the observations in JD TDB."""
return self._times
@property
def times_astropy(self) -> Time | None:
"""Original astropy Time (TDB) if provided at construction; else None."""
return self._times_astropy
@property
def observatories(self) -> list[str] | str:
"""Names of the observatories."""
return self._observatories
@property
def astrometric_uncertainties(self) -> jnp.ndarray:
"""Astrometric uncertainties of the observations in arcseconds."""
return self._astrometric_uncertainties
@property
def observer_positions(self) -> jnp.ndarray:
"""Barycentric cartesian positions of the observers in AU."""
return self._observer_positions
@property
def cov_matrices(self) -> jnp.ndarray:
"""Covariance matrices of the observations in arcsec^2."""
return self._cov_matrices
@property
def inv_cov_matrices(self) -> jnp.ndarray:
"""Inverse covariance matrices of the observations in arcsec^-2."""
return self._inv_cov_matrices
@property
def cov_log_dets(self) -> jnp.ndarray:
"""Log determinants of the covariance matrices."""
return self._cov_log_dets
####################################################################################
# Initialization helpers
def _input_checks(self) -> None:
if self._mpc_file is None:
assert (
(self._observed_coordinates is not None)
and (self._times is not None)
and (self._observatories is not None)
and (self._astrometric_uncertainties is not None)
), (
"If no MPC file is provided, observed_coordinates, times,"
" observatories, and astrometric_uncertainties must be given"
" manually."
)
if not isinstance(
self._times, (type(Time("2023-01-01")), list, jnp.ndarray)
):
raise ValueError(
"times must be either astropy.time.Time, list of astropy.time.Time,"
" or jax.numpy.ndarray (interpreted as JD in TDB)"
)
assert isinstance(self._observatories, (str, list, jnp.ndarray)), (
"observatories must be either a string (interpreted as an MPC"
" observatory code), a list of observatory codes, or a"
" jax.numpy.ndarray"
)
if isinstance(self._observatories, list):
assert len(self._observatories) == len(self._times), (
"If observatories is a list, it must be the same length as"
" the number of observations."
)
elif isinstance(self._observatories, jnp.ndarray):
assert len(self._observatories) == len(self._times), (
"If observatories is a jax.numpy.ndarray, it must be the"
" same length as the number of observations."
)
else:
assert (
(self._observed_coordinates is None)
and (self._times is None)
and (self._observatories is None)
and (self._astrometric_uncertainties is None)
), (
"If an MPC file is provided, observed_coordinates, times,"
" observatories, and astrometric_uncertainties must be None."
)
def _parse_astrometry(self) -> tuple:
if self._mpc_file is None:
(
observed_coordinates,
times,
observatories,
astrometric_uncertainties,
) = (
self._observed_coordinates,
self._times,
self._observatories,
self._astrometric_uncertainties,
)
else:
(
observed_coordinates,
times,
observatories,
astrometric_uncertainties,
) = read_mpc_file(self._mpc_file)
# POSITIONS
if isinstance(observed_coordinates, SkyCoord):
# in case they're barycentric, etc
s = observed_coordinates.transform_to(ICRS)
ra = s.ra.rad
dec = s.dec.rad
elif isinstance(observed_coordinates, list):
ras = []
decs = []
for s in observed_coordinates:
s = s.transform_to(ICRS)
ras.append(s.ra.rad)
decs.append(s.dec.rad)
ra = jnp.array(ras)
dec = jnp.array(decs)
if ra.shape == ():
ra = jnp.array([ra])
dec = jnp.array([dec])
# TIMES
if isinstance(times, Time):
times_astropy = times.tdb
if times_astropy.isscalar:
times_astropy = Time([times_astropy.jd], format="jd", scale="tdb")
times = jnp.array(times.tdb.jd)
elif isinstance(times, list):
times_astropy = Time([t.tdb for t in times])
times = jnp.array([t.tdb.jd for t in times])
else:
times_astropy = None
if times.shape == ():
times = jnp.array([times])
# OBSERVER POSITIONS
if isinstance(observatories, str):
observatories = [observatories] * len(times)
if isinstance(observatories, list):
for i, loc in enumerate(observatories):
loc = loc.lower()
if loc in OBSERVATORY_CODES:
observatories[i] = OBSERVATORY_CODES[loc]
elif "@" in loc:
pass
else:
raise ValueError(
f"Observer location '{loc}' is not a recognized observatory. Please"
" refer to"
" https://minorplanetcenter.net/iau/lists/ObsCodesF.html"
)
observer_positions = get_observer_positions(
times=Time(times, format="jd", scale="tdb"),
observatories=observatories,
de_ephemeris_version=self._de_ephemeris_version,
)
else:
observer_positions = observatories
# UNCERTAINTIES
if astrometric_uncertainties.shape == ():
astrometric_uncertainties = (
jnp.ones(len(times)) * astrometric_uncertainties.to(u.arcsec).value
)
if isinstance(astrometric_uncertainties, u.Quantity):
astrometric_uncertainties = astrometric_uncertainties.to(u.arcsec).value
# if our uncertainties are 1D, convert to diagonal covariance matrices
if astrometric_uncertainties.ndim == 1:
cov_matrices = jnp.array(
[jnp.diag(jnp.array([a**2, a**2])) for a in astrometric_uncertainties]
)
else:
cov_matrices = astrometric_uncertainties
inv_cov_matrices = jnp.array([jnp.linalg.inv(c) for c in cov_matrices])
cov_log_dets = jnp.log(jnp.array([jnp.linalg.det(c) for c in cov_matrices]))
return (
jnp.array(ra),
jnp.array(dec),
times,
times_astropy,
observatories,
astrometric_uncertainties,
jnp.array(observer_positions),
jnp.array(cov_matrices),
jnp.array(inv_cov_matrices),
jnp.array(cov_log_dets),
)
def _final_init_checks(self) -> None:
assert (
len(self._ra)
== len(self._dec)
== len(self._times)
== len(self.observer_positions)
== len(self.astrometric_uncertainties)
), (
f"Inputs must have the same length. Currently: ra={len(self._ra)}, dec={len(self._dec)}, times={len(self._times)},"
f" observer_positions={len(self.observer_positions)}, astrometric_uncertainties={len(self.astrometric_uncertainties)}"
)
t = self._times[0]
for i in range(1, len(self._times)):
assert (
self._times[i] > t
), "Observations must be in ascending chronological order."
t = self._times[i]