Covariance Propagation¶
The astrojax.covariance module provides generic covariance propagation
via variational equations. It augments any differentiable ODE dynamics
function with automatic State Transition Matrix (STM) propagation using
jax.jacfwd — no hand-derived Jacobians needed.
Purpose and When to Use¶
Use covariance propagation when you need to track how state uncertainty evolves over time. Common scenarios:
| Scenario | Why Covariance Propagation Helps |
|---|---|
| Orbit uncertainty growth | Predict how position/velocity errors grow over time |
| Sensor scheduling | Determine when uncertainty exceeds a threshold |
| Initial orbit determination | Assess solution quality via covariance realism |
| Filter initialization | Provide an initial P₀ for the EKF/UKF |
Workflow Overview¶
Covariance propagation follows a four-step workflow:
- Create —
create_variational_dynamics(dynamics, n)augments your dynamics with STM equations - Initialize —
augmented_initial_state(x0, n)builds[x₀, vec(I)] - Integrate — Use any integrator (RK4, RKF45, etc.) on the augmented state
- Extract and map —
extract_state_and_stmsplits the result,propagate_covariancecomputesP = Φ P₀ Φᵀ + Q
Available Functions¶
| Function | Description |
|---|---|
create_variational_dynamics |
Augment dynamics with STM propagation (auto-Jacobian) |
augmented_initial_state |
Build initial [x₀, vec(I_n)] augmented state |
extract_state_and_stm |
Split augmented state into (x, Φ) |
propagate_covariance |
Compute P = Φ P₀ Φᵀ + Q |
Basic Example: Harmonic Oscillator¶
A simple 2-D system to illustrate the workflow:
import jax.numpy as jnp
from astrojax.covariance import (
create_variational_dynamics,
augmented_initial_state,
extract_state_and_stm,
propagate_covariance,
)
from astrojax.integrators import rk4_step
# 1. Define dynamics: x'' = -x
def harmonic(t, x):
return jnp.array([x[1], -x[0]])
# 2. Augment with variational equations
n = 2
aug_dynamics = create_variational_dynamics(harmonic, n)
# 3. Build initial augmented state
x0 = jnp.array([1.0, 0.0])
aug_x0 = augmented_initial_state(x0, n)
# aug_x0 has shape (6,): [x0, x1, Phi_00, Phi_01, Phi_10, Phi_11]
# 4. Integrate (single step for illustration)
result = rk4_step(aug_dynamics, 0.0, aug_x0, 0.1)
# 5. Extract state and STM
x, Phi = extract_state_and_stm(result.state, n)
print(f"State: {x}") # propagated [position, velocity]
print(f"STM:\n{Phi}") # 2x2 state transition matrix
# 6. Map a covariance forward
P0 = jnp.diag(jnp.array([0.01, 0.001]))
P = propagate_covariance(Phi, P0)
print(f"Propagated covariance:\n{P}")
Orbit Dynamics Example: LEO Two-Body¶
A more realistic example propagating a 500 km circular LEO orbit covariance over one orbital period:
import jax
import jax.numpy as jnp
from astrojax import set_dtype, Epoch
from astrojax.constants import GM_EARTH, R_EARTH
from astrojax.covariance import (
create_variational_dynamics,
augmented_initial_state,
extract_state_and_stm,
propagate_covariance,
)
from astrojax.eop import zero_eop
from astrojax.integrators import rk4_step
from astrojax.orbit_dynamics.factory import create_orbit_dynamics
from astrojax.orbits import orbital_period
set_dtype(jnp.float64)
# Circular LEO at 500 km
a = R_EARTH + 500e3
v_circ = jnp.sqrt(GM_EARTH / a)
x0 = jnp.array([a, 0.0, 0.0, 0.0, v_circ, 0.0])
# Initial covariance: 100 m position, 0.1 m/s velocity (1-sigma)
P0 = jnp.diag(jnp.array([100.0**2, 100.0**2, 100.0**2,
0.1**2, 0.1**2, 0.1**2]))
# Create dynamics and augment with variational equations
epoch_0 = Epoch(2026, 3, 11, 0, 0, 0.0)
dynamics = create_orbit_dynamics(zero_eop(), epoch_0)
n = 6
aug_dynamics = create_variational_dynamics(dynamics, n)
# Build augmented initial state
aug_x0 = augmented_initial_state(x0, n) # shape (42,): 6 state + 36 STM
# Propagate one full orbit with lax.scan
T_orbit = orbital_period(a)
dt = 10.0
n_steps = int(float(T_orbit) / float(dt))
def scan_step(carry, _):
t, aug_state = carry
result = rk4_step(aug_dynamics, t, aug_state, dt)
return (t + dt, result.state), result.state
init = (jnp.float64(0.0), aug_x0)
(t_final, aug_final), aug_history = jax.lax.scan(
scan_step, init, None, length=n_steps
)
# Extract final STM and propagate covariance
x_final, Phi = extract_state_and_stm(aug_final, n)
P_final = propagate_covariance(Phi, P0)
# 1-sigma uncertainties
sigma = jnp.sqrt(jnp.diag(P_final))
print(f"Final 1-sigma position: {sigma[:3]} m")
print(f"Final 1-sigma velocity: {sigma[3:]} m/s")
print(f"STM determinant: {jnp.linalg.det(Phi):.6f}") # ≈ 1.0 for Hamiltonian
Multi-Step Covariance with jax.lax.scan¶
You can record the covariance at every timestep by extracting the STM at each step inside the scan body:
def scan_step_with_cov(carry, _):
t, aug_state = carry
result = rk4_step(aug_dynamics, t, aug_state, dt)
# Extract STM and map covariance at this timestep
_, Phi_i = extract_state_and_stm(result.state, n)
P_i = propagate_covariance(Phi_i, P0)
sigma_i = jnp.sqrt(jnp.diag(P_i))
return (t + dt, result.state), sigma_i
_, sigma_history = jax.lax.scan(
scan_step_with_cov, init, None, length=n_steps
)
# sigma_history: (n_steps, 6) — 1-sigma at each timestep
JAX Compatibility¶
All covariance functions are pure JAX and work with:
jax.jit— The augmented dynamics compile to a single XLA programjax.vmap— Vectorize over batches of initial states or covariancesjax.lax.scan— Efficient multi-step propagation (see examples above)
# Batch covariance propagation over multiple initial covariances
batch_propagate = jax.vmap(propagate_covariance, in_axes=(None, 0))
P0_batch = jnp.stack([P0 * scale for scale in [0.5, 1.0, 2.0]])
P_batch = batch_propagate(Phi, P0_batch) # (3, 6, 6)
Relationship to the Estimation Module¶
The covariance module is a building block that the estimation module uses internally:
ekf_predictusesjax.jacfwdto compute the STM and then applies the sameΦ P Φᵀ + Qupdate- The covariance module exposes this machinery directly so you can use it for open-loop uncertainty analysis without measurements
Use the covariance module when you want to study uncertainty growth without sensor updates. Use the estimation module (EKF/UKF) when you have measurements to incorporate.
Configurable precision
All covariance functions respect astrojax.set_dtype(). Call
set_dtype() before JIT compilation to control float32 vs float64
precision. Float64 is recommended for long-duration propagation where
STM elements can grow large.
Dynamics must be JAX-differentiable
The dynamics function passed to create_variational_dynamics must
be composed of JAX operations, since jax.jacfwd is used internally
to compute the Jacobian A = ∂f/∂x. This is the same requirement
as ekf_predict.