Source code for jorbit.utils.doubledouble

"""Experimental DoubleDouble precision arithmetic in JAX."""

from __future__ import annotations

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp


[docs] @jax.tree_util.register_pytree_node_class class DoubleDouble: """An experimental class for 'DoubleDouble' precision arthmetic. This creates a Jax pytree object that stores two jnp.ndarrays, hi and lo, which are the high and low parts of a double-double precision array. Basic arithmetic operations are overloaded to use functions that respect the double-double precision rules. This is not compensated summation, but summation at "DoubleDouble" precision. Attributes: hi (jnp.ndarray): High part. lo: (jnp.ndarray): Low part. """ def __init__(self, hi: jnp.ndarray, lo: jnp.ndarray | None = None) -> None: """Initialize a DoubleDouble number. Args: hi: High part (jnp.ndarray) lo: Low part (jnp.ndarray, optional). If None, lo is set to 0 """ if isinstance(hi, (int, float)) and (lo is None): self.hi, self.lo = DoubleDouble._split(jnp.array(hi)) else: self.hi = jnp.array(hi) self.lo = jnp.zeros_like(hi) if lo is None else lo @staticmethod def _split(a: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Split a 64-bit floating point number into high and low components.""" t = (2**27 + 1) * a a_hi = t - (t - a) a_lo = a - a_hi return a_hi, a_lo @staticmethod def _two_sum(a: jnp.ndarray, b: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Basic two-sum algorithm.""" s = a + b v = s - a e = (a - (s - v)) + (b - v) return s, e @staticmethod def _mul12(x: jnp.ndarray, y: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """The mul12 algorithm from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_.""" # mul12 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf constant = 2**27 + 1 p = x * constant hx = x - p + p tx = x - hx p = y * constant hy = y - p + p ty = y - hy p = hx * hy q = hx * ty + tx * hy z = p + q zz = p - z + q + tx * ty return DoubleDouble(z, zz)
[docs] @classmethod def from_string(cls, s: str) -> DoubleDouble: """Create a DoubleDouble number from a string, similar to mpmath.mpf. Args: s (str): String representation of a number. Returns: DoubleDouble: The DoubleDouble representation. """ assert isinstance(s, str) from decimal import Decimal, getcontext getcontext().prec = 50 d = Decimal(s) hi = float(d) # Compute low part using exact subtraction lo = float(d - Decimal(hi)) # Normalize the components hi, lo = DoubleDouble._two_sum(hi, lo) return cls(jnp.array(hi), jnp.array(lo))
def __str__(self) -> str: """String representation of the DoubleDouble array.""" return f"{self.hi} + {self.lo}" def __repr__(self) -> str: """Representation of the DoubleDouble array.""" return f"DoubleDouble({self.hi}, {self.lo})" def __getitem__(self, index: int) -> DoubleDouble: """Get an item from the DoubleDouble array.""" return DoubleDouble(self.hi[index], self.lo[index]) def __setitem__(self, index: int, value: DoubleDouble) -> None: """Set an item in the DoubleDouble array (note: mutable, unlike jnp.ndarray).""" self.hi = self.hi.at[index].set(value.hi) self.lo = self.lo.at[index].set(value.lo) # @jax.jit def __add__(self, other: DoubleDouble) -> DoubleDouble: """Add two DoubleDouble numbers. Implementation of add2 from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_. """ # add2 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf r = self.hi + other.hi s = jnp.where( jnp.abs(self.hi) > jnp.abs(other.hi), self.hi - r + other.hi + other.lo + self.lo, other.hi - r + self.hi + self.lo + other.lo, ) z = r + s zz = r - z + s return DoubleDouble(z, zz) # @jax.jit def __neg__(self) -> DoubleDouble: """Negate a DoubleDouble number.""" return DoubleDouble(-self.hi, -self.lo) # @jax.jit def __sub__(self, other: DoubleDouble) -> DoubleDouble: """Subtract two DoubleDouble numbers. Implementation of sub2 from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_. """ # sub2 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf r = self.hi - other.hi s = jnp.where( jnp.abs(self.hi) > jnp.abs(other.hi), self.hi - r - other.hi - other.lo + self.lo, -other.hi - r + self.hi + self.lo - other.lo, ) z = r + s zz = r - z + s return DoubleDouble(z, zz) # @jax.jit def __mul__(self, other: DoubleDouble) -> DoubleDouble: """Multiply two DoubleDouble numbers. Implementation of mul2 from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_. """ # mul2 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf c = DoubleDouble._mul12(self.hi, other.hi) cc = self.hi * other.lo + self.lo * other.hi + c.lo z = c.hi + cc zz = c.hi - z + cc return DoubleDouble(z, zz) # @jax.jit def __truediv__(self, other: DoubleDouble) -> DoubleDouble: """Divide two DoubleDouble numbers. Implementation of div2 from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_. """ # div2 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf c = self.hi / other.hi u = DoubleDouble._mul12(c, other.hi) cc = (self.hi - u.hi - u.lo + self.lo - c * other.lo) / other.hi z = c + cc zz = c - z + cc return DoubleDouble(z, zz) # @jax.jit def __abs__(self) -> DoubleDouble: """Absolute value of a DoubleDouble number.""" new_hi = jnp.where(self.hi < 0, -self.hi, self.hi) new_lo = jnp.where(self.hi < 0, -self.lo, self.lo) return DoubleDouble(new_hi, new_lo) def __lt__(self, other: DoubleDouble) -> bool: """Less than comparison of two DoubleDouble numbers.""" return (self.hi < other.hi) | ((self.hi == other.hi) & (self.lo < other.lo)) def __le__(self, other: DoubleDouble) -> bool: """Less than or equal to comparison of two DoubleDouble numbers.""" return (self.hi < other.hi) | ((self.hi == other.hi) & (self.lo <= other.lo)) def __gt__(self, other: DoubleDouble) -> bool: """Greater than comparison of two DoubleDouble numbers.""" return (self.hi > other.hi) | ((self.hi == other.hi) & (self.lo > other.lo)) def __ge__(self, other: DoubleDouble) -> bool: """Greater than or equal to comparison of two DoubleDouble numbers.""" return (self.hi > other.hi) | ((self.hi == other.hi) & (self.lo >= other.lo)) @property def shape(self) -> tuple[int, ...]: """Shape of the DoubleDouble array.""" return self.hi.shape
[docs] def tree_flatten(self) -> tuple: """Implementation for JAX pytree.""" children = (self.hi, self.lo) aux_data = None return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data: None, children: tuple) -> DoubleDouble: """Implementation for JAX pytree.""" return cls(*children)
# @jax.jit
[docs] def dd_max(x: DoubleDouble, axis: int | None = None) -> DoubleDouble: """Sort-of implements jnp.max on a DoubleDouble array. Args: x: DoubleDouble array axis: Axis to reduce over Returns: DoubleDouble: The maximum value """ hi_max = jnp.max(x.hi, axis=axis) max_mask = x.hi == hi_max lo_max = jnp.max(jnp.where(max_mask, x.lo, -jnp.inf), axis=axis) return DoubleDouble(hi_max, lo_max)
# @partial(jax.jit, static_argnames=("axis",))
[docs] def dd_sum(x: DoubleDouble, axis: int | None = None) -> DoubleDouble: """Sort-of implements jnp.sum on a DoubleDouble array. Args: x (DoubleDouble): The DoubleDouble array to sum. axis (int | None): The axis to sum over. If None, the array is flattened. Returns: DoubleDouble: The sum of the array along the given axis. """ # needed to respect DoubleDouble addition rules when doing sums # again- this is *not* compensated summation, but summation at "DoubleDouble" precision if axis is None: x = DoubleDouble(x.hi.flatten(), x.lo.flatten()) axis = 0 # Move the axis to be summed to the front transposed = DoubleDouble(jnp.swapaxes(x.hi, 0, axis), jnp.swapaxes(x.lo, 0, axis)) def scan_fn(carry: DoubleDouble, x: DoubleDouble) -> tuple[DoubleDouble, None]: return carry + x, None result, _ = jax.lax.scan(scan_fn, transposed[0], transposed[1:]) return result
# @jax.jit
[docs] def dd_sqrt(x: DoubleDouble) -> DoubleDouble: """Sort-of implements jnp.sqrt on a DoubleDouble array. Uses sqrt2 from `Dekker 1971 <https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf>`_. Args: x (DoubleDouble): The DoubleDouble array to take the square root of. Returns: DoubleDouble: The square root of the array. """ # sqrt2 from https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf c = jnp.sqrt(x.hi) u = DoubleDouble._mul12(c, c) c_lo = (x.hi - u.hi - u.lo + x.lo) / (2 * c) y = c + c_lo yy = c - y + c_lo return DoubleDouble(y, yy)
# @partial(jax.jit, static_argnames=("axis",))
[docs] def dd_norm(x: DoubleDouble, axis: int | None = None) -> DoubleDouble: """Sort-of implements jnp.linalg.norm on a DoubleDouble array. Uses dd_sum and dd_sqrt. Args: x (DoubleDouble): The DoubleDouble array to take the norm of. axis (int | None): The axis to take the norm over. If None, the array is flattened. Returns: DoubleDouble: The norm of the array. """ if axis is None: x = DoubleDouble(x.hi.flatten(), x.lo.flatten()) axis = 0 return dd_sqrt(dd_sum(x * x, axis=axis))
# @staticmethod # def _two_sum(a: jnp.ndarray, b: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: # s = a + b # v = s - a # e = (a - (s - v)) + (b - v) # return s, e # @staticmethod # def _split(a: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: # t = (2**27 + 1) * a # a_hi = t - (t - a) # a_lo = a - a_hi # return a_hi, a_lo # @staticmethod # def _two_prod(a: jnp.ndarray, b: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: # # https://andrewthall.org/papers/df64_qf128.pdf # x = a * b # a_hi, a_lo = DoubleDouble._split(a) # b_hi, b_lo = DoubleDouble._split(b) # err1 = x - (a_hi * b_hi) # err2 = err1 - (a_lo * b_hi) # err3 = err2 - (a_hi * b_lo) # y = (a_lo * b_lo) - err3 # return x, y