Skip to content

[CPU] Support fp8 template qsdpa#4348

Open
Valentine233 wants to merge 2 commits intomainfrom
fp8_template_qsdpa_cpu
Open

[CPU] Support fp8 template qsdpa#4348
Valentine233 wants to merge 2 commits intomainfrom
fp8_template_qsdpa_cpu

Conversation

@Valentine233
Copy link
Copy Markdown
Collaborator

Currently, there is only the customized kernel for fp8 qsdpa. Similar to int8 qsdpa, this PR adds the template path for fp8 qsdpa through inductor.

Testing plan: python test/prototype/inductor/test_qsdpa_fusion.py -k test_fp8_sdpa_rewriter_cpu

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 28, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4348

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit ede12d5 with merge base 9052ece (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 28, 2026
@Valentine233 Valentine233 added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 28, 2026
@Valentine233 Valentine233 marked this pull request as draft April 28, 2026 06:38
@Valentine233
Copy link
Copy Markdown
Collaborator Author

@mingfeima @Xia-Weiwen Please help review the PR, thanks!

@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @Xia-Weiwen's task in 2m 59s —— View job


Code Review: [CPU] Support fp8 template qsdpa

  • Gather context and read changed files
  • Review test changes
  • Review C++ kernel changes
  • Review codegen and template changes
  • Review lowering changes
  • Post final review

Overall this is a well-structured PR that extends the existing int8 QSDPA template infrastructure to support fp8 (float8_e4m3fn) through inductor. The approach is sound — reusing the existing template class and adding fp8-specific fusion kernels. A few observations below:


Issues

1. Potential logic issue in qsdpa_lowering.py:102-108 — missing dtype guard for int8 path

The original code was:

if is_amx_tile_supported and query.get_dtype() is torch.uint8:
    CppInt8SdpaTemplate.add_choices(...)

The new code is:

if is_amx_tile_supported:
    if query.get_dtype() is torch.float8_e4m3fn:
        q_zp = 0
        ...
    CppQsdpaTemplate.add_choices(...)

This means CppQsdpaTemplate.add_choices() is now called for any dtype when AMX is supported, not just uint8 and float8_e4m3fn. If an unsupported dtype somehow reaches here, the assert at cpp_qsdpa_template.py:2070 would catch it, but an explicit dtype guard would be cleaner. In practice, the earlier check at qsdpa_lowering.py:70-78 prevents other dtypes from reaching this point, so this is low-risk but worth noting.

2. FP8 template doesn't support attention mask in computation (cpp_qsdpa_template.py:1982-1995)

The FP8_SDPA_SEVERAL_LOOPS_TEMPLATE declares attention mask variables (lines 1804-1824) but never uses them in the actual dequant+max step. The int8 template has dequant_mask_max_fusion_kernel with mask support, but the fp8 path only calls fp8_dequant_max_fusion_kernel which doesn't accept a mask parameter. If has_attention_mask is true, the compiled template would have unused variables and the mask would be silently ignored. The test uses has_mask=False so this isn't exercised. Consider either:

  • Adding a fp8_dequant_mask_max_fusion_kernel variant, or
  • Adding a guard to prevent the fp8 template from being selected when an attention mask is present

3. Stale comment in qsdpa_lowering.py:95

# use template if machine has amx, only support uint8 for now

This should be updated since fp8 is now also supported.


Minor / Style Issues

4. Indentation inconsistency in fp8_dequant_quant_fusion_kernel (cpp_qsdpa_template.py:625-626)

inline void fp8_dequant_quant_fusion_kernel(
    const float* in,
  const float* sum_a_ptr,  // <-- 2-space indent instead of 4-space
    const int& M,

5. Redundant at::vec::convert<float> calls in fp8 kernels

In fp8_dequant_max_fusion_kernel (line 228) and fp8_dequant_quant_fusion_kernel (line 648):

auto tmp2 = tmp1 - vec_sum_a;
auto tmp3 = at::vec::convert<float>(tmp2);  // tmp2 is already float

Since tmp2 is already Vectorized<float>, convert<float> is a no-op. This appears to be copy-paste from the int8 version where it converts from int32_t. Not a bug, but unnecessary.

6. Unused includes in FP8 template (cpp_qsdpa_template.py:1743-1746)

#include <cstdlib>
#include <fstream>
#include <string>

These headers (cstdlib, fstream, string) are not used in the FP8 template and appear to be debug leftovers.

7. No one-loop template variant for FP8

The int8 path selects between INT8_SDPA_ONE_LOOP_TEMPLATE and INT8_SDPA_SEVERAL_LOOPS_TEMPLATE based on use_one_parallel_loop, but fp8 always uses the several-loops variant (line 2289). This means the get_options() logic that computes use_one_parallel_loop is partially unused for fp8. Consider adding a comment explaining why (e.g., fp8 always needs the several-loops approach), or noting it as future work.

8. The CppTemplate.__init__ still uses "int8_sdpa" as the name for fp8 (cpp_qsdpa_template.py:2072)

CppTemplate.__init__(
    self, "int8_sdpa", input_nodes, layout, parallel_num_threads()
)

This name is used for internal template identification. Consider updating to something more generic like "qsdpa" since it now handles both int8 and fp8.


Looks Good

  • The removal of unused ompIdx in quantized_sdpa.cpp is a clean fix
  • Renaming cpp_int8_sdpa_template.pycpp_qsdpa_template.py and CppInt8SdpaTemplateCppQsdpaTemplate is appropriate for the generalization
  • The test assertion change to check for "QSDPA_TEMPLATE" is cleaner and works for both int8 and fp8
  • The fp8 fusion kernels follow a consistent pattern with the existing int8 kernels
  • The clamp to [-448, 448] is correct for float8_e4m3fn range

@Valentine233 Valentine233 force-pushed the fp8_template_qsdpa_cpu branch from b498317 to 904b80f Compare April 30, 2026 07:31
@Valentine233 Valentine233 marked this pull request as ready for review April 30, 2026 07:31
@Valentine233 Valentine233 force-pushed the fp8_template_qsdpa_cpu branch from 904b80f to ede12d5 Compare April 30, 2026 07:33
@Valentine233
Copy link
Copy Markdown
Collaborator Author

@jerryzh168 @drisspg Please help review the PR, thanks!

@Valentine233 Valentine233 requested a review from drisspg April 30, 2026 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants