Integrators¶
The astrojax.integrators module provides numerical ODE integrators for
propagating orbital dynamics forward (or backward) in time. All
integrators are implemented in pure JAX and are compatible with
jax.jit and jax.vmap.
Available Integrators¶
| Integrator | Order | Step Control | Stages | Use Case |
|---|---|---|---|---|
rk4_step |
4 | Fixed | 4 | Simple propagation, differentiable control |
rkf45_step |
4(5) | Adaptive | 6 | General-purpose adaptive integration |
dp54_step |
5(4) | Adaptive | 7 | High-accuracy adaptive integration |
rkn1210_step |
12(10) | Adaptive | 17 | Tight-tolerance second-order ODE integration |
Common Interface¶
All step functions share the same calling convention:
dynamics(t, x) -> dx: The ODE right-hand side.t: Current time (scalar).state: Current state vector (1-D array).dt: Timestep to take (may be negative for backward integration).
The result is a StepResult named tuple with fields state, dt_used,
error_estimate, and dt_next.
RK4: Fixed-Step Integration¶
The classic 4th-order Runge-Kutta method takes exactly the step size
you request. It is the simplest integrator and supports reverse-mode
differentiation (jax.grad), making it suitable for differentiable
control and optimization:
import jax
import jax.numpy as jnp
from astrojax.integrators import rk4_step
def harmonic(t, x):
return jnp.array([x[1], -x[0]])
# Single step
result = rk4_step(harmonic, 0.0, jnp.array([1.0, 0.0]), 0.01)
print(result.state) # ~[cos(0.01), -sin(0.01)]
# Multi-step propagation with lax.scan
def scan_step(state, _):
result = rk4_step(harmonic, 0.0, state, 0.01)
return result.state, result.state
final, trajectory = jax.lax.scan(scan_step, jnp.array([1.0, 0.0]), None, length=100)
Adaptive Methods: RKF45 and DP54¶
The adaptive integrators automatically adjust the step size to keep the local error within configurable tolerances. If a step produces an error above the tolerance, it is rejected and retried with a smaller step.
from astrojax.integrators import rkf45_step, dp54_step, AdaptiveConfig
def two_body(t, state):
r = state[:3]
v = state[3:]
r_norm = jnp.linalg.norm(r)
a = -398600.4418 * r / r_norm**3
return jnp.concatenate([v, a])
# Use default tolerances
result = rkf45_step(two_body, 0.0, state0, 60.0)
# Custom tolerances for higher accuracy
config = AdaptiveConfig(abs_tol=1e-10, rel_tol=1e-8)
result = dp54_step(two_body, 0.0, state0, 60.0, config=config)
# The result tells you what happened
print(f"Step taken: {result.dt_used}")
print(f"Error: {result.error_estimate}")
print(f"Suggested next dt: {result.dt_next}")
RKN1210: Second-Order ODE Specialist¶
The rkn1210_step integrator is a Runge-Kutta-Nyström 12(10) method
specialized for second-order ODEs of the form \(y'' = f(t, y)\). It
achieves 12th-order accuracy with 17 stages per step by exploiting the
second-order structure: the state is split into position and velocity
halves, and only the acceleration is computed at each stage.
This makes RKN1210 significantly more efficient than standard RK methods for orbital mechanics and other second-order systems, particularly when tight tolerances (< \(10^{-10}\)) are required.
The public API is identical to the other integrators — the dynamics
function returns [velocity, acceleration] for a state [position,
velocity], and the integrator internally extracts only the acceleration
half:
from astrojax.integrators import rkn1210_step, AdaptiveConfig
def two_body(t, state):
r = state[:3]
v = state[3:]
r_norm = jnp.linalg.norm(r)
a = -GM_EARTH * r / r_norm**3
return jnp.concatenate([v, a])
# Tight tolerances where RKN1210 excels
config = AdaptiveConfig(abs_tol=1e-12, rel_tol=1e-10)
result = rkn1210_step(two_body, 0.0, state0, 60.0, config=config)
State vector requirement
The state vector must have an even number of elements, with the
first half representing positions and the second half velocities.
This is the standard format for orbital mechanics state vectors
[x, y, z, vx, vy, vz].
Adaptive Configuration¶
The AdaptiveConfig named tuple controls step-size adaptation:
| Parameter | Default | Description |
|---|---|---|
abs_tol |
1e-6 |
Absolute error tolerance per component |
rel_tol |
1e-3 |
Relative error tolerance per component |
safety_factor |
0.9 |
Conservative scaling of step predictions |
min_scale_factor |
0.2 |
Maximum step shrinkage ratio |
max_scale_factor |
10.0 |
Maximum step growth ratio |
min_step |
1e-12 |
Absolute minimum step size |
max_step |
900.0 |
Absolute maximum step size (seconds) |
max_step_attempts |
10 |
Maximum retries per step |
The per-component error tolerance is:
A step is accepted when the normalized error (infinity norm) is \(\leq 1.0\).
Control Inputs¶
All integrators support an optional additive control function. This is useful for thrust manoeuvres, perturbation forces, or feedback control:
def gravity(t, x):
r = x[:3]
r_norm = jnp.linalg.norm(r)
a = -398600.4418 * r / r_norm**3
return jnp.concatenate([x[3:], a])
def thrust(t, x):
# Constant along-track thrust of 1 mm/s^2
v = x[3:]
v_hat = v / jnp.linalg.norm(v)
return jnp.concatenate([jnp.zeros(3), 1e-3 * v_hat])
result = rk4_step(gravity, 0.0, state0, 60.0, control=thrust)
The effective derivative at each stage is dynamics(t, x) + control(t, x).
Backward Integration¶
All integrators support backward integration by passing a negative dt:
JAX Compatibility¶
All integrators work with jax.jit for compilation:
jit_step = jax.jit(lambda t, x, dt: rk4_step(dynamics, t, x, dt))
result = jit_step(0.0, state0, 60.0)
RK4 also supports jax.grad for differentiable simulation. The adaptive
methods (RKF45, DP54, RKN1210) use jax.lax.while_loop internally, which supports
forward-mode differentiation but not reverse-mode jax.grad.
Differentiability trade-off
The adaptive integrators (RKF45, DP54, RKN1210) use jax.lax.while_loop for
step rejection, which makes them JIT-compatible but blocks reverse-mode
autodiff. This is intentional -- for the primary use case (orbit
propagation), forward simulation without gradients is the norm. Users
who need differentiable simulation should use rk4_step inside
jax.lax.scan, which fully supports jax.grad:
def loss(x0):
def scan_step(state, _):
result = rk4_step(dynamics, 0.0, state, dt)
return result.state, None
final, _ = jax.lax.scan(scan_step, x0, None, length=n_steps)
return jnp.sum((final - target) ** 2)
grad = jax.grad(loss)(x0)
This keeps each integrator focused on its strength rather than compromising with a one-size-fits-all approach.
Configurable precision
All integrators respect astrojax.set_dtype(). Call set_dtype()
before JIT compilation to control whether computations use float32
or float64 precision.