Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)
Using pip, one can install jvp_flash_attention as follows.
# Install package
pip install jvp_flash_attention
# [OPTIONAL, for development] Install package and pre-commit hooks
pip install -e .
pre-commit installOnce installed, one can use jvp_flash_attention in place of PyTorch's scaled_dot_product_attention as follows.
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from jvp_flash_attention.jvp_attention import attention as jvp_attention
with sdpa_kernel(SDPBackend.MATH):
# Regular (quadratic) attention
# x = F.scaled_dot_product_attention(
# q,
# k,
# v,
# attn_mask=attn_mask,
# dropout_p=attn_dropout_p if self.training else 0.0,
# )
# JVP flash attention
x = jvp_attention(
q,
k,
v,
# attn_mask=attn_mask, # NOTE: Attention masking is temporarily unsupported
# dropout_p=attn_dropout_p if self.training else 0.0, # NOTE: Attention dropout is currently unsupported
)Contributions or enhancements are welcome!
If you want to run the unit tests verifying the correctness of the JVP Flash Attention Triton kernel, run the following command(s).
python tests/test_jvp_attention.py --dtype {float16,bfloat16,float32}In principle, the kernel should support ROCm systems as well, though it has not yet been tested on them. macOS is currently unsupported.
Results for float16:
==============================================================================================================
BENCHMARK SUMMARY
==============================================================================================================
Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
--------------------------------------------------------------------------------------------------------------
32 False additive sdpa 0.785 0.64 0.0 TFLOP/s baseline N/A
32 False additive jvp_attn 0.475 0.23 0.0 TFLOP/s 1.95e-03 ✓
32 False boolean sdpa 0.834 0.65 0.0 TFLOP/s baseline N/A
32 False boolean jvp_attn 0.469 0.22 0.0 TFLOP/s 1.95e-03 ✓
32 False none sdpa 0.546 0.64 0.0 TFLOP/s baseline N/A
32 False none jvp_attn 0.465 0.22 0.0 TFLOP/s 1.95e-03 ✓
32 True none sdpa 0.838 0.65 0.0 TFLOP/s baseline N/A
32 True none jvp_attn 0.464 0.22 0.0 TFLOP/s 1.95e-03 ✓
64 False additive sdpa 0.790 1.41 0.0 TFLOP/s baseline N/A
64 False additive jvp_attn 0.487 0.47 0.0 TFLOP/s 9.77e-04 ✓
64 False boolean sdpa 0.820 1.45 0.0 TFLOP/s baseline N/A
64 False boolean jvp_attn 0.473 0.43 0.0 TFLOP/s 9.77e-04 ✓
64 False none sdpa 0.685 1.41 0.0 TFLOP/s baseline N/A
64 False none jvp_attn 0.474 0.43 0.0 TFLOP/s 9.77e-04 ✓
64 True none sdpa 0.854 1.42 0.0 TFLOP/s baseline N/A
64 True none jvp_attn 0.499 0.43 0.0 TFLOP/s 1.95e-03 ✓
128 False additive sdpa 0.769 3.28 0.0 TFLOP/s baseline N/A
128 False additive jvp_attn 0.494 1.02 0.1 TFLOP/s 9.77e-04 ✓
128 False boolean sdpa 0.831 3.44 0.0 TFLOP/s baseline N/A
128 False boolean jvp_attn 0.468 0.86 0.1 TFLOP/s 9.77e-04 ✓
128 False none sdpa 0.533 3.28 0.0 TFLOP/s baseline N/A
128 False none jvp_attn 0.470 0.86 0.1 TFLOP/s 9.77e-04 ✓
128 True none sdpa 0.984 3.35 0.0 TFLOP/s baseline N/A
128 True none jvp_attn 0.483 0.86 0.0 TFLOP/s 1.95e-03 ✓
256 False additive sdpa 1.142 9.69 0.1 TFLOP/s baseline N/A
256 False additive jvp_attn 0.473 2.35 0.4 TFLOP/s 9.77e-04 ✓
256 False boolean sdpa 0.886 10.32 0.1 TFLOP/s baseline N/A
256 False boolean jvp_attn 0.466 1.72 0.4 TFLOP/s 9.77e-04 ✓
256 False none sdpa 0.715 9.69 0.1 TFLOP/s baseline N/A
256 False none jvp_attn 0.472 1.72 0.4 TFLOP/s 9.77e-04 ✓
256 True none sdpa 0.976 9.94 0.0 TFLOP/s baseline N/A
256 True none jvp_attn 0.464 1.72 0.2 TFLOP/s 1.95e-03 ✓
512 False additive sdpa 1.399 31.88 0.2 TFLOP/s baseline N/A
512 False additive jvp_attn 0.481 5.95 1.4 TFLOP/s 4.88e-04 ✓
512 False boolean sdpa 1.222 34.38 0.3 TFLOP/s baseline N/A
512 False boolean jvp_attn 0.489 3.45 1.4 TFLOP/s 4.88e-04 ✓
512 False none sdpa 1.106 31.88 0.3 TFLOP/s baseline N/A
512 False none jvp_attn 0.475 3.45 1.4 TFLOP/s 4.88e-04 ✓
512 True none sdpa 1.354 32.88 0.1 TFLOP/s baseline N/A
512 True none jvp_attn 0.493 3.45 0.7 TFLOP/s 1.95e-03 ✓
1024 False additive sdpa 2.430 113.77 0.6 TFLOP/s baseline N/A
1024 False additive jvp_attn 0.480 16.89 5.7 TFLOP/s 4.88e-04 ✓
1024 False boolean sdpa 2.889 123.77 0.5 TFLOP/s baseline N/A
1024 False boolean jvp_attn 0.483 6.89 5.7 TFLOP/s 4.88e-04 ✓
1024 False none sdpa 2.457 113.77 0.6 TFLOP/s baseline N/A
1024 False none jvp_attn 0.467 6.89 5.9 TFLOP/s 4.88e-04 ✓
1024 True none sdpa 2.670 117.77 0.3 TFLOP/s baseline N/A
1024 True none jvp_attn 0.500 6.89 2.7 TFLOP/s 1.95e-03 ✓
2048 False additive sdpa 7.791 427.54 0.7 TFLOP/s baseline N/A
2048 False additive jvp_attn 0.696 53.79 15.7 TFLOP/s 2.44e-04 ✓
2048 False boolean sdpa 7.673 467.54 0.7 TFLOP/s baseline N/A
2048 False boolean jvp_attn 0.755 13.79 14.5 TFLOP/s 2.44e-04 ✓
2048 False none sdpa 7.773 427.54 0.7 TFLOP/s baseline N/A
2048 False none jvp_attn 0.614 13.79 17.8 TFLOP/s 2.44e-04 ✓
2048 True none sdpa 8.609 443.54 0.3 TFLOP/s baseline N/A
2048 True none jvp_attn 0.464 13.79 11.8 TFLOP/s 1.95e-03 ✓
================================================================================
MASK TYPE PERFORMANCE COMPARISON
================================================================================
Seq Len Causal Method No Mask Boolean Mask Additive Mask
--------------------------------------------------------------------------------
32 False jvp_attn 0.47 ms 0.47 ms (1.01x) 0.48 ms (1.02x)
32 True jvp_attn 0.46 ms N/A N/A
64 False jvp_attn 0.47 ms 0.47 ms (1.00x) 0.49 ms (1.03x)
64 True jvp_attn 0.50 ms N/A N/A
128 False jvp_attn 0.47 ms 0.47 ms (1.00x) 0.49 ms (1.05x)
128 True jvp_attn 0.48 ms N/A N/A
256 False jvp_attn 0.47 ms 0.47 ms (0.99x) 0.47 ms (1.00x)
256 True jvp_attn 0.46 ms N/A N/A
512 False jvp_attn 0.47 ms 0.49 ms (1.03x) 0.48 ms (1.01x)
512 True jvp_attn 0.49 ms N/A N/A
1024 False jvp_attn 0.47 ms 0.48 ms (1.03x) 0.48 ms (1.03x)
1024 True jvp_attn 0.50 ms N/A N/A
2048 False jvp_attn 0.61 ms 0.75 ms (1.23x) 0.70 ms (1.13x)
2048 True jvp_attn 0.46 ms N/A N/A
============================================================
STATISTICS
============================================================
Average speedup: 4.00x
Min speedup: 1.13x
Max speedup: 18.54x
Accuracy: 28/28 tests passed
✓ All accuracy checks passed!
Results for bfloat16:
==============================================================================================================
BENCHMARK SUMMARY
==============================================================================================================
Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
--------------------------------------------------------------------------------------------------------------
32 False additive sdpa 0.900 0.64 0.0 TFLOP/s baseline N/A
32 False additive jvp_attn 0.661 0.23 0.0 TFLOP/s 1.56e-02 ✓
32 False boolean sdpa 0.883 0.65 0.0 TFLOP/s baseline N/A
32 False boolean jvp_attn 0.578 0.22 0.0 TFLOP/s 1.56e-02 ✓
32 False none sdpa 0.603 0.64 0.0 TFLOP/s baseline N/A
32 False none jvp_attn 0.513 0.22 0.0 TFLOP/s 1.56e-02 ✓
32 True none sdpa 0.915 0.65 0.0 TFLOP/s baseline N/A
32 True none jvp_attn 0.519 0.22 0.0 TFLOP/s 1.56e-02 ✓
64 False additive sdpa 0.833 1.41 0.0 TFLOP/s baseline N/A
64 False additive jvp_attn 0.531 0.47 0.0 TFLOP/s 7.81e-03 ✓
64 False boolean sdpa 0.870 1.45 0.0 TFLOP/s baseline N/A
64 False boolean jvp_attn 0.545 0.43 0.0 TFLOP/s 7.81e-03 ✓
64 False none sdpa 0.565 1.41 0.0 TFLOP/s baseline N/A
64 False none jvp_attn 0.508 0.43 0.0 TFLOP/s 7.81e-03 ✓
64 True none sdpa 0.939 1.42 0.0 TFLOP/s baseline N/A
64 True none jvp_attn 0.520 0.43 0.0 TFLOP/s 1.56e-02 ✓
128 False additive sdpa 0.864 3.28 0.0 TFLOP/s baseline N/A
128 False additive jvp_attn 0.476 1.02 0.1 TFLOP/s 7.81e-03 ✓
128 False boolean sdpa 0.839 3.44 0.0 TFLOP/s baseline N/A
128 False boolean jvp_attn 0.460 0.86 0.1 TFLOP/s 7.81e-03 ✓
128 False none sdpa 0.798 3.28 0.0 TFLOP/s baseline N/A
128 False none jvp_attn 0.519 0.86 0.1 TFLOP/s 7.81e-03 ✓
128 True none sdpa 0.886 3.35 0.0 TFLOP/s baseline N/A
128 True none jvp_attn 0.504 0.86 0.0 TFLOP/s 1.56e-02 ✓
256 False additive sdpa 1.164 9.69 0.1 TFLOP/s baseline N/A
256 False additive jvp_attn 0.471 2.35 0.4 TFLOP/s 7.81e-03 ✓
256 False boolean sdpa 0.918 10.32 0.1 TFLOP/s baseline N/A
256 False boolean jvp_attn 0.468 1.72 0.4 TFLOP/s 7.81e-03 ✓
256 False none sdpa 0.780 9.69 0.1 TFLOP/s baseline N/A
256 False none jvp_attn 0.463 1.72 0.4 TFLOP/s 3.91e-03 ✓
256 True none sdpa 1.187 9.94 0.0 TFLOP/s baseline N/A
256 True none jvp_attn 0.463 1.72 0.2 TFLOP/s 1.56e-02 ✓
512 False additive sdpa 1.059 31.88 0.3 TFLOP/s baseline N/A
512 False additive jvp_attn 0.476 5.95 1.4 TFLOP/s 3.91e-03 ✓
512 False boolean sdpa 1.053 34.38 0.3 TFLOP/s baseline N/A
512 False boolean jvp_attn 0.463 3.45 1.5 TFLOP/s 3.91e-03 ✓
512 False none sdpa 1.131 31.88 0.3 TFLOP/s baseline N/A
512 False none jvp_attn 0.466 3.45 1.5 TFLOP/s 3.91e-03 ✓
512 True none sdpa 1.651 32.88 0.1 TFLOP/s baseline N/A
512 True none jvp_attn 0.470 3.45 0.7 TFLOP/s 1.56e-02 ✓
1024 False additive sdpa 2.525 113.77 0.5 TFLOP/s baseline N/A
1024 False additive jvp_attn 0.489 16.89 5.6 TFLOP/s 3.91e-03 ✓
1024 False boolean sdpa 2.775 123.77 0.5 TFLOP/s baseline N/A
1024 False boolean jvp_attn 0.471 6.89 5.8 TFLOP/s 3.91e-03 ✓
1024 False none sdpa 2.393 113.77 0.6 TFLOP/s baseline N/A
1024 False none jvp_attn 0.482 6.89 5.7 TFLOP/s 3.91e-03 ✓
1024 True none sdpa 2.319 117.77 0.3 TFLOP/s baseline N/A
1024 True none jvp_attn 0.463 6.89 3.0 TFLOP/s 1.56e-02 ✓
2048 False additive sdpa 8.504 427.54 0.6 TFLOP/s baseline N/A
2048 False additive jvp_attn 0.804 53.79 13.6 TFLOP/s 1.95e-03 ✓
2048 False boolean sdpa 8.944 467.54 0.6 TFLOP/s baseline N/A
2048 False boolean jvp_attn 0.852 13.79 12.9 TFLOP/s 1.95e-03 ✓
2048 False none sdpa 6.810 427.54 0.8 TFLOP/s baseline N/A
2048 False none jvp_attn 0.900 13.79 12.2 TFLOP/s 1.95e-03 ✓
2048 True none sdpa 8.846 443.54 0.3 TFLOP/s baseline N/A
2048 True none jvp_attn 0.508 13.79 10.8 TFLOP/s 3.12e-02 ✓
================================================================================
MASK TYPE PERFORMANCE COMPARISON
================================================================================
Seq Len Causal Method No Mask Boolean Mask Additive Mask
--------------------------------------------------------------------------------
32 False jvp_attn 0.51 ms 0.58 ms (1.13x) 0.66 ms (1.29x)
32 True jvp_attn 0.52 ms N/A N/A
64 False jvp_attn 0.51 ms 0.54 ms (1.07x) 0.53 ms (1.05x)
64 True jvp_attn 0.52 ms N/A N/A
128 False jvp_attn 0.52 ms 0.46 ms (0.89x) 0.48 ms (0.92x)
128 True jvp_attn 0.50 ms N/A N/A
256 False jvp_attn 0.46 ms 0.47 ms (1.01x) 0.47 ms (1.02x)
256 True jvp_attn 0.46 ms N/A N/A
512 False jvp_attn 0.47 ms 0.46 ms (0.99x) 0.48 ms (1.02x)
512 True jvp_attn 0.47 ms N/A N/A
1024 False jvp_attn 0.48 ms 0.47 ms (0.98x) 0.49 ms (1.01x)
1024 True jvp_attn 0.46 ms N/A N/A
2048 False jvp_attn 0.90 ms 0.85 ms (0.95x) 0.80 ms (0.89x)
2048 True jvp_attn 0.51 ms N/A N/A
============================================================
STATISTICS
============================================================
Average speedup: 3.75x
Min speedup: 1.11x
Max speedup: 17.40x
Accuracy: 28/28 tests passed
✓ All accuracy checks passed!
Results for float32:
==============================================================================================================
BENCHMARK SUMMARY
==============================================================================================================
Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
--------------------------------------------------------------------------------------------------------------
32 False additive sdpa 0.724 0.51 0.0 TFLOP/s baseline N/A
32 False additive jvp_attn 0.523 0.45 0.0 TFLOP/s 7.21e-03 ✓
32 False boolean sdpa 0.764 0.53 0.0 TFLOP/s baseline N/A
32 False boolean jvp_attn 0.500 0.43 0.0 TFLOP/s 7.21e-03 ✓
32 False none sdpa 0.454 0.51 0.0 TFLOP/s baseline N/A
32 False none jvp_attn 0.521 0.43 0.0 TFLOP/s 7.22e-03 ✓
32 True none sdpa 0.771 0.51 0.0 TFLOP/s baseline N/A
32 True none jvp_attn 0.530 0.43 0.0 TFLOP/s 6.18e-03 ✓
64 False additive sdpa 0.731 1.09 0.0 TFLOP/s baseline N/A
64 False additive jvp_attn 0.503 0.94 0.0 TFLOP/s 7.17e-03 ✓
64 False boolean sdpa 0.760 1.17 0.0 TFLOP/s baseline N/A
64 False boolean jvp_attn 0.501 0.86 0.0 TFLOP/s 7.17e-03 ✓
64 False none sdpa 0.447 1.09 0.0 TFLOP/s baseline N/A
64 False none jvp_attn 0.497 0.86 0.0 TFLOP/s 7.03e-03 ✓
64 True none sdpa 0.790 1.11 0.0 TFLOP/s baseline N/A
64 True none jvp_attn 0.507 0.86 0.0 TFLOP/s 6.18e-03 ✓
128 False additive sdpa 0.702 2.81 0.0 TFLOP/s baseline N/A
128 False additive jvp_attn 0.494 2.03 0.1 TFLOP/s 5.41e-03 ✓
128 False boolean sdpa 0.826 3.13 0.0 TFLOP/s baseline N/A
128 False boolean jvp_attn 0.478 1.72 0.1 TFLOP/s 5.41e-03 ✓
128 False none sdpa 0.579 2.81 0.0 TFLOP/s baseline N/A
128 False none jvp_attn 0.514 1.72 0.1 TFLOP/s 5.07e-03 ✓
128 True none sdpa 0.837 2.88 0.0 TFLOP/s baseline N/A
128 True none jvp_attn 0.537 1.72 0.0 TFLOP/s 6.18e-03 ✓
256 False additive sdpa 0.687 8.75 0.1 TFLOP/s baseline N/A
256 False additive jvp_attn 0.481 4.69 0.4 TFLOP/s 3.41e-03 ✓
256 False boolean sdpa 0.797 10.00 0.1 TFLOP/s baseline N/A
256 False boolean jvp_attn 0.506 3.44 0.3 TFLOP/s 3.41e-03 ✓
256 False none sdpa 0.466 8.75 0.2 TFLOP/s baseline N/A
256 False none jvp_attn 0.474 3.44 0.4 TFLOP/s 3.67e-03 ✓
256 True none sdpa 1.024 9.00 0.0 TFLOP/s baseline N/A
256 True none jvp_attn 0.496 3.44 0.2 TFLOP/s 5.78e-03 ✓
512 False additive sdpa 0.982 30.01 0.3 TFLOP/s baseline N/A
512 False additive jvp_attn 0.515 11.88 1.3 TFLOP/s 3.09e-03 ✓
512 False boolean sdpa 1.413 35.01 0.2 TFLOP/s baseline N/A
512 False boolean jvp_attn 0.487 6.88 1.4 TFLOP/s 3.09e-03 ✓
512 False none sdpa 1.362 30.01 0.3 TFLOP/s baseline N/A
512 False none jvp_attn 0.481 6.88 1.4 TFLOP/s 2.88e-03 ✓
512 True none sdpa 0.968 31.01 0.2 TFLOP/s baseline N/A
512 True none jvp_attn 0.481 6.88 0.7 TFLOP/s 5.13e-03 ✓
1024 False additive sdpa 2.381 110.02 0.6 TFLOP/s baseline N/A
1024 False additive jvp_attn 0.708 33.77 3.9 TFLOP/s 2.84e-03 ✓
1024 False boolean sdpa 3.475 130.02 0.4 TFLOP/s baseline N/A
1024 False boolean jvp_attn 0.593 13.77 4.6 TFLOP/s 2.84e-03 ✓
1024 False none sdpa 2.246 110.02 0.6 TFLOP/s baseline N/A
1024 False none jvp_attn 0.488 13.77 5.6 TFLOP/s 2.61e-03 ✓
1024 True none sdpa 2.700 115.02 0.3 TFLOP/s baseline N/A
1024 True none jvp_attn 0.531 13.77 2.6 TFLOP/s 5.61e-03 ✓
2048 False additive sdpa 8.094 420.04 0.7 TFLOP/s baseline N/A
2048 False additive jvp_attn 1.274 107.54 8.6 TFLOP/s 1.57e-03 ✓
2048 False boolean sdpa 7.980 500.04 0.7 TFLOP/s baseline N/A
2048 False boolean jvp_attn 1.332 27.54 8.2 TFLOP/s 1.57e-03 ✓
2048 False none sdpa 6.471 420.04 0.8 TFLOP/s baseline N/A
2048 False none jvp_attn 1.345 27.54 8.1 TFLOP/s 1.56e-03 ✓
2048 True none sdpa 8.308 436.04 0.3 TFLOP/s baseline N/A
2048 True none jvp_attn 0.749 27.54 7.3 TFLOP/s 6.47e-03 ✓
================================================================================
MASK TYPE PERFORMANCE COMPARISON
================================================================================
Seq Len Causal Method No Mask Boolean Mask Additive Mask
--------------------------------------------------------------------------------
32 False jvp_attn 0.52 ms 0.50 ms (0.96x) 0.52 ms (1.00x)
32 True jvp_attn 0.53 ms N/A N/A
64 False jvp_attn 0.50 ms 0.50 ms (1.01x) 0.50 ms (1.01x)
64 True jvp_attn 0.51 ms N/A N/A
128 False jvp_attn 0.51 ms 0.48 ms (0.93x) 0.49 ms (0.96x)
128 True jvp_attn 0.54 ms N/A N/A
256 False jvp_attn 0.47 ms 0.51 ms (1.07x) 0.48 ms (1.01x)
256 True jvp_attn 0.50 ms N/A N/A
512 False jvp_attn 0.48 ms 0.49 ms (1.01x) 0.52 ms (1.07x)
512 True jvp_attn 0.48 ms N/A N/A
1024 False jvp_attn 0.49 ms 0.59 ms (1.22x) 0.71 ms (1.45x)
1024 True jvp_attn 0.53 ms N/A N/A
2048 False jvp_attn 1.34 ms 1.33 ms (0.99x) 1.27 ms (0.95x)
2048 True jvp_attn 0.75 ms N/A N/A
============================================================
STATISTICS
============================================================
Average speedup: 2.83x
Min speedup: 0.87x
Max speedup: 11.09x
Accuracy: 28/28 tests passed
✓ All accuracy checks passed!
This project is covered under the MIT License.
If you use the code associated with this package or otherwise find this work useful, please use GitHub's Cite this repository feature or the BibTeX below.
@software{Morehead_JVP_Flash_Attention_2025,
author = {Morehead, Alex},
doi = {10.5281/zenodo.17050188},
license = {MIT},
month = sep,
title = {{JVP Flash Attention}},
url = {https://github.com/amorehead/jvp_flash_attention},
version = {0.0.3},
year = {2025}
}jvp_flash_attention builds upon the contributions and insights from the following sources:
Thank you to each and every contributor!
