Skip to content

EKF

Extended Kalman Filter (EKF) predict and update functions.

Implements the standard EKF using automatic differentiation to compute the state transition matrix (STM) and measurement Jacobian. The user provides a propagate_fn(x) -> x_next for prediction and a measurement_fn(x) -> z_pred for update; JAX computes the required Jacobians via jax.jacfwd.

The covariance update uses the Joseph form for guaranteed symmetry and positive semi-definiteness, which is important for float32 stability.

These are building-block functions designed to compose with jax.lax.scan for sequential filtering. See the user guide for a complete orbit determination example.

ekf_predict(filter_state, propagate_fn, Q)

Propagate the filter state forward one timestep.

Advances the state estimate through the nonlinear propagation function and updates the covariance using the automatically computed state transition matrix (STM).

The user constructs propagate_fn by closing over their dynamics, integrator, and timestep::

dynamics = create_orbit_dynamics(eop, epoch_0)
def propagate(x):
    return rk4_step(dynamics, t, x, dt).state

Parameters:

Name Type Description Default
filter_state FilterState

Current filter state (x, P).

required
propagate_fn Callable[[Array], Array]

State propagation function f(x) -> x_next. Must be differentiable by JAX (composed of JAX operations).

required
Q ArrayLike

Process noise covariance matrix of shape (n, n).

required

Returns:

Name Type Description
FilterState FilterState

Predicted state and covariance (x_pred, P_pred).

Examples:

import jax.numpy as jnp
from astrojax.estimation import FilterState, ekf_predict

x0 = jnp.array([1.0, 0.0])
P0 = jnp.eye(2) * 0.01
fs = FilterState(x=x0, P=P0)
Q = jnp.eye(2) * 1e-6

def propagate(x):
    return x + jnp.array([x[1], -x[0]]) * 0.01

fs_pred = ekf_predict(fs, propagate, Q)

ekf_update(filter_state, z, measurement_fn, R)

Incorporate a measurement into the filter state.

Computes the Kalman gain, updates the state estimate, and updates the covariance using the Joseph form for numerical stability.

Parameters:

Name Type Description Default
filter_state FilterState

Predicted filter state (x_pred, P_pred), typically from ekf_predict.

required
z ArrayLike

Measurement vector of shape (m,).

required
measurement_fn Callable[[Array], Array]

Measurement model h(x) -> z_pred. Maps the state to the expected measurement. Must be differentiable by JAX.

required
R ArrayLike

Measurement noise covariance matrix of shape (m, m).

required

Returns:

Name Type Description
FilterResult FilterResult

Updated state, innovation, innovation covariance, and Kalman gain.

Examples:

import jax.numpy as jnp
from astrojax.estimation import FilterState, ekf_update

x_pred = jnp.array([1.0, 0.5, 0.0, 0.0, 7.5e3, 0.0])
P_pred = jnp.eye(6) * 100.0
fs = FilterState(x=x_pred, P=P_pred)

z = jnp.array([1.01, 0.49, 0.01])  # measured position
R = jnp.eye(3) * 0.01

def measure_position(x):
    return x[:3]

result = ekf_update(fs, z, measure_position, R)
updated_state = result.state