UKF¶
Unscented Kalman Filter (UKF) predict and update functions.
Implements the scaled Unscented Kalman Filter using the Van der Merwe
sigma point algorithm. The propagation and measurement functions are
applied to sigma points via jax.vmap for efficient parallel
evaluation.
The covariance update uses the Joseph form for guaranteed symmetry and positive semi-definiteness. Cholesky decomposition is regularized with a dtype-adaptive epsilon to prevent failure from float32 precision loss.
These are building-block functions designed to compose with
jax.lax.scan for sequential filtering. See the user guide for
a complete orbit determination example.
ukf_predict(filter_state, propagate_fn, Q, config=_DEFAULT_UKF_CONFIG)
¶
Propagate the filter state forward one timestep using sigma points.
Generates sigma points from the current state and covariance,
propagates each through propagate_fn via jax.vmap, and
reconstructs the predicted mean and covariance from the weighted
propagated points.
The user constructs propagate_fn by closing over their dynamics,
integrator, and timestep::
dynamics = create_orbit_dynamics(eop, epoch_0)
def propagate(x):
return rk4_step(dynamics, t, x, dt).state
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Current filter state |
required |
propagate_fn
|
Callable[[Array], Array]
|
State propagation function |
required |
Q
|
ArrayLike
|
Process noise covariance matrix of shape |
required |
config
|
UKFConfig
|
UKF sigma point configuration. Default: |
_DEFAULT_UKF_CONFIG
|
Returns:
| Name | Type | Description |
|---|---|---|
FilterState |
FilterState
|
Predicted state and covariance |
Examples:
ukf_update(filter_state, z, measurement_fn, R, config=_DEFAULT_UKF_CONFIG)
¶
Incorporate a measurement into the filter state using sigma points.
Generates sigma points from the predicted state, transforms each
through the measurement function via jax.vmap, and computes
the Kalman gain from the cross-covariance between state and
measurement spaces.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Predicted filter state |
required |
z
|
ArrayLike
|
Measurement vector of shape |
required |
measurement_fn
|
Callable[[Array], Array]
|
Measurement model |
required |
R
|
ArrayLike
|
Measurement noise covariance matrix of shape |
required |
config
|
UKFConfig
|
UKF sigma point configuration. Must match the config
used in |
_DEFAULT_UKF_CONFIG
|
Returns:
| Name | Type | Description |
|---|---|---|
FilterResult |
FilterResult
|
Updated state, innovation, innovation covariance, and Kalman gain. |
Examples:
import jax.numpy as jnp
from astrojax.estimation import FilterState, ukf_update
x_pred = jnp.array([1.0, 0.5, 0.0, 0.0, 7.5e3, 0.0])
P_pred = jnp.eye(6) * 100.0
fs = FilterState(x=x_pred, P=P_pred)
z = jnp.array([1.01, 0.49, 0.01])
R = jnp.eye(3) * 0.01
def measure_position(x):
return x[:3]
result = ukf_update(fs, z, measure_position, R)