SGP4/SDP4 Propagation¶
The astrojax.sgp4 module provides a JAX-native implementation of the
SGP4 (Simplified General Perturbations 4) and SDP4 (Simplified Deep-space
Perturbations 4) orbit propagators for propagating Two-Line Element (TLE)
sets. All functions are compatible with jax.jit, jax.vmap, and
autodifferentiation.
Quick Start¶
import jax.numpy as jnp
from astrojax.sgp4 import parse_tle, sgp4_init, sgp4_propagate
line1 = "1 25544U 98067A 08264.51782528 -.00002182 00000-0 -11606-4 0 2927"
line2 = "2 25544 51.6416 247.4627 0006703 130.5360 325.0288 15.72125391563537"
elements = parse_tle(line1, line2)
params, method = sgp4_init(elements)
# Propagate 60 minutes from epoch
r, v = sgp4_propagate(params, jnp.float64(60.0), method)
# r: position [km] in TEME frame, shape (3,)
# v: velocity [km/s] in TEME frame, shape (3,)
Near-Earth vs Deep-Space¶
SGP4 automatically classifies satellites into two regimes based on their orbital period:
| Regime | Period | Method flag | Propagation model |
|---|---|---|---|
| Near-Earth | < 225 min | 'n' |
SGP4 (atmospheric drag, secular/periodic perturbations) |
| Deep-Space | ≥ 225 min | 'd' |
SDP4 (lunar-solar perturbations, resonance effects) |
The method string returned by sgp4_init tells you which regime
the satellite falls into. You pass it to sgp4_propagate so that JAX
traces only the relevant code path (zero overhead from the unused branch).
Propagation Variants¶
The module provides two propagation variants that differ in how deep-space resonance integration is performed. For near-earth satellites, both variants behave identically.
Default: sgp4_propagate¶
Uses jax.lax.scan internally for deep-space resonance integration.
This supports both forward-mode and reverse-mode autodifferentiation,
but imposes a configurable upper bound on propagation time.
from astrojax.sgp4 import sgp4_propagate
# Default: up to ~100 days from epoch (200 iterations × 720 min)
r, v = sgp4_propagate(params, jnp.float64(tsince), method)
# Extend to ~200 days by increasing max_dspace_iters
r, v = sgp4_propagate(params, jnp.float64(tsince), method,
max_dspace_iters=400)
Unbounded: sgp4_propagate_unbounded¶
Uses jax.lax.while_loop internally for deep-space resonance
integration. This imposes no upper bound on propagation time,
but only supports forward-mode autodifferentiation.
from astrojax.sgp4 import sgp4_propagate_unbounded
# No time limit — works for any tsince value
r, v = sgp4_propagate_unbounded(params, jnp.float64(tsince), method)
Choosing a Variant¶
sgp4_propagate |
sgp4_propagate_unbounded |
|
|---|---|---|
| Deep-space loop | jax.lax.scan (fixed iterations) |
jax.lax.while_loop (dynamic) |
Reverse-mode AD (jax.grad, jax.vjp) |
Yes | No |
Forward-mode AD (jax.jacfwd, jax.jvp) |
Yes | Yes |
| Time limit (deep-space) | max_dspace_iters × 720 min (~100 days default) |
Unlimited |
| Near-earth behavior | Identical | Identical |
Use sgp4_propagate (default) when:
- You need reverse-mode gradients (e.g., gradient-based orbit determination, optimization, training)
- Your propagation horizon is within the iteration budget
(adjustable via
max_dspace_iters)
Use sgp4_propagate_unbounded when:
- You need to propagate deep-space satellites over long time spans (months to years)
- You only need forward-mode differentiation, or no differentiation
- You don't want to worry about tuning
max_dspace_iters
Reverse-mode AD and sgp4_propagate_unbounded
The unbounded variant uses jax.lax.while_loop, which does not
support reverse-mode differentiation in JAX. Calling jax.grad
or jax.vjp through sgp4_propagate_unbounded with a deep-space
satellite will raise an error.
Unified Variants¶
Both propagation functions have "unified" counterparts that auto-detect
the satellite regime from the parameter array (no method flag needed).
These use jax.lax.cond internally, enabling jax.vmap over mixed
batches of near-earth and deep-space satellites:
from astrojax.sgp4 import (
sgp4_init_jax,
sgp4_propagate_unified,
sgp4_propagate_unified_unbounded,
elements_to_array,
)
arr = elements_to_array(elements)
params = sgp4_init_jax(arr)
# Auto-detects near-earth vs deep-space from params
r, v = sgp4_propagate_unified(params, jnp.float64(60.0))
r, v = sgp4_propagate_unified_unbounded(params, jnp.float64(60.0))
TLE Input Formats¶
The module accepts TLEs in several formats:
From TLE Strings¶
from astrojax.sgp4 import create_sgp4_propagator
params, propagate_fn = create_sgp4_propagator(line1, line2)
r, v = propagate_fn(jnp.float64(60.0))
From OMM Fields¶
from astrojax.sgp4 import create_sgp4_propagator_from_omm
fields = {
"OBJECT_ID": "1998-067A",
"EPOCH": "2008-09-20T12:25:40.104192",
"MEAN_MOTION": "15.72125391",
"ECCENTRICITY": "0.0006703",
"INCLINATION": "51.6416",
"RA_OF_ASC_NODE": "247.4627",
"ARG_OF_PERICENTER": "130.5360",
"MEAN_ANOMALY": "325.0288",
"BSTAR": "-0.11606e-4",
"MEAN_MOTION_DOT": "-0.00002182",
"MEAN_MOTION_DDOT": "0",
}
params, propagate_fn = create_sgp4_propagator_from_omm(fields)
From GPRecord¶
from astrojax.sgp4 import create_sgp4_propagator_from_gp_record
params, propagate_fn = create_sgp4_propagator_from_gp_record(record)
JIT-Compilable Initialization¶
For workflows that need to differentiate or vmap through initialization
(e.g., optimizing orbital elements), use sgp4_init_jax:
import jax
from astrojax.sgp4 import sgp4_init_jax, sgp4_propagate_unified, elements_to_array
arr = elements_to_array(elements)
params = sgp4_init_jax(arr, gravity=WGS72, opsmode="i")
# Differentiate through init + propagation
def loss(elems):
p = sgp4_init_jax(elems, gravity=WGS72, opsmode="i")
r, v = sgp4_propagate_unified(p, jnp.float64(60.0))
return jnp.sum(r**2)
grads = jax.grad(loss)(arr)
Batch Propagation with vmap¶
Over Time¶
times = jnp.array([0.0, 60.0, 360.0, 1440.0])
r_batch, v_batch = jax.vmap(
lambda t: sgp4_propagate(params, t, method)
)(times)
# r_batch: shape (4, 3), v_batch: shape (4, 3)
Over Satellites¶
from astrojax.sgp4 import sgp4_init_jax, sgp4_propagate_unified, elements_to_array
# Stack element arrays for multiple satellites
batch = jnp.stack([elements_to_array(e) for e in element_list])
init_vmap = jax.vmap(lambda e: sgp4_init_jax(e, gravity=WGS72, opsmode="i"))
params_batch = init_vmap(batch)
# Propagate all at once
prop_vmap = jax.vmap(lambda p: sgp4_propagate_unified(p, jnp.float64(60.0)))
r_batch, v_batch = prop_vmap(params_batch)
Gravity Models¶
Three standard gravity models are available:
| Model | Constant | Description |
|---|---|---|
| WGS72 | WGS72 |
Standard SGP4 gravity model (default) |
| WGS84 | WGS84 |
Modern WGS84 model |
| WGS72 (legacy) | WGS72OLD |
Original WGS72 constants |
High-Level TLE Class¶
For convenience, the TLE class wraps initialization and propagation
with frame transformations:
from astrojax.sgp4 import TLE
tle = TLE(line1, line2)
print(tle.epoch) # Epoch as astrojax.Epoch
print(tle.method) # 'n' or 'd'
r, v = tle.state_teme(jnp.float64(60.0)) # TEME frame
r, v = tle.state_gcrf(jnp.float64(60.0)) # GCRF frame (requires EOP)
r, v = tle.state_itrf(jnp.float64(60.0)) # ITRF frame (requires EOP)
Output Format¶
All propagation functions return the same output:
(r, v) where:
r: jnp.array, shape (3,) — position in km (TEME frame)
v: jnp.array, shape (3,) — velocity in km/s (TEME frame)
If the orbit becomes invalid (e.g., reentry, numerical divergence),
both arrays are filled with NaN.