Skip to content

[FSDP2+Float8] Training loss diverges when enable_fsdp_float8_all_gather=True vs False on H20 with Qwen3-8B #4300

@szf-Copper

Description

@szf-Copper

Problem

[FSDP2+Float8] Training loss diverges when enable_fsdp_float8_all_gather=True vs False on H20 with Qwen3-8B

Description

I am running a 2-GPU FSDP2 + Float8 pre-training experiment on Qwen3-8B using torchao 0.14.1 and torch 2.9.0 on NVIDIA H20 GPUs. I have observed a significant divergence in the training loss convergence behavior depending on the setting of enable_fsdp_float8_all_gather in Float8LinearConfig. Specifically, the loss curves when this flag is set to True are consistently different from when it is set to False, and the final loss values do not match. This is unexpected because the official PyTorch blog post "Supercharging Training using float8 and FSDP2" states that float8 all_gather should produce "identical loss convergence" compared to the bf16 baseline.

Environment

  • Hardware: 2 × NVIDIA H20 GPUs
  • PyTorch: 2.9.0
  • TorchAO: 0.14.1
  • CUDA: 12.8
  • Model: Qwen3-8B
  • Dataset: WikiText-103 (first 4000 samples)
  • Training Framework: Hugging Face Transformers Trainer with FSDP2

Steps to Reproduce

  1. Use the attached training script below.
  2. Run the script with enable_fsdp_float8_all_gather=True.
  3. Run the same script with enable_fsdp_float8_all_gather=False.
  4. Compare the loss values at each logging step.

Training Script (Minimal Example)

import os
import sys
import time
import torch
import logging
from contextlib import redirect_stdout, redirect_stderr
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    TrainerState,
    TrainerControl
)

from torchao.float8 import (
    convert_to_float8_training,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
    ScalingGranularity
)

import torch.distributed as dist
import torch.profiler as profiler
from torch.profiler import profile, ProfilerActivity
from datasets import load_from_disk

# ========================
# Globally disable all logging
# ========================
for logger_name in ["transformers", "huggingface_hub", "accelerate", "torch", "datasets"]:
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.ERROR)
    logger.propagate = False

# ========================
# Step 2: Output redirection
# ========================
class CustomOutputFilter:
    def __init__(self):
        self.allowed_prefix = "training-log"
        self.train_started = False
        self.rank = None
        self.profiler_prefixes = ["[Profiler]", "Profiling", "float8", "trace", "tensorboard", "✅", "❌", "📌"]

    def write(self, msg):
        if self.rank is None:
            self.rank = dist.get_rank() if dist.is_initialized() else 0
        if self.rank != 0:
            return
        msg_stripped = msg.strip()
        if not self.train_started:
            if msg_stripped and not msg_stripped.startswith("{"):
                sys.__stdout__.write(msg)
                sys.__stdout__.flush()
            if "Starting training..." in msg:
                self.train_started = True
        else:
            if (msg_stripped.startswith(self.allowed_prefix) or
                any(p in msg_stripped for p in self.profiler_prefixes) or
                "trace_step_" in msg_stripped):
                sys.__stdout__.write(msg)
                sys.__stdout__.flush()

    def flush(self):
        sys.__stdout__.flush()

# ========================
# Step 3: Custom logging callback
# ========================
class CustomLogCallback(TrainerCallback):
    def __init__(self, total_samples, batch_size):
        self.total_samples = total_samples
        self.batch_size = batch_size
        self.step_losses = []
        self.last_log_time = None
        self.last_log_step = 0
        self.train_start_time = time.time()
        self.is_main_process = None

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        tr_loss = kwargs.get("tr_loss", None)
        if tr_loss is not None:
            self.step_losses.append(tr_loss.item())

    def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if self.is_main_process is None:
            self.is_main_process = (dist.get_rank() == 0) if dist.is_initialized() else True
        if not self.is_main_process:
            return
        if logs is None or "loss" not in logs:
            return

        current_time = time.time()
        current_step = state.global_step

        if self.last_log_time is None:
            self.last_log_time = self.train_start_time
            self.last_log_step = 0

        time_diff = current_time - self.last_log_time
        step_diff = current_step - self.last_log_step
        avg_step_time = time_diff / step_diff if step_diff > 0 else 0.0

        step = current_step
        loss = float(logs["loss"])
        grad_norm = float(logs.get("grad_norm", 0))
        lr = float(logs.get("learning_rate", 0))

        trained_samples = step * self.batch_size
        epoch = trained_samples / self.total_samples
        epoch = min(epoch, args.num_train_epochs)

        log_msg = (
            f"[training-log] - "
            f"step: {step}, "
            f"loss: {loss:.8f}, "
            f"grad_norm: {grad_norm:.4f}, "
            f"learning_rate: {lr:.4e}, "
            f"epoch: {epoch}, "
            f"avg_step_time: {avg_step_time:.4f}s/step"
            f"\n"
        )
        print(log_msg)

        self.last_log_time = current_time
        self.last_log_step = current_step

        logs["epoch"] = epoch
        logs["avg_step_time"] = avg_step_time
        state.log_history[-1] = logs

