Skip to content

Covariance

Covariance propagation via variational equations.

Provides a generic mechanism to augment any ODE dynamics function with State Transition Matrix (STM) propagation. The Jacobian A = ∂f/∂x is computed automatically via jax.jacfwd, so no hand-derived partial derivatives are needed.

The augmented state vector is [x(n), vec(Φ)(n²)] and can be integrated with any existing integrator (RK4, RKF45, DP54, etc.). This approach works with orbit dynamics, attitude dynamics, relative motion, or any other differentiable ODE system.

Functions:

Name Description
create_variational_dynamics

Augment a dynamics function with STM propagation.

augmented_initial_state

Build the initial augmented state vector.

extract_state_and_stm

Split an augmented state into (x, Φ).

propagate_covariance

Map a covariance through the STM.

augmented_initial_state(x0, n)

Build the initial augmented state vector.

Concatenates the initial state x0 with the vectorised identity matrix vec(I_n), since the STM at t=t₀ is the identity.

Parameters:

Name Type Description Default
x0 ArrayLike

Initial state vector of shape (n,).

required
n int

Dimension of the state vector.

required

Returns:

Type Description
Array

Augmented state of shape (n + n²,).

create_variational_dynamics(dynamics, n)

Augment a dynamics function with STM propagation.

Returns a new dynamics function whose state is the original state concatenated with the vectorised STM. The Jacobian A = ∂f/∂x is computed inside the closure via jax.jacfwd, so the user never needs to derive partials by hand.

Parameters:

Name Type Description Default
dynamics Callable[[ArrayLike, ArrayLike], Array]

Original ODE right-hand side f(t, x) -> dx/dt. Must be differentiable by JAX.

required
n int

Dimension of the original state vector x.

required

Returns:

Type Description
Callable[[ArrayLike, ArrayLike], Array]

Augmented dynamics function f_aug(t, aug_state) -> d(aug_state)/dt

Callable[[ArrayLike, ArrayLike], Array]

where aug_state has length n + n².

Examples:

import jax.numpy as jnp
from astrojax.covariance import (
    create_variational_dynamics,
    augmented_initial_state,
    extract_state_and_stm,
)
from astrojax.integrators import rk4_step

def harmonic(t, x):
    return jnp.array([x[1], -x[0]])

aug_dynamics = create_variational_dynamics(harmonic, n=2)
aug_x0 = augmented_initial_state(jnp.array([1.0, 0.0]), n=2)
result = rk4_step(aug_dynamics, 0.0, aug_x0, 0.01)
x, Phi = extract_state_and_stm(result.state, n=2)

extract_state_and_stm(aug_state, n)

Split an augmented state into the state vector and STM.

Parameters:

Name Type Description Default
aug_state ArrayLike

Augmented state of shape (n + n²,).

required
n int

Dimension of the original state vector.

required

Returns:

Type Description
Array

Tuple (x, Phi) where x has shape (n,) and Phi

Array

has shape (n, n).

propagate_covariance(Phi, P0, Q=None)

Propagate a covariance matrix through the STM.

Computes P = Φ P₀ Φᵀ + Q where Q defaults to zero.

Parameters:

Name Type Description Default
Phi ArrayLike

State transition matrix of shape (n, n).

required
P0 ArrayLike

Initial covariance matrix of shape (n, n).

required
Q ArrayLike | None

Process noise covariance of shape (n, n). Defaults to zero.

None

Returns:

Type Description
Array

Propagated covariance matrix of shape (n, n).