Config¶
Module-wide floating-point precision configuration.
Provides set_dtype and get_dtype to control the float dtype used
throughout astrojax. The default is jnp.float32 for GPU/TPU
compatibility. Switching to jnp.float64 automatically enables
JAX's 64-bit mode (jax_enable_x64).
Call set_dtype before any JIT compilation, just like JAX's own
jax.config.update("jax_enable_x64", True). Under JIT, get_dtype()
runs during tracing and its result is baked into the compiled program.
JAX retraces when input dtypes change, so passing float64 inputs after
set_dtype(jnp.float64) triggers a correct retrace.
Integer components (e.g. Epoch _jd) are always jnp.int32
regardless of this setting.
get_dtype()
¶
Return the current module-wide float dtype.
Returns:
| Name | Type | Description |
|---|---|---|
DTypeLike |
DTypeLike
|
The active float dtype (default |
get_epoch_eq_tolerance()
¶
Return the dtype-adaptive tolerance for Epoch equality comparisons.
The tolerance scales with the precision of the configured float dtype:
float16: 0.1 sbfloat16: 0.1 sfloat32: 1e-3 sfloat64: 1e-9 s
Returns:
| Name | Type | Description |
|---|---|---|
float |
float
|
Tolerance in seconds. |
set_dtype(dtype)
¶
Set the module-wide float dtype for astrojax.
Must be called before any jax.jit compilation. In eager mode
the change takes effect immediately. Under JIT, get_dtype() runs
during tracing and its value is baked into the compiled program.
If dtype is jnp.float64, JAX's 64-bit mode is automatically
enabled via jax.config.update("jax_enable_x64", True).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dtype
|
DTypeLike
|
One of |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If dtype is not a supported float type. |