pytorch-fsdp2
Scannednpx machina-cli add skill Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2 --openclawSkill: Use PyTorch FSDP2 (fully_shard) correctly in a training script
This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via
torch.distributed.fsdp.fully_shardand theFSDPModulemethods it adds in-place to modules. See:references/pytorch_fully_shard_api.md,references/pytorch_fsdp2_tutorial.md.
When to use this skill
Use FSDP2 when:
- Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
Alternatives (when FSDP2 is not the best fit)
- DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.
Contract the agent must follow
- Launch with
torchrunand set the CUDA device per process (usually viaLOCAL_RANK). - Apply
fully_shard()bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module. - Call
model(input), notmodel.forward(input), so the FSDP2 hooks run (unless you explicitlyunshard()or register the forward method). - Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-
fully_shard). - Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve
torch.save(model.state_dict())unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)
Step-by-step procedure
0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use
torchrun --nproc_per_node <gpus_per_node> ...and ensureRANK,WORLD_SIZE,LOCAL_RANKare visible.
Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).
1) Initialize distributed and set device
Minimal, correct pattern:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))- Optionally create a
DeviceMeshto describe the data-parallel group(s)
Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).
2) Build model on meta device (recommended for very large models)
For big models, initialize on meta, apply sharding, then materialize weights on GPU:
with torch.device("meta"): model = ...- apply
fully_shard(...)on submodules, thenfully_shard(model) model.to_empty(device="cuda")model.reset_parameters()(or your init routine)
Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).
3) Apply fully_shard() bottom-up (wrapping policy = “apply where needed”)
Do not only call fully_shard on the topmost module.
Recommended sharding pattern for transformer-like models:
- iterate modules,
if isinstance(m, TransformerBlock): fully_shard(m, ...) - then
fully_shard(model, ...)
Why:
fully_shardforms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).
4) Configure reshard_after_forward for memory/perf trade-offs
Default behavior:
NonemeansTruefor non-root modules andFalsefor root modules (good default).
Heuristics:
- If you’re memory-bound: keep defaults or force
Trueon many blocks. - If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
False). - Advanced: use an
intto reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference: references/pytorch_fully_shard_api.md (full semantics).
5) Mixed precision & offload (optional but common)
FSDP2 uses:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)offload_policy=CPUOffloadPolicy()if you want CPU offload
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep
reduce_dtypealigned with your gradient reduction expectations. - If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).
6) Optimizer, gradient clipping, accumulation
- Create the optimizer after sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism (
set_requires_gradient_sync) instead of FSDP1’sno_sync().
- use the FSDP2 mechanism (
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: references/pytorch_fsdp2_tutorial.md.
7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
get_model_state_dict/set_model_state_dictwithStateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)- For optimizer:
get_optimizer_state_dict/set_optimizer_state_dict
Avoid:
- Saving DTensor state dicts with plain
torch.saveunless you intentionally convert withDTensor.full_tensor()and manage memory carefully.
References:
references/pytorch_dcp_overview.md(DCP behavior and caveats)references/pytorch_dcp_recipe.mdandreferences/pytorch_dcp_async_recipe.md(end-to-end usage)references/pytorch_fsdp2_tutorial.md(DTensor vs DCP state-dict flows)references/pytorch_examples_fsdp2.md(working checkpoint scripts)
Workflow checklists (copy-paste friendly)
Workflow A: Retrofit FSDP2 into an existing training script
- Launch with
torchrunand initialize the process group. - Set the CUDA device from
LOCAL_RANK; create aDeviceMeshif you need multi-dim parallelism. - Build the model (use
metaif needed), applyfully_shardbottom-up, thenfully_shard(model). - Create the optimizer after sharding so it captures DTensor parameters.
- Use
model(inputs)so hooks run; useset_requires_gradient_syncfor accumulation. - Add DCP save/load via
torch.distributed.checkpointhelpers.
Reference: references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.
Workflow B: Add DCP save/load (minimal pattern)
- Wrap state in
Statefulor assemble state viaget_state_dict. - Call
dcp.save(...)from all ranks to a shared path. - Call
dcp.load(...)and restore withset_state_dict. - Validate any resharding assumptions when loading into a different mesh.
Reference: references/pytorch_dcp_recipe.md.
Debug checklist (what the agent should check first)
- All ranks on distinct GPUs?
If not, verifytorch.cuda.set_device(LOCAL_RANK)and yourtorchrunflags. - Did you accidentally call
forward()directly?
Usemodel(input)or explicitlyunshard()/ register forward. - Is
fully_shard()applied bottom-up?
If only root is sharded, expect worse memory/perf and possible confusion. - Optimizer created at the right time?
Must be built on DTensor parameters after sharding. - Checkpointing path consistent?
- If using DCP, don’t mix with ad-hoc
torch.saveunless you understand conversions. - Be mindful of PyTorch-version compatibility warnings for DCP.
- If using DCP, don’t mix with ad-hoc
Common issues and fixes
- Forward hooks not running → Call
model(inputs)(orunshard()explicitly) instead ofmodel.forward(...). - Optimizer sees non-DTensor params → Create optimizer after all
fully_shardcalls. - Only root module sharded → Apply
fully_shardbottom-up on submodules before the root. - Memory spikes after forward → Set
reshard_after_forward=Truefor more modules. - Gradient accumulation desync → Use
set_requires_gradient_syncinstead of FSDP1’sno_sync().
Reference: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.
Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
init_distributed(): init process group, set devicebuild_model_meta(): model on meta, applyfully_shard, materialize weightsbuild_optimizer(): optimizer created after shardingtrain_step(): forward/backward/step withmodel(inputs)and DTensor-aware patternscheckpoint_save/load(): DCP or distributed state dict helpers
Concrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.
References
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.mdreferences/torchtitan_fsdp_notes.md(optional, production notes)references/ray_train_fsdp2_example.md(optional, integration example)
Source
git clone https://github.com/Orchestra-Research/AI-Research-SKILLs/blob/main/08-distributed-training/pytorch-fsdp2/SKILL.mdView on GitHub Overview
This skill explains how to integrate PyTorch FSDP2 (fully_shard) into a training script, including correct initialization, per-parameter sharding, mixed precision/offload configuration, and distributed checkpointing. It targets models that exceed a single-GPU memory and enables DTensor-based sharding with DeviceMesh for scalable training.
How This Skill Works
FSDP2 uses fully_shard() to shard submodules bottom-up before the root module and relies on DTensor-capable parameter handling. Training is launched with torchrun, the optimizer is created after sharding on DTensor parameters, and checkpoints are saved with Distributed Checkpoint (DCP) rather than raw state_dict saves.
When to Use It
- Your model doesn’t fit on a single GPU (parameters, gradients, and optimizer state exceed memory).
- You want DTensor-based per-parameter sharding for clearer, sharded state dicts compared to FSDP1.
- You plan to compose Data Parallel with Tensor Parallel using DeviceMesh.
- You need distributed checkpointing (DCP) to resume training across nodes or after failures.
- You require mixed precision with offload to manage memory footprint.
Quick Start
- Step 1: Launch the job with torchrun --nproc_per_node <gpus_per_node> and ensure RANK/WORLD_SIZE/LOCAL_RANK are in the environment.
- Step 2: Initialize distributed and set the CUDA device; optionally create a DeviceMesh to describe the data-parallel groups.
- Step 3: Build the model on meta device, apply fully_shard() to submodules (then fully_shard(model)), materialize on CUDA, create the optimizer after sharding, and enable DCP-based checkpointing.
Best Practices
- Launch with torchrun and ensure LOCAL_RANK, RANK, and WORLD_SIZE are visible in the environment.
- Apply fully_shard() bottom-up to submodules before the root module.
- Call model(input) (not model.forward(input)) so FSDP2 hooks run unless unsharded or forward overridden.
- Create the optimizer after sharding so it’s built on DTensor parameters post-shard.
- Checkpoint with Distributed Checkpoint (DCP) or distributed-state-dict helpers instead of torch.save(state_dict) unless you need full tensors.
Example Use Cases
- Train a 2B-parameter Transformer model across 16 GPUs with DTensor per-parameter sharding and a DeviceMesh.
- Fine-tune a large Vision Transformer across 8 GPUs using FSDP2 with mixed precision and model offload.
- Scale a language model across 32 GPUs, leveraging DTensor sharding, Block-wise sharding, and DCP for robust checkpoints.
- Combine DP with Tensor Parallel on a multi-branch model using DeviceMesh to distribute workload across 24 GPUs.
- Resume training from distributed checkpoints after a node outage using DCP across multiple nodes.
Frequently Asked Questions
Related Skills
tensorboard
Orchestra-Research/AI-Research-SKILLs
Visualize training metrics, debug models with histograms, compare experiments, visualize model graphs, and profile performance with TensorBoard - Google's ML visualization toolkit
huggingface-accelerate
Orchestra-Research/AI-Research-SKILLs
Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
deepspeed
Orchestra-Research/AI-Research-SKILLs
Expert guidance for distributed training with DeepSpeed - ZeRO optimization stages, pipeline parallelism, FP16/BF16/FP8, 1-bit Adam, sparse attention
optimizing-attention-flash
Orchestra-Research/AI-Research-SKILLs
Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
ray-train
Orchestra-Research/AI-Research-SKILLs
Distributed training orchestration across clusters. Scales PyTorch/TensorFlow/HuggingFace from laptop to 1000s of nodes. Built-in hyperparameter tuning with Ray Tune, fault tolerance, elastic scaling. Use when training massive models across multiple machines or running distributed hyperparameter sweeps.
ray-data
Orchestra-Research/AI-Research-SKILLs
Scalable data processing for ML workloads. Streaming execution across CPU/GPU, supports Parquet/CSV/JSON/images. Integrates with Ray Train, PyTorch, TensorFlow. Scales from single machine to 100s of nodes. Use for batch inference, data preprocessing, multi-modal data loading, or distributed ETL pipelines.