Source code for jorbit.astrometry.orbit_fit_seeds

"""Methods for an initial orbit fit from astrometry, incl. Gauss's method."""

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from jorbit import Observations
from jorbit.astrometry.transformations import icrs_to_horizons_ecliptic
from jorbit.data.constants import SPEED_OF_LIGHT, TOTAL_SOLAR_SYSTEM_GM
from jorbit.utils.states import CartesianState, KeplerianState


[docs] def gauss_method_orbit(obs: Observations) -> CartesianState: """Gauss's method for orbit determination from three observations. Args: obs (Observations): A set of three observations. Returns: CartesianState: The state of the best-fitting orbit. """ assert len(obs) == 3, "Gauss's method requires 3 (and only 3) observations" def radec_to_unit(ra: float, dec: float) -> jnp.ndarray: cos_dec = jnp.cos(dec) return jnp.array([cos_dec * jnp.cos(ra), cos_dec * jnp.sin(ra), jnp.sin(dec)]) rho0 = radec_to_unit(obs.ra[0], obs.dec[0]) rho1 = radec_to_unit(obs.ra[1], obs.dec[1]) rho2 = radec_to_unit(obs.ra[2], obs.dec[2]) # Step 1: Calculate time intervals tau1 = obs.times[0] - obs.times[1] tau3 = obs.times[2] - obs.times[1] tau = obs.times[2] - obs.times[0] # Step 2: Calculate cross products p1 = jnp.cross(rho1, rho2) p2 = jnp.cross(rho0, rho2) p3 = jnp.cross(rho0, rho1) # Step 3: Calculate scalar triple product D0 = jnp.dot(rho0, p1) # Step 4: Calculate nine scalar quantities position_vectors = jnp.stack([p1, p2, p3], axis=1) D = obs.observer_positions @ position_vectors # Step 5: Calculate scalar position coefficients A = (1 / D0) * (-D[0, 1] * tau3 / tau + D[1, 1] + D[2, 1] * tau1 / tau) B = (1 / (6 * D0)) * ( D[0, 1] * (tau3**2 - tau**2) * tau3 / tau + D[2, 1] * (tau**2 - tau1**2) * tau1 / tau ) E = jnp.dot(obs.observer_positions[1], rho1) # Step 6: Calculate squared scalar distance of second observation R2_squared = jnp.dot(obs.observer_positions[1], obs.observer_positions[1]) # Step 7: Calculate polynomial coefficients a = -(A**2 + 2 * A * E + R2_squared) b = -2 * TOTAL_SOLAR_SYSTEM_GM * B * (A + E) c = -((TOTAL_SOLAR_SYSTEM_GM * B) ** 2) # Step 8: Solve for r2 (scalar distance) using Newton-Raphson method def polynomial(r: float) -> float: return r**8 + a * r**6 + b * r**3 + c def polynomial_derivative(r: float) -> float: return 8 * r**7 + 6 * a * r**5 + 3 * b * r**2 # Initial guess r2 = 100.0 for _ in range(100): f = polynomial(r2) f_prime = polynomial_derivative(r2) delta = f / f_prime r2 = r2 - delta if abs(delta) < 1e-11: break # Step 9: Calculate slant ranges rho = jnp.zeros(3) # First observation slant range num1 = ( 6 * (D[2, 0] * tau1 / tau3 + D[1, 0] * tau / tau3) * r2**3 + TOTAL_SOLAR_SYSTEM_GM * D[2, 0] * (tau**2 - tau1**2) * tau1 / tau3 ) den1 = 6 * r2**3 + TOTAL_SOLAR_SYSTEM_GM * (tau**2 - tau3**2) rho = rho.at[0].set((1 / D0) * (num1 / den1 - D[0, 0])) # Second observation slant range rho = rho.at[1].set(A + TOTAL_SOLAR_SYSTEM_GM * B / r2**3) # Third observation slant range num3 = ( 6 * (D[0, 2] * tau3 / tau1 - D[1, 2] * tau / tau1) * r2**3 + TOTAL_SOLAR_SYSTEM_GM * D[0, 2] * (tau**2 - tau3**2) * tau3 / tau1 ) den3 = 6 * r2**3 + TOTAL_SOLAR_SYSTEM_GM * (tau**2 - tau1**2) rho = rho.at[2].set((1 / D0) * (num3 / den3 - D[2, 2])) # Step 10: Calculate position vectors r = jnp.zeros((3, 3)) for i in range(3): r = r.at[i].set(obs.observer_positions[i] + rho[i] * [rho0, rho1, rho2][i]) # Step 11: Calculate Lagrange coefficients and velocities # For second observation (as before) f1 = 1 - (TOTAL_SOLAR_SYSTEM_GM / (2 * r2**3)) * tau1**2 f3 = 1 - (TOTAL_SOLAR_SYSTEM_GM / (2 * r2**3)) * tau3**2 g1 = tau1 - (TOTAL_SOLAR_SYSTEM_GM / (6 * r2**3)) * tau1**3 g3 = tau3 - (TOTAL_SOLAR_SYSTEM_GM / (6 * r2**3)) * tau3**3 # Calculate velocity at second observation v2 = (-f3 * r[0] + f1 * r[2]) / (f1 * g3 - f3 * g1) # # Calculate additional Lagrange coefficients for first observation # f21 = 1 - (TOTAL_SOLAR_SYSTEM_GM/(2*r2**3))*(-tau1)**2 # f coefficient from time 2 to 1 # g21 = -tau1 - (TOTAL_SOLAR_SYSTEM_GM/(6*r2**3))*(-tau1)**3 # g coefficient from time 2 to 1 # Calculate derivatives of Lagrange coefficients fdot21 = (TOTAL_SOLAR_SYSTEM_GM / (r2**3)) * tau1 gdot21 = 1 + (TOTAL_SOLAR_SYSTEM_GM / (2 * r2**3)) * tau1**2 # Calculate velocity at first observation using the state transition matrix relationship: # r1 = f21*r2 + g21*v2 # but we already have r[0] # v1 = fdot21*r2 + gdot21*v2 v1 = fdot21 * r[1] + gdot21 * v2 return CartesianState( x=jnp.array([r[0]]), v=jnp.array([v1]), time_reference=obs.times[0], acceleration_func_kwargs={ "c2": SPEED_OF_LIGHT**2, }, )
[docs] def simple_circular(ra: float, dec: float, semi: float, time: float) -> CartesianState: """Compute a circular orbit of a given size that passes through a given coordinate. A simpler alternative to Gauss's method, assumes that the particle is observed at its highest excursion from the ecliptic. Args: ra (float): Right ascension of the object in radians, ICRS. dec (float): Declination of the object in radians, ICRS. semi (float): Semi-major axis of the orbit in AU. time (float): Time of the observation in JD, tdb. Returns: CartesianState: The state of the implied orbit. """ phi = ra theta = jnp.pi / 2 - dec x = jnp.sin(theta) * jnp.cos(phi) y = jnp.sin(theta) * jnp.sin(phi) z = jnp.cos(theta) x_icrs = jnp.hstack([x, y, z]) x = icrs_to_horizons_ecliptic(x_icrs) # assume we're observing the thing at its highest excursion from the ecliptic: inc = jnp.array([jnp.abs(jnp.arcsin(x[2])) * 180 / jnp.pi]) # its longitude of ascending node is the angle between the x-axis and the projection of the vector onto the xy-plane: varphi = (jnp.arctan2(x[1], x[0]) * 180 / jnp.pi) % 360 Omega = ( (jnp.array([varphi]) - 90) if x[2] > 0 else (jnp.array([varphi]) + 90) ) % 360 nu = jnp.array([90.0]) if x[2] > 0 else jnp.array([270.0]) a = jnp.array([semi]) ecc = jnp.array([0.0]) omega = jnp.array([0.0]) k = KeplerianState( semi=a, ecc=ecc, nu=nu, inc=inc, Omega=Omega, omega=omega, time_reference=time, acceleration_func_kwargs={ "c2": SPEED_OF_LIGHT**2, }, ) return k