Skip to content

Configuration

AstroJAX uses a module-wide configuration to control the floating-point precision of all computations.

Float Dtype

By default, all float computations use jnp.float32 for GPU/TPU compatibility. You can switch to jnp.float64 for higher precision, or to jnp.float16/jnp.bfloat16 for faster low-precision workloads.

import astrojax
import jax.numpy as jnp

# Switch to float64 for high-precision work
astrojax.set_dtype(jnp.float64)

# Check the current dtype
print(astrojax.get_dtype())  # <class 'jax.numpy.float64'>

Supported Dtypes

Dtype Epoch Precision Use Case
jnp.float16 ~seconds Fast training, low-fidelity
jnp.bfloat16 ~seconds TPU training
jnp.float32 (default) ~8 ms GPU simulation, most applications
jnp.float64 sub-nanosecond High-fidelity analysis, validation

When to Call set_dtype

Call set_dtype before any jax.jit compilation, just like JAX's own jax.config.update("jax_enable_x64", True). Under JIT, get_dtype() runs during tracing and its result is baked into the compiled program.

import astrojax
import jax
import jax.numpy as jnp

# Set dtype BEFORE defining JIT functions
astrojax.set_dtype(jnp.float64)

@jax.jit
def propagate(state, dt):
    # All internal computations use float64
    ...

Float64 and JAX x64 Mode

Setting jnp.float64 automatically enables JAX's 64-bit mode via jax.config.update("jax_enable_x64", True). This is a process-level setting that cannot be reverted.

Integer Components

The Epoch class uses jnp.int32 for its Julian Day number (_jd) regardless of the configured float dtype. Only the float components (_seconds, _kahan_c) follow the configured dtype.

Epoch Equality Tolerance

The Epoch.__eq__ comparison uses a dtype-adaptive tolerance that scales with precision:

Dtype Tolerance
float16 / bfloat16 0.1 s
float32 1e-3 s
float64 1e-9 s