TomoJAX is a fully differentiable, memory-efficient parallel-beam CT projector implemented in JAX. It provides exact gradients for 5-DOF rigid-body alignment, making it useful for CT reconstruction, per-view alignment, and deep learning workloads that require data consistency.
Left to right: ground truth phantom, naive reconstructions (misaligned and noisy), aligned reconstructions.
# Install uv (https://docs.astral.sh/uv/), then:
uv sync --extra cuda12 --group dev
uv run tomojax-test-gpuFor CPU-only setups, use --extra cpu instead. See the
full installation guide for prerequisites,
verification utilities, and troubleshooting.
- Differentiable projector with exact gradients via JAX autodiff,
trilinear interpolation, and O(detector pixels) memory via
lax.scan - 5-DOF rigid-body alignment (
alpha,beta,phi,dx,dz) with Gauss-Newton, gradient descent, or L-BFGS optimisers - Multi-resolution alignment with coarse-to-fine pyramid (e.g. 4x, 2x, 1x) and alternating reconstruct/align steps
- Profiled alignment early stopping with compute-saving defaults
and a conservative
robustprofile for final real-data runs - Three reconstruction algorithms: FBP, FISTA-TV, and SPDHG-TV with automatic Lipschitz estimation and TV proximal operators
- Laminography support with tilted rotation-axis geometry, 360-degree scans, and sample-frame reconstructions
- Complete CLI toolkit with 11 commands for simulation, alignment, reconstruction, inspection, validation, and benchmarking
# Simulate a small phantom, misalign it, then align and reconstruct
uv run tomojax-simulate \
--out data/sim.nxs \
--nx 64 --ny 64 --nz 64 --nu 64 --nv 64 --n-views 60 \
--phantom random_shapes --seed 42 --progress
uv run tomojax-misalign \
--data data/sim.nxs --out data/sim_mis.nxs \
--rot-deg 1.0 --trans-px 5 --seed 0 --progress
uv run tomojax-align \
--data data/sim_mis.nxs --levels 4 2 1 \
--outer-iters 4 --recon-iters 15 --lambda-tv 0.003 \
--opt-method gn --gn-damping 1e-3 \
--out out/aligned.nxs --progressSee the quickstart guide for a full walkthrough and the CLI reference for all commands.
import jax.numpy as jnp
from tomojax.core.geometry import Grid, Detector, ParallelGeometry
from tomojax.core.projector import forward_project_view, forward_project_view_T
from tomojax.recon import fbp, fista_tv
from tomojax.align import align, AlignConfig
# Define grid, detector, geometry
grid = Grid(nx=128, ny=128, nz=128, vx=1.0, vy=1.0, vz=1.0)
det = Detector(nu=128, nv=128, du=1.0, dv=1.0)
thetas = jnp.linspace(0, 180, 128, endpoint=False)
geom = ParallelGeometry(grid=grid, detector=det, thetas_deg=thetas)
vol = jnp.ones((grid.nx, grid.ny, grid.nz), jnp.float32)
# Single-view projection (pose from geometry)
p0 = forward_project_view(geom, grid, det, vol, view_index=0)
# FBP and FISTA reconstructions
x_fbp = fbp(geom, grid, det, p0[None, ...])
x_fista, info = fista_tv(
geom,
grid,
det,
p0[None, ...],
iters=10,
lambda_tv=0.001,
recon_rel_tol=1e-4,
recon_patience=3,
)
# Alignment (toy 1-view example; use many views in practice)
cfg = AlignConfig(
outer_iters=1,
recon_iters=5,
lambda_tv=0.001,
recon_rel_tol=1e-4,
recon_patience=3,
)
x_aligned, params5, info = align(geom, grid, det, p0[None, ...], cfg=cfg)See the API reference for the full public surface and the data format reference for the HDF5/NXtomo schema used by the CLIs.
| Section | Description |
|---|---|
| Installation | Prerequisites, GPU/CPU setup, verification |
| Quickstart | Five-minute workflow on a small dataset |
| Tutorials | |
| End-to-end | Full 256^3 workflow with alignment |
| Laminography | Tilted rotation-axis geometry |
| Single sample | Sphere or cube phantom |
| Concepts | |
| Geometry | Grid, detector, 5-DOF parameterisation |
| Alignment | Multi-resolution pipeline and optimisers |
| Reconstruction | FBP, FISTA-TV, SPDHG-TV |
| CLI reference | |
| CLI overview | All 11 commands, config files, env vars |
| Reference | |
| Data format | NXtomo HDF5 schema |
| Loss functions | Available losses and schedules |
| Config files | TOML configuration system |
| API | Python API reference |
| Misalignment modes | Deterministic perturbation schedules |
| Troubleshooting | Common issues and solutions |
| Phantom | Projections | Sinogram | Reconstruction |
|---|---|---|---|
![]() ![]() |
![]() |
![]() |
![]() ![]() |
| Top: slice Bottom: volume projection |
Animated over 360° | Angle vs detector | Top: slice Bottom: volume projection |
| Clean Misaligned Projections | Noisy Misaligned Projections |
|---|---|
![]() |
![]() |
| Random rigid-body misalignments | Same misalignments + Poisson noise |
| Clean Misaligned Sinogram | Noisy Misaligned Sinogram |
|---|---|
![]() |
![]() |
| Clear view-to-view inconsistencies | Inconsistencies + noise artifacts |
| Clean Data Alignment | Noisy Data Alignment |
|---|---|
![]() |
![]() |
| 4x, 2x, 1x resolution refinement | Robust alignment despite noise |












