I am running a 2-GPU FSDP2 + Float8 pre-training experiment on Qwen3-8B using torchao 0.14.1 and torch 2.9.0 on NVIDIA H20 GPUs. I have observed a significant divergence in the training loss convergence behavior depending on the setting of enable_fsdp_float8_all_gather in Float8LinearConfig. Specifically, the loss curves when this flag is set to True are consistently different from when it is set to False, and the final loss values do not match. This is unexpected because the official PyTorch blog post "Supercharging Training using float8 and FSDP2" states that float8 all_gather should produce "identical loss convergence" compared to the bf16 baseline.
import os
import sys
import time
import torch
import logging
from contextlib import redirect_stdout, redirect_stderr
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
TrainerState,
TrainerControl
)
from torchao.float8 import (
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
CastConfig,
ScalingGranularity
)
import torch.distributed as dist
import torch.profiler as profiler
from torch.profiler import profile, ProfilerActivity
from datasets import load_from_disk
# ========================
# Globally disable all logging
# ========================
for logger_name in ["transformers", "huggingface_hub", "accelerate", "torch", "datasets"]:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.ERROR)
logger.propagate = False
# ========================
# Step 2: Output redirection
# ========================
class CustomOutputFilter:
def __init__(self):
self.allowed_prefix = "training-log"
self.train_started = False
self.rank = None
self.profiler_prefixes = ["[Profiler]", "Profiling", "float8", "trace", "tensorboard", "✅", "❌", "📌"]
def write(self, msg):
if self.rank is None:
self.rank = dist.get_rank() if dist.is_initialized() else 0
if self.rank != 0:
return
msg_stripped = msg.strip()
if not self.train_started:
if msg_stripped and not msg_stripped.startswith("{"):
sys.__stdout__.write(msg)
sys.__stdout__.flush()
if "Starting training..." in msg:
self.train_started = True
else:
if (msg_stripped.startswith(self.allowed_prefix) or
any(p in msg_stripped for p in self.profiler_prefixes) or
"trace_step_" in msg_stripped):
sys.__stdout__.write(msg)
sys.__stdout__.flush()
def flush(self):
sys.__stdout__.flush()
# ========================
# Step 3: Custom logging callback
# ========================
class CustomLogCallback(TrainerCallback):
def __init__(self, total_samples, batch_size):
self.total_samples = total_samples
self.batch_size = batch_size
self.step_losses = []
self.last_log_time = None
self.last_log_step = 0
self.train_start_time = time.time()
self.is_main_process = None
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
tr_loss = kwargs.get("tr_loss", None)
if tr_loss is not None:
self.step_losses.append(tr_loss.item())
def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
if self.is_main_process is None:
self.is_main_process = (dist.get_rank() == 0) if dist.is_initialized() else True
if not self.is_main_process:
return
if logs is None or "loss" not in logs:
return
current_time = time.time()
current_step = state.global_step
if self.last_log_time is None:
self.last_log_time = self.train_start_time
self.last_log_step = 0
time_diff = current_time - self.last_log_time
step_diff = current_step - self.last_log_step
avg_step_time = time_diff / step_diff if step_diff > 0 else 0.0
step = current_step
loss = float(logs["loss"])
grad_norm = float(logs.get("grad_norm", 0))
lr = float(logs.get("learning_rate", 0))
trained_samples = step * self.batch_size
epoch = trained_samples / self.total_samples
epoch = min(epoch, args.num_train_epochs)
log_msg = (
f"[training-log] - "
f"step: {step}, "
f"loss: {loss:.8f}, "
f"grad_norm: {grad_norm:.4f}, "
f"learning_rate: {lr:.4e}, "
f"epoch: {epoch}, "
f"avg_step_time: {avg_step_time:.4f}s/step"
f"\n"
)
print(log_msg)
self.last_log_time = current_time
self.last_log_step = current_step
logs["epoch"] = epoch
logs["avg_step_time"] = avg_step_time
state.log_history[-1] = logs
# ========================
# Step 4: Profiler callback
# ========================
class ProfilerCallback(TrainerCallback):
def __init__(self, max_steps=2, output_dir="./qwen3_8b_fsdp2_fp8_profile"):
self.max_steps = max_steps
self.output_dir = output_dir
self.profiler = None
self.trace_count = 0
self.is_main_process = None
os.makedirs(self.output_dir, exist_ok=True)
for f in os.listdir(self.output_dir):
if f.startswith("trace"):
os.remove(os.path.join(self.output_dir, f))
def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self.is_main_process is None:
self.is_main_process = (dist.get_rank() == 0) if dist.is_initialized() else True
if not self.is_main_process:
return
step = state.global_step
if step < self.max_steps:
print(f"[Profiler] Starting profile capture for step {step + 1}...\n")
self.profiler = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
schedule=profiler.schedule(
wait=0,
warmup=0,
active=1,
repeat=0
)
)
self.profiler.start()
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if not self.is_main_process:
return
step = state.global_step
if step == self.max_steps and self.profiler is not None:
self.profiler.stop()
trace_json_file = os.path.join(self.output_dir, f"trace_step_{step + 1}.json")
self.profiler.export_chrome_trace(trace_json_file)
print(f"[Profiler] Profile data for step {step + 1} saved:")
print(f" - Chrome Trace: {self.output_dir}/trace_step_{step + 1}.json")
self.trace_count += 1
self.profiler = None
# ========================
# Step 5: Trainer override
# ========================
class NoLogTrainer(Trainer):
def log(self, logs, start_time=None):
if logs is not None:
self.state.log_history.append(logs)
self.control = self.callback_handler.on_log(
self.args, self.state, self.control, logs
)
# ========================
# Constants
# ========================
MODEL_NAME = "/home/public_data/model/Qwen3-8B"
BATCH_SIZE = 4
MAX_SEQ_LEN = 512
NUM_EPOCHS = 10
OUTPUT_DIR = "./qwen3_8b_fsdp2_torchao_fp8_random"
DATASET_RANGE_SIZE = 4000
# ========================
# Dataset loading
# ========================
dataset = load_from_disk("../dataset_wiki/wikitext-103-raw-v1-train")
dataset = dataset.select(range(DATASET_RANGE_SIZE))
total_samples = len(dataset)
print(f"Config are: MODEL_NAME: {MODEL_NAME}, "
f"BATCH_SIZE of each DEVICE: {BATCH_SIZE}, "
f"NUM_EPOCHS: {NUM_EPOCHS}, "
f"MAX_SEQ_LEN: {MAX_SEQ_LEN}, "
f"DATASET_RANGE_SIZE: {DATASET_RANGE_SIZE}, "
f"OUTPUT_DIR: {OUTPUT_DIR}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
local_files_only=True
)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=MAX_SEQ_LEN,
padding="max_length",
return_tensors="pt"
)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
# ========================
# Model loading
# ========================
print("Loading model with BF16 precision...")
model_config = AutoConfig.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
local_files_only=True
)
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch.bfloat16,
)
# Configure Float8 conversion
cast_config = CastConfig(
scaling_type=ScalingType.DYNAMIC,
scaling_granularity=ScalingGranularity.TENSORWISE
)
config = Float8LinearConfig(
cast_config_input=cast_config,
cast_config_weight=cast_config,
cast_config_grad_output=cast_config,
enable_fsdp_float8_all_gather=True # <-- Set to False for the second run
)
# Convert to Float8 training mode
convert_to_float8_training(model, config=config)
print("The model has been converted to Float8 training mode.")
# enable torch.compile for competitive performance
model = torch.compile(model)
# ========================
# Training arguments
# ========================
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=1,
learning_rate=5e-5,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
save_strategy="no",
fp16=False,
bf16=True,
report_to="none",
disable_tqdm=True,
log_level="error",
log_level_replica="error",
remove_unused_columns=False,
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_version": 2,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_transformer_layer_cls_to_wrap": ["Qwen3DecoderLayer"],
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_forward_prefetch": False,
"fsdp_use_orig_params": True,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_sync_module_states": True,
}
)
effective_batch_size = BATCH_SIZE * training_args.gradient_accumulation_steps
# ========================
# Initialize trainer and callbacks
# ========================
if dist.get_rank() == 0:
print("Starting training...")
custom_log_callback = CustomLogCallback(total_samples, effective_batch_size)
profiler_callback = ProfilerCallback(max_steps=0, output_dir=OUTPUT_DIR)
trainer = NoLogTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=lambda data: {
"input_ids": torch.stack([d["input_ids"].detach().clone() for d in data]),
"labels": torch.stack([d["input_ids"].detach().clone() for d in data])
},
callbacks=[custom_log_callback, profiler_callback]
)
# ========================
# Run training
# ========================
output_filter = CustomOutputFilter()
with redirect_stdout(output_filter), redirect_stderr(output_filter):
trainer.train()
time.sleep(5)
sys.exit(0)
Problem
[FSDP2+Float8] Training loss diverges when
enable_fsdp_float8_all_gather=TruevsFalseon H20 with Qwen3-8BDescription
I am running a 2-GPU FSDP2 + Float8 pre-training experiment on Qwen3-8B using torchao 0.14.1 and torch 2.9.0 on NVIDIA H20 GPUs. I have observed a significant divergence in the training loss convergence behavior depending on the setting of
enable_fsdp_float8_all_gatherinFloat8LinearConfig. Specifically, the loss curves when this flag is set toTrueare consistently different from when it is set toFalse, and the final loss values do not match. This is unexpected because the official PyTorch blog post "Supercharging Training using float8 and FSDP2" states thatfloat8 all_gathershould produce "identical loss convergence" compared to the bf16 baseline.Environment
Trainerwith FSDP2Steps to Reproduce
enable_fsdp_float8_all_gather=True.enable_fsdp_float8_all_gather=False.Training Script (Minimal Example)
Loss Comparison