Skip to content

UKF

Unscented Kalman Filter (UKF) predict and update functions.

Implements the scaled Unscented Kalman Filter using the Van der Merwe sigma point algorithm. The propagation and measurement functions are applied to sigma points via jax.vmap for efficient parallel evaluation.

The covariance update uses the Joseph form for guaranteed symmetry and positive semi-definiteness. Cholesky decomposition is regularized with a dtype-adaptive epsilon to prevent failure from float32 precision loss.

These are building-block functions designed to compose with jax.lax.scan for sequential filtering. See the user guide for a complete orbit determination example.

ukf_predict(filter_state, propagate_fn, Q, config=_DEFAULT_UKF_CONFIG)

Propagate the filter state forward one timestep using sigma points.

Generates sigma points from the current state and covariance, propagates each through propagate_fn via jax.vmap, and reconstructs the predicted mean and covariance from the weighted propagated points.

The user constructs propagate_fn by closing over their dynamics, integrator, and timestep::

dynamics = create_orbit_dynamics(eop, epoch_0)
def propagate(x):
    return rk4_step(dynamics, t, x, dt).state

Parameters:

Name Type Description Default
filter_state FilterState

Current filter state (x, P).

required
propagate_fn Callable[[Array], Array]

State propagation function f(x) -> x_next. Applied to each sigma point via jax.vmap.

required
Q ArrayLike

Process noise covariance matrix of shape (n, n).

required
config UKFConfig

UKF sigma point configuration. Default: UKFConfig().

_DEFAULT_UKF_CONFIG

Returns:

Name Type Description
FilterState FilterState

Predicted state and covariance (x_pred, P_pred).

Examples:

import jax.numpy as jnp
from astrojax.estimation import FilterState, UKFConfig, ukf_predict

x0 = jnp.array([1.0, 0.0])
P0 = jnp.eye(2) * 0.01
fs = FilterState(x=x0, P=P0)
Q = jnp.eye(2) * 1e-6

def propagate(x):
    return x + jnp.array([x[1], -x[0]]) * 0.01

fs_pred = ukf_predict(fs, propagate, Q)

ukf_update(filter_state, z, measurement_fn, R, config=_DEFAULT_UKF_CONFIG)

Incorporate a measurement into the filter state using sigma points.

Generates sigma points from the predicted state, transforms each through the measurement function via jax.vmap, and computes the Kalman gain from the cross-covariance between state and measurement spaces.

Parameters:

Name Type Description Default
filter_state FilterState

Predicted filter state (x_pred, P_pred), typically from ukf_predict.

required
z ArrayLike

Measurement vector of shape (m,).

required
measurement_fn Callable[[Array], Array]

Measurement model h(x) -> z_pred. Maps the state to the expected measurement. Applied to each sigma point via jax.vmap.

required
R ArrayLike

Measurement noise covariance matrix of shape (m, m).

required
config UKFConfig

UKF sigma point configuration. Must match the config used in ukf_predict. Default: UKFConfig().

_DEFAULT_UKF_CONFIG

Returns:

Name Type Description
FilterResult FilterResult

Updated state, innovation, innovation covariance, and Kalman gain.

Examples:

import jax.numpy as jnp
from astrojax.estimation import FilterState, ukf_update

x_pred = jnp.array([1.0, 0.5, 0.0, 0.0, 7.5e3, 0.0])
P_pred = jnp.eye(6) * 100.0
fs = FilterState(x=x_pred, P=P_pred)

z = jnp.array([1.01, 0.49, 0.01])
R = jnp.eye(3) * 0.01

def measure_position(x):
    return x[:3]

result = ukf_update(fs, z, measure_position, R)