Add "32x1 transposed" variant to MXFP8 3D quantization kernel#4356
Add "32x1 transposed" variant to MXFP8 3D quantization kernel#4356alexsamardzic wants to merge 3 commits intogh/alexsamardzic/1/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4356
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit e84da90 with merge base 6367fd6 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Benchmarking results: |
|
@alexsamardzic can you benchmark this against the 2 stage approach we do here in ao/torchao/prototype/moe_training/mxfp8_grouped_mm.py Lines 555 to 558 in 9052ece |
Here is an adapted benchmarking script to compare between the two: bench_quantize_3d_vs_triton.py. And here are the results: |
danielvegamyhre
left a comment
There was a problem hiding this comment.
LGTM with some minor comments/questions
| x_clone = x.clone().requires_grad_(True) | ||
| w_t_clone = w_t.clone().requires_grad_(True) | ||
|
|
||
| fn = torch.compile(_to_mxfp8_then_scaled_grouped_mm, fullgraph=True) |
There was a problem hiding this comment.
why was compile removed here?
There was a problem hiding this comment.
Wrong edit, reverted.
| input_act = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") | ||
| weight = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda") | ||
| mat2 = weight.transpose(-2, -1) | ||
| scale_block_k = 1 |
There was a problem hiding this comment.
as an aside, in the future we should probably refactor to not use the "scale_block_n" / "scale_block_k" naming everywhere since that leads to cases like this where we are quantizing pre-transposed input shape (E,K,N) along K, yet setting scale_block_k=1 for api consistency.
maybe in a follow up we can refactor call them scale_block_dim1, scale_block_dim2 or something?
There was a problem hiding this comment.
Indeed these are better names, I did the rename now.
| * x_scale.unsqueeze(-1).to(torch.bfloat16) | ||
| ).reshape(M, K) | ||
|
|
||
| input_scale_ref = input_scale.repeat_interleave(block_size, dim=1) |
There was a problem hiding this comment.
why is the LHS input activation (scaled with 1x32) being repeat interleaved here? i would think repeat interleave would only be necessary for replicating the scale for a 32x32 weight scaling reference impl?
There was a problem hiding this comment.
This should be just dequantization for the BF16 reference, scales are (M, dim//32) so we expand to (M, dim).
| if cutlass.const_expr(INPUT_TRANSPOSED_VALUE): | ||
| staged_layout_in = cute.make_layout( | ||
| (STAGE_COUNT_VALUE, 1, TILE_N, TILE_K), | ||
| stride=(STAGE_ELEMS, STAGE_ELEMS, 1, TILE_N), |
There was a problem hiding this comment.
cutedsl question: why is "num stages" instances of TILE_NxTILE_K tiles represented as a 4d tensor of shape (stages, 1, tile_n, tile_k), rather than a 3d tensor of shape (stages, tile_n, tile_k)?
There was a problem hiding this comment.
I mechanically used (stages, 1, tile_n, tile_k) to mirror the per-stage TMA tile shape (1, tile_n, tile_k). The singleton dim indeed is not needed at all, so it's removed now.
| stride=(tile_n * tile_k, tile_k, 1), | ||
| ) | ||
| def _make_tile_smem_layouts( | ||
| cute, |
There was a problem hiding this comment.
why does cute package need to be a param here, is it due to doing the import with guards elsewhere instead of at the top? not a huge deal but this feels a bit awkward
There was a problem hiding this comment.
This is a remnant of first try to make code working in case when CuTeDSL package is not installed, replaced with a simple import within the body (also made the change for alike methods for 2D kernels).
Stack from ghstack (oldest at bottom):