Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

25 stars

Best use case

Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script is best used when you need a repeatable AI agent workflow instead of a one-off prompt.

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

Teams using Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script should expect a more consistent output, faster repeated execution, less prompt rewriting.

When to use this skill

  • You want a reusable workflow that can be run more than once with consistent structure.

When not to use this skill

  • You only need a quick one-off answer and do not need a reusable workflow.
  • You cannot install or maintain the underlying files, dependencies, or repository context.

Installation

Claude Code / Cursor / Codex

$curl -o ~/.claude/skills/pytorch-fsdp2/SKILL.md --create-dirs "https://raw.githubusercontent.com/ComeOnOliver/skillshub/main/skills/Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2/SKILL.md"

Manual Installation

  1. Download SKILL.md from GitHub
  2. Place it in .claude/skills/pytorch-fsdp2/SKILL.md inside your project
  3. Restart your AI agent — it will auto-discover the skill

How Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script Compares

Feature / AgentSkill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training scriptStandard Approach
Platform SupportNot specifiedLimited / Varies
Context Awareness High Baseline
Installation ComplexityUnknownN/A

Frequently Asked Questions

What does this skill do?

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

Where can I find the source code?

You can find the source code on GitHub using the link provided at the top of the page.

SKILL.md Source

# Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

> FSDP2 in PyTorch is exposed primarily via `torch.distributed.fsdp.fully_shard` and the `FSDPModule` methods it adds in-place to modules. See: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.

---

## When to use this skill

Use FSDP2 when:
- Your model **doesn’t fit** on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is **DTensor-based per-parameter sharding** (more inspectable, simpler sharded state dicts) than FSDP1.  
- You may later compose DP with **Tensor Parallel** using **DeviceMesh**.

Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.

## Alternatives (when FSDP2 is not the best fit)

- **DistributedDataParallel (DDP)**: Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- **FullyShardedDataParallel (FSDP1)**: Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference: `references/pytorch_ddp_notes.md`, `references/pytorch_fsdp1_api.md`.

---

## Contract the agent must follow

1. **Launch with `torchrun`** and set the CUDA device per process (usually via `LOCAL_RANK`).  
2. **Apply `fully_shard()` bottom-up**, i.e., shard submodules (e.g., Transformer blocks) before the root module.  
3. **Call `model(input)`**, not `model.forward(input)`, so the FSDP2 hooks run (unless you explicitly `unshard()` or register the forward method).  
4. **Create the optimizer after sharding** and make sure it is built on the **DTensor parameters** (post-`fully_shard`).  
5. **Checkpoint using Distributed Checkpoint (DCP)** or the distributed-state-dict helpers, not naïve `torch.save(model.state_dict())` unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)

---

## Step-by-step procedure

### 0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use `torchrun --nproc_per_node <gpus_per_node> ...` and ensure `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are visible.

Reference: `references/pytorch_fsdp2_tutorial.md` (launch commands and setup), `references/pytorch_fully_shard_api.md` (user contract).

---

### 1) Initialize distributed and set device
Minimal, correct pattern:
- `dist.init_process_group(backend="nccl")`
- `torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`
- Optionally create a `DeviceMesh` to describe the data-parallel group(s)

Reference: `references/pytorch_device_mesh_tutorial.md` (why DeviceMesh exists & how it manages process groups).

---

### 2) Build model on meta device (recommended for very large models)
For big models, initialize on `meta`, apply sharding, then materialize weights on GPU:
- `with torch.device("meta"): model = ...`
- apply `fully_shard(...)` on submodules, then `fully_shard(model)`
- `model.to_empty(device="cuda")`
- `model.reset_parameters()` (or your init routine)

Reference: `references/pytorch_fsdp2_tutorial.md` (migration guide shows this flow explicitly).

---

### 3) Apply `fully_shard()` bottom-up (wrapping policy = “apply where needed”)
**Do not** only call `fully_shard` on the topmost module.

Recommended sharding pattern for transformer-like models:
- iterate modules, `if isinstance(m, TransformerBlock): fully_shard(m, ...)`
- then `fully_shard(model, ...)`

Why:
- `fully_shard` forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference: `references/pytorch_fully_shard_api.md` (bottom-up requirement and why).

---

### 4) Configure `reshard_after_forward` for memory/perf trade-offs
Default behavior:
- `None` means `True` for non-root modules and `False` for root modules (good default).

Heuristics:
- If you’re memory-bound: keep defaults or force `True` on many blocks.
- If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often `False`).
- Advanced: use an `int` to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference: `references/pytorch_fully_shard_api.md` (full semantics).

---

### 5) Mixed precision & offload (optional but common)
FSDP2 uses:
- `mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)`
- `offload_policy=CPUOffloadPolicy()` if you want CPU offload

Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep `reduce_dtype` aligned with your gradient reduction expectations.
- If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference: `references/pytorch_fully_shard_api.md` (MixedPrecisionPolicy / OffloadPolicy classes).

---

### 6) Optimizer, gradient clipping, accumulation
- Create the optimizer **after** sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
  - use the FSDP2 mechanism (`set_requires_gradient_sync`) instead of FSDP1’s `no_sync()`.

Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference: `references/pytorch_fsdp2_tutorial.md`.

---

### 7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:

**A) Distributed Checkpoint (DCP) — best default**
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces **multiple files** (often at least one per rank) and operates “in place”.

**B) Distributed state dict helpers**
- `get_model_state_dict` / `set_model_state_dict` with `StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)`
- For optimizer: `get_optimizer_state_dict` / `set_optimizer_state_dict`

Avoid:
- Saving DTensor state dicts with plain `torch.save` unless you intentionally convert with `DTensor.full_tensor()` and manage memory carefully.

References:
- `references/pytorch_dcp_overview.md` (DCP behavior and caveats)
- `references/pytorch_dcp_recipe.md` and `references/pytorch_dcp_async_recipe.md` (end-to-end usage)
- `references/pytorch_fsdp2_tutorial.md` (DTensor vs DCP state-dict flows)
- `references/pytorch_examples_fsdp2.md` (working checkpoint scripts)

---

## Workflow checklists (copy-paste friendly)

### Workflow A: Retrofit FSDP2 into an existing training script
- [ ] Launch with `torchrun` and initialize the process group.
- [ ] Set the CUDA device from `LOCAL_RANK`; create a `DeviceMesh` if you need multi-dim parallelism.
- [ ] Build the model (use `meta` if needed), apply `fully_shard` bottom-up, then `fully_shard(model)`.
- [ ] Create the optimizer after sharding so it captures DTensor parameters.
- [ ] Use `model(inputs)` so hooks run; use `set_requires_gradient_sync` for accumulation.
- [ ] Add DCP save/load via `torch.distributed.checkpoint` helpers.

Reference: `references/pytorch_fsdp2_tutorial.md`, `references/pytorch_fully_shard_api.md`, `references/pytorch_device_mesh_tutorial.md`, `references/pytorch_dcp_recipe.md`.

### Workflow B: Add DCP save/load (minimal pattern)
- [ ] Wrap state in `Stateful` or assemble state via `get_state_dict`.
- [ ] Call `dcp.save(...)` from all ranks to a shared path.
- [ ] Call `dcp.load(...)` and restore with `set_state_dict`.
- [ ] Validate any resharding assumptions when loading into a different mesh.

Reference: `references/pytorch_dcp_recipe.md`.

## Debug checklist (what the agent should check first)

1. **All ranks on distinct GPUs?**  
   If not, verify `torch.cuda.set_device(LOCAL_RANK)` and your `torchrun` flags.
2. **Did you accidentally call `forward()` directly?**  
   Use `model(input)` or explicitly `unshard()` / register forward.
3. **Is `fully_shard()` applied bottom-up?**  
   If only root is sharded, expect worse memory/perf and possible confusion.
4. **Optimizer created at the right time?**  
   Must be built on DTensor parameters *after* sharding.
5. **Checkpointing path consistent?**  
   - If using DCP, don’t mix with ad-hoc `torch.save` unless you understand conversions.
   - Be mindful of PyTorch-version compatibility warnings for DCP.

---

## Common issues and fixes

- **Forward hooks not running** → Call `model(inputs)` (or `unshard()` explicitly) instead of `model.forward(...)`.
- **Optimizer sees non-DTensor params** → Create optimizer after all `fully_shard` calls.
- **Only root module sharded** → Apply `fully_shard` bottom-up on submodules before the root.
- **Memory spikes after forward** → Set `reshard_after_forward=True` for more modules.
- **Gradient accumulation desync** → Use `set_requires_gradient_sync` instead of FSDP1’s `no_sync()`.

Reference: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.

---

## Minimal reference implementation outline (agent-friendly)

The coding agent should implement a script with these labeled blocks:

- `init_distributed()`: init process group, set device
- `build_model_meta()`: model on meta, apply `fully_shard`, materialize weights
- `build_optimizer()`: optimizer created after sharding
- `train_step()`: forward/backward/step with `model(inputs)` and DTensor-aware patterns
- `checkpoint_save/load()`: DCP or distributed state dict helpers

Concrete examples live in `references/pytorch_examples_fsdp2.md` and the official tutorial reference.

---

## References
- `references/pytorch_fsdp2_tutorial.md`
- `references/pytorch_fully_shard_api.md`
- `references/pytorch_ddp_notes.md`
- `references/pytorch_fsdp1_api.md`
- `references/pytorch_device_mesh_tutorial.md`
- `references/pytorch_tp_tutorial.md`
- `references/pytorch_dcp_overview.md`
- `references/pytorch_dcp_recipe.md`
- `references/pytorch_dcp_async_recipe.md`
- `references/pytorch_examples_fsdp2.md`
- `references/torchtitan_fsdp_notes.md` (optional, production notes)
- `references/ray_train_fsdp2_example.md` (optional, integration example)

Related Skills

training-machine-learning-models

25
from ComeOnOliver/skillshub

Build train machine learning models with automated workflows. Analyzes datasets, selects model types (classification, regression), configures parameters, trains with cross-validation, and saves model artifacts. Use when asked to "train model" or "evalua... Trigger with relevant phrases based on skill purpose.

pytorch-model-trainer

25
from ComeOnOliver/skillshub

Pytorch Model Trainer - Auto-activating skill for ML Training. Triggers on: pytorch model trainer, pytorch model trainer Part of the ML Training skill category.

managing-database-sharding

25
from ComeOnOliver/skillshub

Process use when you need to work with database sharding. This skill provides horizontal sharding strategies with comprehensive guidance and automation. Trigger with phrases like "implement sharding", "shard database", or "distribute data".

distributed-training-setup

25
from ComeOnOliver/skillshub

Distributed Training Setup - Auto-activating skill for ML Training. Triggers on: distributed training setup, distributed training setup Part of the ML Training skill category.

when-training-neural-networks-use-flow-nexus-neural

25
from ComeOnOliver/skillshub

This SOP provides a systematic workflow for training and deploying neural networks using Flow Nexus platform with distributed E2B sandboxes. It covers architecture selection, distributed training, ...

when-debugging-ml-training-use-ml-training-debugger

25
from ComeOnOliver/skillshub

Debug ML training issues and optimize performance including loss divergence, overfitting, and slow convergence

ml-training-debugger

25
from ComeOnOliver/skillshub

Diagnose machine learning training failures including loss divergence, mode collapse, gradient issues, architecture problems, and optimization failures. This skill spawns a specialist ML debugging ...

agentdb-reinforcement-learning-training

25
from ComeOnOliver/skillshub

Train AI agents using AgentDB's 9 reinforcement learning algorithms including Q-Learning, DQN, PPO, and Actor-Critic. Build self-learning agents, implement RL training loops with experience replay, and deploy optimized models to production.

pytorch-patterns

25
from ComeOnOliver/skillshub

PyTorch深度学习模式与最佳实践,用于构建稳健、高效且可复现的训练流程、模型架构和数据加载。

PyTorch

25
from ComeOnOliver/skillshub

## Overview

TorchTitan - PyTorch Native Distributed LLM Pretraining

25
from ComeOnOliver/skillshub

## Quick start

torchforge: PyTorch-Native Agentic RL Library

25
from ComeOnOliver/skillshub

torchforge is Meta's PyTorch-native RL library that separates infrastructure concerns from algorithm concerns. It enables rapid RL research by letting you focus on algorithms while handling distributed training, inference, and weight sync automatically.