EKF¶
Extended Kalman Filter (EKF) predict and update functions.
Implements the standard EKF using automatic differentiation to compute
the state transition matrix (STM) and measurement Jacobian. The user
provides a propagate_fn(x) -> x_next for prediction and a
measurement_fn(x) -> z_pred for update; JAX computes the required
Jacobians via jax.jacfwd.
The covariance update uses the Joseph form for guaranteed symmetry and positive semi-definiteness, which is important for float32 stability.
These are building-block functions designed to compose with
jax.lax.scan for sequential filtering. See the user guide for
a complete orbit determination example.
ekf_predict(filter_state, propagate_fn, Q)
¶
Propagate the filter state forward one timestep.
Advances the state estimate through the nonlinear propagation function and updates the covariance using the automatically computed state transition matrix (STM).
The user constructs propagate_fn by closing over their dynamics,
integrator, and timestep::
dynamics = create_orbit_dynamics(eop, epoch_0)
def propagate(x):
return rk4_step(dynamics, t, x, dt).state
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Current filter state |
required |
propagate_fn
|
Callable[[Array], Array]
|
State propagation function |
required |
Q
|
ArrayLike
|
Process noise covariance matrix of shape |
required |
Returns:
| Name | Type | Description |
|---|---|---|
FilterState |
FilterState
|
Predicted state and covariance |
Examples:
ekf_update(filter_state, z, measurement_fn, R)
¶
Incorporate a measurement into the filter state.
Computes the Kalman gain, updates the state estimate, and updates the covariance using the Joseph form for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Predicted filter state |
required |
z
|
ArrayLike
|
Measurement vector of shape |
required |
measurement_fn
|
Callable[[Array], Array]
|
Measurement model |
required |
R
|
ArrayLike
|
Measurement noise covariance matrix of shape |
required |
Returns:
| Name | Type | Description |
|---|---|---|
FilterResult |
FilterResult
|
Updated state, innovation, innovation covariance, and Kalman gain. |
Examples:
import jax.numpy as jnp
from astrojax.estimation import FilterState, ekf_update
x_pred = jnp.array([1.0, 0.5, 0.0, 0.0, 7.5e3, 0.0])
P_pred = jnp.eye(6) * 100.0
fs = FilterState(x=x_pred, P=P_pred)
z = jnp.array([1.01, 0.49, 0.01]) # measured position
R = jnp.eye(3) * 0.01
def measure_position(x):
return x[:3]
result = ekf_update(fs, z, measure_position, R)
updated_state = result.state