"""Transformations between coordinate systems or representations of a particle's state."""
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jorbit.data.constants import (
HORIZONS_ECLIPTIC_TO_ICRS_ROT_MAT,
ICRS_TO_HORIZONS_ECLIPTIC_ROT_MAT,
)
[docs]
@jax.jit
def icrs_to_horizons_ecliptic(xs: jnp.ndarray) -> jnp.ndarray:
"""Transform an ICRS 3D cartesian position to a Horizons ecliptic 3D cartesian position.
Args:
xs (jnp.ndarray): ICRS 3D cartesian position.
Returns:
jnp.ndarray: Horizons ecliptic 3D cartesian position.
"""
rotated_xs = jnp.dot(xs, ICRS_TO_HORIZONS_ECLIPTIC_ROT_MAT.T)
return rotated_xs
[docs]
@jax.jit
def horizons_ecliptic_to_icrs(xs: jnp.ndarray) -> jnp.ndarray:
"""Transform a Horizons ecliptic 3D cartesian position to an ICRS 3D cartesian position.
Args:
xs (jnp.ndarray): Horizons ecliptic 3D cartesian position.
Returns:
jnp.ndarray: ICRS 3D cartesian position.
"""
rotated_xs = jnp.dot(xs, HORIZONS_ECLIPTIC_TO_ICRS_ROT_MAT.T)
return rotated_xs
[docs]
@jax.jit
def elements_to_cartesian(
a: float, ecc: float, nu: float, inc: float, Omega: float, omega: float, mass: float
) -> tuple:
"""Convert orbital elements to cartesian coordinates.
Relies on the total mass of the solar system, which is assumed to be the sum of all
GM values of the sun, planets, and 16 most massive asteroids as assumed by DE440.
This is the inverse of cartesian_to_elements.
NOTE All orbital element angles are assumed to be *degrees*, in contrast to
sky coordinate angles, which are usually assumed to be in radians when not in their
Astropy SkyCoord form, or angular separations, which are usually assumed to be in
arcsec.
Args:
a (float): Semi-major axis in AU.
ecc (float): Eccentricity.
nu (float): True anomaly in degrees.
inc (float): Inclination in degrees.
Omega (float): Longitude of the ascending node in degrees.
omega (float): Argument of periapsis in degrees.
mass (float): Total mass (GM) of the central object with G in AU^3 / day^2.
Returns:
tuple: (x, v) where x is the position in AU and v is the velocity in AU/day.
"""
# # Each of the elements are (n_particles, )
# # The angles are in *degrees*. Always assuming orbital element angles are in degrees
nu *= jnp.pi / 180
inc *= jnp.pi / 180
Omega *= jnp.pi / 180
omega *= jnp.pi / 180
t = (a * (1 - ecc**2))[:, None]
r_w = (
t
/ (1 + ecc[:, None] * jnp.cos(nu[:, None]))
* jnp.column_stack((jnp.cos(nu), jnp.sin(nu), nu * 0.0))
)
v_w = (
jnp.sqrt(mass)
/ jnp.sqrt(t)
* jnp.column_stack((-jnp.sin(nu), ecc + jnp.cos(nu), nu * 0))
)
zeros = jnp.zeros_like(omega, dtype=jnp.float64)
ones = jnp.ones_like(omega, dtype=jnp.float64)
Rot1 = jnp.array(
[
[jnp.cos(-omega), -jnp.sin(-omega), zeros],
[jnp.sin(-omega), jnp.cos(-omega), zeros],
[zeros, zeros, ones],
]
)
Rot2 = jnp.array(
[
[ones, zeros, zeros],
[zeros, jnp.cos(-inc), -jnp.sin(-inc)],
[zeros, jnp.sin(-inc), jnp.cos(-inc)],
]
)
Rot3 = jnp.array(
[
[jnp.cos(-Omega), -jnp.sin(-Omega), zeros],
[jnp.sin(-Omega), jnp.cos(-Omega), zeros],
[zeros, zeros, ones],
]
)
rot = jax.vmap(
lambda r1, r2, r3: jnp.matmul(jnp.matmul(r1, r2), r3), in_axes=(2, 2, 2)
)(Rot1, Rot2, Rot3)
x = jax.vmap(lambda x, y: jnp.matmul(x, y))(r_w, rot)
v = jax.vmap(lambda x, y: jnp.matmul(x, y))(v_w, rot)
return x, v
[docs]
@jax.jit
def cartesian_to_elements(x: jnp.ndarray, v: jnp.ndarray, mass: float) -> tuple:
"""Convert cartesian coordinates to orbital elements.
Relies on the total mass of the solar system, which is assumed to be the sum of all
GM values of the sun, planets, and 16 most massive asteroids as assumed by DE440.
This is the inverse of elements_to_cartesian. Two degenerate cases are handled:
- Circular orbits (ecc < 1e-10): omega is set to zero and nu becomes the argument
of latitude (or position angle from x-axis for in-plane orbits).
- In-plane orbits (inc = 0, n_mag = 0): Omega is set to zero; omega is set to the
ecliptic longitude of periapsis (atan2(e_y, e_x)); nu for circular in-plane orbits
is the position angle from the x-axis.
Args:
x (jnp.ndarray): Position in AU.
v (jnp.ndarray): Velocity in AU/day.
mass (float): Total mass (GM) of the central object with G in AU^3 / day^2.
Returns:
tuple:
(a, ecc, nu, inc, Omega, omega) where a is the semi-major axis in AU,
ecc is the eccentricity, nu is the true anomaly in degrees, inc is the
inclination in degrees, Omega is the longitude of the ascending node in
degrees, and omega is the argument of periapsis in degrees.
"""
r_mag = jnp.linalg.norm(x, axis=1)
v_mag = jnp.linalg.norm(v, axis=1)
# Specific angular momentum
h = jnp.cross(x, v)
h_mag = jnp.linalg.norm(h, axis=1)
# Eccentricity vector
e_vec = jnp.cross(v, h) / mass - x / r_mag[:, jnp.newaxis]
ecc = jnp.linalg.norm(e_vec, axis=1)
# Specific orbital energy
specific_energy = v_mag**2 / 2 - mass / r_mag
a = -mass / (2 * specific_energy)
inc = jnp.arccos(h[:, 2] / h_mag) * 180 / jnp.pi
n = jnp.cross(jnp.array([0, 0, 1]), h)
n_mag = jnp.linalg.norm(n, axis=1)
# Prevents 0/0 NaN in jnp.where branches when inc=0 (n_mag=0)
safe_n_mag = jnp.where(n_mag == 0, 1.0, n_mag)
Omega = jnp.where(
n[:, 1] >= 0,
jnp.arccos(jnp.clip(n[:, 0] / safe_n_mag, -1, 1)) * 180 / jnp.pi,
360.0 - jnp.arccos(jnp.clip(n[:, 0] / safe_n_mag, -1, 1)) * 180 / jnp.pi,
)
Omega = jnp.where(n_mag == 0, 0, Omega)
# ── omega (argument of periapsis) ──
# For in-plane orbits (n_mag=0), omega = ecliptic longitude of periapsis = atan2(e_y, e_x).
# safe_e_mag guards ecc=0; the is_circular branch below overrides omega=0 in that case.
safe_e_mag = jnp.where(ecc == 0, 1.0, ecc)
omega_in_plane = jnp.where(
e_vec[:, 1] >= 0,
jnp.arccos(jnp.clip(e_vec[:, 0] / safe_e_mag, -1, 1)) * 180 / jnp.pi,
360.0 - jnp.arccos(jnp.clip(e_vec[:, 0] / safe_e_mag, -1, 1)) * 180 / jnp.pi,
)
omega_standard = jnp.where(
n_mag > 0,
jnp.where(
e_vec[:, 2] >= 0,
jnp.arccos(
jnp.clip(
jnp.sum(n * e_vec, axis=1)
/ (safe_n_mag * jnp.linalg.norm(e_vec, axis=1)),
-1,
1,
)
)
* 180
/ jnp.pi,
360
- jnp.arccos(
jnp.clip(
jnp.sum(n * e_vec, axis=1)
/ (safe_n_mag * jnp.linalg.norm(e_vec, axis=1)),
-1,
1,
)
)
* 180
/ jnp.pi,
),
omega_in_plane,
)
# ── nu (true anomaly) ──
# Standard computation from eccentricity vector
nu_standard = jnp.where(
jnp.sum(x * v, axis=1) >= 0,
jnp.arccos(jnp.clip(jnp.sum(e_vec * x, axis=1) / (ecc * r_mag), -1, 1))
* 180
/ jnp.pi,
360
- jnp.arccos(jnp.clip(jnp.sum(e_vec * x, axis=1) / (ecc * r_mag), -1, 1))
* 180
/ jnp.pi,
)
# ── Circular orbit fallback: omega=0, nu=argument of latitude ──
# u = angle from ascending node to position, measured in the orbital plane.
# For in-plane orbits (n_mag=0), use the position angle from the x-axis instead.
cos_u = jnp.sum(n * x, axis=1) / (safe_n_mag * r_mag)
u_standard = jnp.where(
x[:, 2] >= 0,
jnp.arccos(jnp.clip(cos_u, -1, 1)) * 180 / jnp.pi,
360 - jnp.arccos(jnp.clip(cos_u, -1, 1)) * 180 / jnp.pi,
)
nu_in_plane = jnp.where(
x[:, 1] >= 0,
jnp.arccos(jnp.clip(x[:, 0] / r_mag, -1, 1)) * 180 / jnp.pi,
360 - jnp.arccos(jnp.clip(x[:, 0] / r_mag, -1, 1)) * 180 / jnp.pi,
)
u_deg = jnp.where(n_mag == 0, nu_in_plane, u_standard)
is_circular = ecc < 1e-10 # Threshold for circular orbits, can be tuned
omega = jnp.where(is_circular, 0.0, omega_standard)
nu = jnp.where(is_circular, u_deg, nu_standard)
ecc = jnp.where(is_circular, 0.0, ecc)
return a, ecc, nu, inc, Omega, omega