[CPU] Support fp8 template qsdpa#4348
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4348
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit ede12d5 with merge base 9052ece ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@mingfeima @Xia-Weiwen Please help review the PR, thanks! |
|
@claude review |
|
Claude finished @Xia-Weiwen's task in 2m 59s —— View job Code Review: [CPU] Support fp8 template qsdpa
Overall this is a well-structured PR that extends the existing int8 QSDPA template infrastructure to support fp8 (float8_e4m3fn) through inductor. The approach is sound — reusing the existing template class and adding fp8-specific fusion kernels. A few observations below: Issues1. Potential logic issue in The original code was: if is_amx_tile_supported and query.get_dtype() is torch.uint8:
CppInt8SdpaTemplate.add_choices(...)The new code is: if is_amx_tile_supported:
if query.get_dtype() is torch.float8_e4m3fn:
q_zp = 0
...
CppQsdpaTemplate.add_choices(...)This means 2. FP8 template doesn't support attention mask in computation ( The
3. Stale comment in # use template if machine has amx, only support uint8 for nowThis should be updated since fp8 is now also supported. Minor / Style Issues4. Indentation inconsistency in inline void fp8_dequant_quant_fusion_kernel(
const float* in,
const float* sum_a_ptr, // <-- 2-space indent instead of 4-space
const int& M,5. Redundant In auto tmp2 = tmp1 - vec_sum_a;
auto tmp3 = at::vec::convert<float>(tmp2); // tmp2 is already floatSince 6. Unused includes in FP8 template ( #include <cstdlib>
#include <fstream>
#include <string>These headers ( 7. No one-loop template variant for FP8 The int8 path selects between 8. The CppTemplate.__init__(
self, "int8_sdpa", input_nodes, layout, parallel_num_threads()
)This name is used for internal template identification. Consider updating to something more generic like Looks Good
|
b498317 to
904b80f
Compare
904b80f to
ede12d5
Compare
|
@jerryzh168 @drisspg Please help review the PR, thanks! |
Currently, there is only the customized kernel for fp8 qsdpa. Similar to int8 qsdpa, this PR adds the template path for fp8 qsdpa through inductor.
Testing plan:
python test/prototype/inductor/test_qsdpa_fusion.py -k test_fp8_sdpa_rewriter_cpu