# ========================
# Step 4: Profiler callback
# ========================
class ProfilerCallback(TrainerCallback):
    def __init__(self, max_steps=2, output_dir="./qwen3_8b_fsdp2_fp8_profile"):
        self.max_steps = max_steps
        self.output_dir = output_dir
        self.profiler = None
        self.trace_count = 0
        self.is_main_process = None
        os.makedirs(self.output_dir, exist_ok=True)
        for f in os.listdir(self.output_dir):
            if f.startswith("trace"):
                os.remove(os.path.join(self.output_dir, f))

    def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if self.is_main_process is None:
            self.is_main_process = (dist.get_rank() == 0) if dist.is_initialized() else True
        if not self.is_main_process:
            return
        step = state.global_step
        if step < self.max_steps:
            print(f"[Profiler] Starting profile capture for step {step + 1}...\n")
            self.profiler = profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
                with_flops=True,
                with_modules=True,
                schedule=profiler.schedule(
                    wait=0,
                    warmup=0,
                    active=1,
                    repeat=0
                )
            )
            self.profiler.start()

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if not self.is_main_process:
            return
        step = state.global_step
        if step == self.max_steps and self.profiler is not None:
            self.profiler.stop()
            trace_json_file = os.path.join(self.output_dir, f"trace_step_{step + 1}.json")
            self.profiler.export_chrome_trace(trace_json_file)

            print(f"[Profiler] Profile data for step {step + 1} saved:")
            print(f"  - Chrome Trace: {self.output_dir}/trace_step_{step + 1}.json")

            self.trace_count += 1
            self.profiler = None

# ========================
# Step 5: Trainer override
# ========================
class NoLogTrainer(Trainer):
    def log(self, logs, start_time=None):
        if logs is not None:
            self.state.log_history.append(logs)
            self.control = self.callback_handler.on_log(
                self.args, self.state, self.control, logs
            )

# ========================
# Constants
# ========================
MODEL_NAME = "/home/public_data/model/Qwen3-8B"
BATCH_SIZE = 4
MAX_SEQ_LEN = 512
NUM_EPOCHS = 10
OUTPUT_DIR = "./qwen3_8b_fsdp2_torchao_fp8_random"
DATASET_RANGE_SIZE = 4000

# ========================
# Dataset loading
# ========================
dataset = load_from_disk("../dataset_wiki/wikitext-103-raw-v1-train")
dataset = dataset.select(range(DATASET_RANGE_SIZE))
total_samples = len(dataset)

print(f"Config are: MODEL_NAME: {MODEL_NAME}, "
      f"BATCH_SIZE of each DEVICE: {BATCH_SIZE}, "
      f"NUM_EPOCHS: {NUM_EPOCHS}, "
      f"MAX_SEQ_LEN: {MAX_SEQ_LEN}, "
      f"DATASET_RANGE_SIZE: {DATASET_RANGE_SIZE}, "
      f"OUTPUT_DIR: {OUTPUT_DIR}")

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    local_files_only=True
)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_SEQ_LEN,
        padding="max_length",
        return_tensors="pt"
    )

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids"])

# ========================
# Model loading
# ========================
print("Loading model with BF16 precision...")

model_config = AutoConfig.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    local_files_only=True
)

model = AutoModelForCausalLM.from_config(
    model_config,
    torch_dtype=torch.bfloat16,
)

# Configure Float8 conversion
cast_config = CastConfig(
    scaling_type=ScalingType.DYNAMIC,
    scaling_granularity=ScalingGranularity.TENSORWISE
)

config = Float8LinearConfig(
    cast_config_input=cast_config,
    cast_config_weight=cast_config,
    cast_config_grad_output=cast_config,
    enable_fsdp_float8_all_gather=True  # <-- Set to False for the second run
)

# Convert to Float8 training mode
convert_to_float8_training(model, config=config)
print("The model has been converted to Float8 training mode.")

# enable torch.compile for competitive performance
model = torch.compile(model)

# ========================
# Training arguments
# ========================
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="no",
    fp16=False,
    bf16=True,
    report_to="none",
    disable_tqdm=True,
    log_level="error",
    log_level_replica="error",
    remove_unused_columns=False,
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "fsdp_version": 2,
        "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
        "fsdp_transformer_layer_cls_to_wrap": ["Qwen3DecoderLayer"],
        "fsdp_backward_prefetch": "BACKWARD_PRE",
        "fsdp_forward_prefetch": False,
        "fsdp_use_orig_params": True,
        "fsdp_cpu_ram_efficient_loading": True,
        "fsdp_sync_module_states": True,
    }
)

effective_batch_size = BATCH_SIZE * training_args.gradient_accumulation_steps

# ========================
# Initialize trainer and callbacks
# ========================

if dist.get_rank() == 0:
    print("Starting training...")
custom_log_callback = CustomLogCallback(total_samples, effective_batch_size)
profiler_callback = ProfilerCallback(max_steps=0, output_dir=OUTPUT_DIR)

trainer = NoLogTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=lambda data: {
        "input_ids": torch.stack([d["input_ids"].detach().clone() for d in data]),
        "labels": torch.stack([d["input_ids"].detach().clone() for d in data])
    },
    callbacks=[custom_log_callback, profiler_callback]
)

# ========================
# Run training
# ========================
output_filter = CustomOutputFilter()
with redirect_stdout(output_filter), redirect_stderr(output_filter):
    trainer.train()

time.sleep(5)
sys.exit(0)

Loss Comparison

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions