Skip to content

tristanmanchester/tomojax

Repository files navigation

TomoJAX: Differentiable CT Projector with Alignment

TomoJAX is a fully differentiable, memory-efficient parallel-beam CT projector implemented in JAX. It provides exact gradients for 5-DOF rigid-body alignment, making it useful for CT reconstruction, per-view alignment, and deep learning workloads that require data consistency.

Alignment demonstration

Left to right: ground truth phantom, naive reconstructions (misaligned and noisy), aligned reconstructions.

Installation

# Install uv (https://docs.astral.sh/uv/), then:
uv sync --extra cuda12 --group dev
uv run tomojax-test-gpu

For CPU-only setups, use --extra cpu instead. See the full installation guide for prerequisites, verification utilities, and troubleshooting.

Key features

  • Differentiable projector with exact gradients via JAX autodiff, trilinear interpolation, and O(detector pixels) memory via lax.scan
  • 5-DOF rigid-body alignment (alpha, beta, phi, dx, dz) with Gauss-Newton, gradient descent, or L-BFGS optimisers
  • Multi-resolution alignment with coarse-to-fine pyramid (e.g. 4x, 2x, 1x) and alternating reconstruct/align steps
  • Profiled alignment early stopping with compute-saving defaults and a conservative robust profile for final real-data runs
  • Three reconstruction algorithms: FBP, FISTA-TV, and SPDHG-TV with automatic Lipschitz estimation and TV proximal operators
  • Laminography support with tilted rotation-axis geometry, 360-degree scans, and sample-frame reconstructions
  • Complete CLI toolkit with 11 commands for simulation, alignment, reconstruction, inspection, validation, and benchmarking

Quick start

# Simulate a small phantom, misalign it, then align and reconstruct
uv run tomojax-simulate \
  --out data/sim.nxs \
  --nx 64 --ny 64 --nz 64 --nu 64 --nv 64 --n-views 60 \
  --phantom random_shapes --seed 42 --progress

uv run tomojax-misalign \
  --data data/sim.nxs --out data/sim_mis.nxs \
  --rot-deg 1.0 --trans-px 5 --seed 0 --progress

uv run tomojax-align \
  --data data/sim_mis.nxs --levels 4 2 1 \
  --outer-iters 4 --recon-iters 15 --lambda-tv 0.003 \
  --opt-method gn --gn-damping 1e-3 \
  --out out/aligned.nxs --progress

See the quickstart guide for a full walkthrough and the CLI reference for all commands.

Python API

import jax.numpy as jnp
from tomojax.core.geometry import Grid, Detector, ParallelGeometry
from tomojax.core.projector import forward_project_view, forward_project_view_T
from tomojax.recon import fbp, fista_tv
from tomojax.align import align, AlignConfig

# Define grid, detector, geometry
grid = Grid(nx=128, ny=128, nz=128, vx=1.0, vy=1.0, vz=1.0)
det  = Detector(nu=128, nv=128, du=1.0, dv=1.0)
thetas = jnp.linspace(0, 180, 128, endpoint=False)
geom = ParallelGeometry(grid=grid, detector=det, thetas_deg=thetas)

vol = jnp.ones((grid.nx, grid.ny, grid.nz), jnp.float32)

# Single-view projection (pose from geometry)
p0 = forward_project_view(geom, grid, det, vol, view_index=0)

# FBP and FISTA reconstructions
x_fbp = fbp(geom, grid, det, p0[None, ...])
x_fista, info = fista_tv(
    geom,
    grid,
    det,
    p0[None, ...],
    iters=10,
    lambda_tv=0.001,
    recon_rel_tol=1e-4,
    recon_patience=3,
)

# Alignment (toy 1-view example; use many views in practice)
cfg = AlignConfig(
    outer_iters=1,
    recon_iters=5,
    lambda_tv=0.001,
    recon_rel_tol=1e-4,
    recon_patience=3,
)
x_aligned, params5, info = align(geom, grid, det, p0[None, ...], cfg=cfg)

See the API reference for the full public surface and the data format reference for the HDF5/NXtomo schema used by the CLIs.

Documentation

Section Description
Installation Prerequisites, GPU/CPU setup, verification
Quickstart Five-minute workflow on a small dataset
Tutorials
End-to-end Full 256^3 workflow with alignment
Laminography Tilted rotation-axis geometry
Single sample Sphere or cube phantom
Concepts
Geometry Grid, detector, 5-DOF parameterisation
Alignment Multi-resolution pipeline and optimisers
Reconstruction FBP, FISTA-TV, SPDHG-TV
CLI reference
CLI overview All 11 commands, config files, env vars
Reference
Data format NXtomo HDF5 schema
Loss functions Available losses and schedules
Config files TOML configuration system
API Python API reference
Misalignment modes Deterministic perturbation schedules
Troubleshooting Common issues and solutions

Visual examples

Basic projector workflow

Phantom Projections Sinogram Reconstruction


Top: slice
Bottom: volume projection
Animated over 360° Angle vs detector Top: slice
Bottom: volume projection

Alignment and reconstruction workflow

Misaligned input data

Clean Misaligned Projections Noisy Misaligned Projections
Random rigid-body misalignments Same misalignments + Poisson noise

Sinogram analysis

Clean Misaligned Sinogram Noisy Misaligned Sinogram
Clear view-to-view inconsistencies Inconsistencies + noise artifacts

Multi-resolution alignment process

Clean Data Alignment Noisy Data Alignment
4x, 2x, 1x resolution refinement Robust alignment despite noise

About

Differentiable parallel-beam CT projector, reconstruction, and alignment toolkit built in JAX.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages