Get the FREE Ultimate OpenClaw Setup Guide →
npx machina-cli add skill Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2 --openclaw
Files (1)
SKILL.md
10.7 KB

Skill: Use PyTorch FSDP2 (fully_shard) correctly in a training script

This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

FSDP2 in PyTorch is exposed primarily via torch.distributed.fsdp.fully_shard and the FSDPModule methods it adds in-place to modules. See: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.


When to use this skill

Use FSDP2 when:

  • Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
  • You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
  • You may later compose DP with Tensor Parallel using DeviceMesh.

Avoid (or be careful) if:

  • You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
  • You’re forced onto older PyTorch versions without the FSDP2 stack.

Alternatives (when FSDP2 is not the best fit)

  • DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
  • FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference: references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.


Contract the agent must follow

  1. Launch with torchrun and set the CUDA device per process (usually via LOCAL_RANK).
  2. Apply fully_shard() bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.
  3. Call model(input), not model.forward(input), so the FSDP2 hooks run (unless you explicitly unshard() or register the forward method).
  4. Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-fully_shard).
  5. Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve torch.save(model.state_dict()) unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)


Step-by-step procedure

0) Version & environment sanity

  • Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
  • Use torchrun --nproc_per_node <gpus_per_node> ... and ensure RANK, WORLD_SIZE, LOCAL_RANK are visible.

Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).


1) Initialize distributed and set device

Minimal, correct pattern:

  • dist.init_process_group(backend="nccl")
  • torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  • Optionally create a DeviceMesh to describe the data-parallel group(s)

Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).


2) Build model on meta device (recommended for very large models)

For big models, initialize on meta, apply sharding, then materialize weights on GPU:

  • with torch.device("meta"): model = ...
  • apply fully_shard(...) on submodules, then fully_shard(model)
  • model.to_empty(device="cuda")
  • model.reset_parameters() (or your init routine)

Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).


3) Apply fully_shard() bottom-up (wrapping policy = “apply where needed”)

Do not only call fully_shard on the topmost module.

Recommended sharding pattern for transformer-like models:

  • iterate modules, if isinstance(m, TransformerBlock): fully_shard(m, ...)
  • then fully_shard(model, ...)

Why:

  • fully_shard forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).


4) Configure reshard_after_forward for memory/perf trade-offs

Default behavior:

  • None means True for non-root modules and False for root modules (good default).

Heuristics:

  • If you’re memory-bound: keep defaults or force True on many blocks.
  • If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often False).
  • Advanced: use an int to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference: references/pytorch_fully_shard_api.md (full semantics).


5) Mixed precision & offload (optional but common)

FSDP2 uses:

  • mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
  • offload_policy=CPUOffloadPolicy() if you want CPU offload

Rules of thumb:

  • Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
  • Keep reduce_dtype aligned with your gradient reduction expectations.
  • If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).


6) Optimizer, gradient clipping, accumulation

  • Create the optimizer after sharding so it holds DTensor params.
  • If you need gradient accumulation / no_sync:
    • use the FSDP2 mechanism (set_requires_gradient_sync) instead of FSDP1’s no_sync().

Gradient clipping:

  • Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference: references/pytorch_fsdp2_tutorial.md.


7) Checkpointing: prefer DCP or distributed state dict helpers

Two recommended approaches:

A) Distributed Checkpoint (DCP) — best default

  • DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
  • DCP produces multiple files (often at least one per rank) and operates “in place”.

B) Distributed state dict helpers

  • get_model_state_dict / set_model_state_dict with StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
  • For optimizer: get_optimizer_state_dict / set_optimizer_state_dict

Avoid:

  • Saving DTensor state dicts with plain torch.save unless you intentionally convert with DTensor.full_tensor() and manage memory carefully.

References:

  • references/pytorch_dcp_overview.md (DCP behavior and caveats)
  • references/pytorch_dcp_recipe.md and references/pytorch_dcp_async_recipe.md (end-to-end usage)
  • references/pytorch_fsdp2_tutorial.md (DTensor vs DCP state-dict flows)
  • references/pytorch_examples_fsdp2.md (working checkpoint scripts)

Workflow checklists (copy-paste friendly)

Workflow A: Retrofit FSDP2 into an existing training script

  • Launch with torchrun and initialize the process group.
  • Set the CUDA device from LOCAL_RANK; create a DeviceMesh if you need multi-dim parallelism.
  • Build the model (use meta if needed), apply fully_shard bottom-up, then fully_shard(model).
  • Create the optimizer after sharding so it captures DTensor parameters.
  • Use model(inputs) so hooks run; use set_requires_gradient_sync for accumulation.
  • Add DCP save/load via torch.distributed.checkpoint helpers.

Reference: references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.

Workflow B: Add DCP save/load (minimal pattern)

  • Wrap state in Stateful or assemble state via get_state_dict.
  • Call dcp.save(...) from all ranks to a shared path.
  • Call dcp.load(...) and restore with set_state_dict.
  • Validate any resharding assumptions when loading into a different mesh.

Reference: references/pytorch_dcp_recipe.md.

Debug checklist (what the agent should check first)

  1. All ranks on distinct GPUs?
    If not, verify torch.cuda.set_device(LOCAL_RANK) and your torchrun flags.
  2. Did you accidentally call forward() directly?
    Use model(input) or explicitly unshard() / register forward.
  3. Is fully_shard() applied bottom-up?
    If only root is sharded, expect worse memory/perf and possible confusion.
  4. Optimizer created at the right time?
    Must be built on DTensor parameters after sharding.
  5. Checkpointing path consistent?
    • If using DCP, don’t mix with ad-hoc torch.save unless you understand conversions.
    • Be mindful of PyTorch-version compatibility warnings for DCP.

Common issues and fixes

  • Forward hooks not running → Call model(inputs) (or unshard() explicitly) instead of model.forward(...).
  • Optimizer sees non-DTensor params → Create optimizer after all fully_shard calls.
  • Only root module sharded → Apply fully_shard bottom-up on submodules before the root.
  • Memory spikes after forward → Set reshard_after_forward=True for more modules.
  • Gradient accumulation desync → Use set_requires_gradient_sync instead of FSDP1’s no_sync().

Reference: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.


Minimal reference implementation outline (agent-friendly)

The coding agent should implement a script with these labeled blocks:

  • init_distributed(): init process group, set device
  • build_model_meta(): model on meta, apply fully_shard, materialize weights
  • build_optimizer(): optimizer created after sharding
  • train_step(): forward/backward/step with model(inputs) and DTensor-aware patterns
  • checkpoint_save/load(): DCP or distributed state dict helpers

Concrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.


References

  • references/pytorch_fsdp2_tutorial.md
  • references/pytorch_fully_shard_api.md
  • references/pytorch_ddp_notes.md
  • references/pytorch_fsdp1_api.md
  • references/pytorch_device_mesh_tutorial.md
  • references/pytorch_tp_tutorial.md
  • references/pytorch_dcp_overview.md
  • references/pytorch_dcp_recipe.md
  • references/pytorch_dcp_async_recipe.md
  • references/pytorch_examples_fsdp2.md
  • references/torchtitan_fsdp_notes.md (optional, production notes)
  • references/ray_train_fsdp2_example.md (optional, integration example)

Source

git clone https://github.com/Orchestra-Research/AI-Research-SKILLs/blob/main/08-distributed-training/pytorch-fsdp2/SKILL.mdView on GitHub

Overview

This skill explains how to integrate PyTorch FSDP2 (fully_shard) into a training script, including correct initialization, per-parameter sharding, mixed precision/offload configuration, and distributed checkpointing. It targets models that exceed a single-GPU memory and enables DTensor-based sharding with DeviceMesh for scalable training.

How This Skill Works

FSDP2 uses fully_shard() to shard submodules bottom-up before the root module and relies on DTensor-capable parameter handling. Training is launched with torchrun, the optimizer is created after sharding on DTensor parameters, and checkpoints are saved with Distributed Checkpoint (DCP) rather than raw state_dict saves.

When to Use It

  • Your model doesn’t fit on a single GPU (parameters, gradients, and optimizer state exceed memory).
  • You want DTensor-based per-parameter sharding for clearer, sharded state dicts compared to FSDP1.
  • You plan to compose Data Parallel with Tensor Parallel using DeviceMesh.
  • You need distributed checkpointing (DCP) to resume training across nodes or after failures.
  • You require mixed precision with offload to manage memory footprint.

Quick Start

  1. Step 1: Launch the job with torchrun --nproc_per_node <gpus_per_node> and ensure RANK/WORLD_SIZE/LOCAL_RANK are in the environment.
  2. Step 2: Initialize distributed and set the CUDA device; optionally create a DeviceMesh to describe the data-parallel groups.
  3. Step 3: Build the model on meta device, apply fully_shard() to submodules (then fully_shard(model)), materialize on CUDA, create the optimizer after sharding, and enable DCP-based checkpointing.

Best Practices

  • Launch with torchrun and ensure LOCAL_RANK, RANK, and WORLD_SIZE are visible in the environment.
  • Apply fully_shard() bottom-up to submodules before the root module.
  • Call model(input) (not model.forward(input)) so FSDP2 hooks run unless unsharded or forward overridden.
  • Create the optimizer after sharding so it’s built on DTensor parameters post-shard.
  • Checkpoint with Distributed Checkpoint (DCP) or distributed-state-dict helpers instead of torch.save(state_dict) unless you need full tensors.

Example Use Cases

  • Train a 2B-parameter Transformer model across 16 GPUs with DTensor per-parameter sharding and a DeviceMesh.
  • Fine-tune a large Vision Transformer across 8 GPUs using FSDP2 with mixed precision and model offload.
  • Scale a language model across 32 GPUs, leveraging DTensor sharding, Block-wise sharding, and DCP for robust checkpoints.
  • Combine DP with Tensor Parallel on a multi-branch model using DeviceMesh to distribute workload across 24 GPUs.
  • Resume training from distributed checkpoints after a node outage using DCP across multiple nodes.

Frequently Asked Questions

Add this skill to your agents

Related Skills

tensorboard

Orchestra-Research/AI-Research-SKILLs

Visualize training metrics, debug models with histograms, compare experiments, visualize model graphs, and profile performance with TensorBoard - Google's ML visualization toolkit

huggingface-accelerate

Orchestra-Research/AI-Research-SKILLs

Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.

deepspeed

Orchestra-Research/AI-Research-SKILLs

Expert guidance for distributed training with DeepSpeed - ZeRO optimization stages, pipeline parallelism, FP16/BF16/FP8, 1-bit Adam, sparse attention

optimizing-attention-flash

Orchestra-Research/AI-Research-SKILLs

Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.

ray-train

Orchestra-Research/AI-Research-SKILLs

Distributed training orchestration across clusters. Scales PyTorch/TensorFlow/HuggingFace from laptop to 1000s of nodes. Built-in hyperparameter tuning with Ray Tune, fault tolerance, elastic scaling. Use when training massive models across multiple machines or running distributed hyperparameter sweeps.

ray-data

Orchestra-Research/AI-Research-SKILLs

Scalable data processing for ML workloads. Streaming execution across CPU/GPU, supports Parquet/CSV/JSON/images. Integrates with Ray Train, PyTorch, TensorFlow. Scales from single machine to 100s of nodes. Use for batch inference, data preprocessing, multi-modal data loading, or distributed ETL pipelines.

Sponsor this space

Reach thousands of developers