数据并行是最早、最直观的分布式训练策略。核心思想把数据切开每张卡处理一部分然后同步梯度。但朴素的数据并行有严重的内存冗余问题ZeROZero Redundancy Optimizer通过三个阶段逐步消除冗余是 DeepSpeed 和 FSDP 的核心。回顾单卡训练的显存去哪了在讲分布式之前先搞清楚单卡训练的显存构成。以 7B 参数模型、FP16 混合精度训练为例┌─────────────────────────────────────────────────────────┐ │ 单卡显存占用 │ ├─────────────────┬───────────────┬───────────────────────┤ │ 模型参数 │ 7B × 2B │ 14 GB │ │ 梯度 │ 7B × 2B │ 14 GB │ │ 优化器状态 │ 7B × 12B │ 84 GB │ │ (Adam: FP32副本 │ 7B × 4B 28 GB │ │ 一阶动量 m │ 7B × 4B 28 GB │ │ 二阶动量 v │ 7B × 4B 28 GB) │ │ 激活值 │ 变动 │ ~10-50 GB │ ├─────────────────┼───────────────┼───────────────────────┤ │ 合计 │ │ ~122-162 GB │ └─────────────────┴───────────────┴───────────────────────┘关键观察优化器状态占了总显存的 68%84 / 122。这是 ZeRO 重点优化的目标。想知道为什么优化器占用了这么多显存可以点这里朴素数据并行Data Parallelism, DP工作原理训练数据 / | \ / | \ ┌────┐ ┌────┐ ┌────┐ │GPU0│ │GPU1│ │GPU2│ 每张卡持有完整模型副本 │完整│ │完整│ │完整│ 每张卡处理不同的数据 batch │模型│ │模型│ │模型│ └──┬─┘ └──┬─┘ └──┬─┘ │ │ │ 梯度0 梯度1 梯度2 各自计算梯度 │ │ │ └──────┼──────┘ ▼ AllReduce 所有梯度求平均 │ ┌──────┼──────┐ ▼ ▼ ▼ 更新0 更新1 更新2 各自用平均梯度更新模型 (相同) (相同) (相同) 更新后模型仍然一致每张 GPU 持有完整的模型副本处理不同的数据子集计算完梯度后做 AllReduce 求平均然后各自更新参数。因为初始参数相同、梯度相同、学习率相同所以更新后的参数也相同——保持同步。PyTorch DDPDistributed Data ParallelDDP 是 DP 的生产级实现比旧版DataParallel单进程多线程更稳定、更高效多进程。importtorchimporttorch.distributedasdistfromtorch.nn.parallelimportDistributedDataParallelasDDP# 初始化torchrun 自动设置环境变量dist.init_process_group(backendnccl)local_rankint(os.environ[LOCAL_RANK])torch.cuda.set_device(local_rank)# 包装模型modelMyModel().to(local_rank)modelDDP(model,device_ids[local_rank])# 训练循环和单卡完全一样forbatchindataloader:lossmodel(batch).loss loss.backward()# DDP 自动在 backward 中插入 AllReduceoptimizer.step()optimizer.zero_grad()DDP 的核心优化梯度桶化Gradient Bucketing。它不会等所有梯度都算完再 AllReduce而是把梯度分成若干 bucket边算边同步——和反向传播的计算重叠隐藏通信延迟。反向传播: [layer_n 梯度] [layer_n-1 梯度] [layer_n-2 梯度] ... AllReduce: [bucket 1 ] [bucket 2 ] ... ↑ 计算和通信重叠overlapDDP 的致命问题显存冗余DDP 中每张卡都保存完整的模型参数、梯度、优化器状态。4 卡 DDP 训练 7B 模型 每张卡: 14 GB (参数) 14 GB (梯度) 84 GB (优化器) 112 GB 4 张卡总计: 112 GB × 4 448 GB 但其中真正不同的只有梯度因为数据不同 参数和优化器状态在每张卡上完全相同 → 3/4 是浪费的冗余这就是 ZeRO 要解决的问题。ZeRO零冗余优化器ZeRO 的核心思想把冗余的数据切分到不同 GPU 上每个 GPU 只保存 1/N 的状态。ZeRO 分三个阶段逐步切分更多内容切分内容 显存节省 ┌──────────────┐ ┌──────────┐ DDP (无切分) │ 参数 梯度 优化 │ │ 0% │ ├──────────────┤ ├──────────┤ ZeRO-1 │ 参数 梯度 [优化]│ → 切分 │ ~4x │ ├──────────────┤ ├──────────┤ ZeRO-2 │ 参数 [梯度][优化]│ → 切分 │ ~8x │ ├──────────────┤ ├──────────┤ ZeRO-3 │[参数][梯度][优化]│ → 全切分 │ ~N x │ └──────────────┘ └──────────┘ZeRO Stage 1切分优化器状态原理Adam 的优化器状态FP32 参数副本、一阶动量 m、二阶动量 v占 84 GB是显存的大头。把它平均分成 N 份每个 GPU 只保存 1/N。GPU 0 GPU 1 GPU 2 GPU 3 ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ 模型参数 (FP16) │ 完整 │ │ 完整 │ │ 完整 │ │ 完整 │ ← 未切分 ├──────┤ ├──────┤ ├──────┤ ├──────┤ 梯度 (FP16) │ 完整 │ │ 完整 │ │ 完整 │ │ 完整 │ ← 未切分 ├──────┤ ├──────┤ ├──────┤ ├──────┤ 优化器状态 │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ ← 切分! (FP32参数mv) │ 21GB │ │ 21GB │ │ 21GB │ │ 21GB │ └──────┘ └──────┘ └──────┘ └──────┘训练流程变化1. 前向传播和 DDP 一样每张卡有完整参数 2. 反向传播和 DDP 一样每张卡算出完整梯度 3. 梯度同步ReduceScatter不是 AllReduce → 每个 GPU 只保留自己负责那 1/N 参数的梯度 4. 参数更新每个 GPU 只更新自己负责的 1/N 参数 5. 参数同步AllGather把更新后的参数收集给所有人DP/DDP 用 AllReduce: GPU 0: [全部梯度] ─┐ GPU 1: [全部梯度] ─┼─ AllReduce → 每个 GPU 拿到 [全部平均梯度] GPU 2: [全部梯度] ─┤ 然后各自更新全部参数 GPU 3: [全部梯度] ─┘ ZeRO-1 用 ReduceScatter AllGather: GPU 0: [全部梯度] ─┐ GPU 1: [全部梯度] ─┼─ ReduceScatter → GPU 0 拿 [梯度 1/4] GPU 2: [全部梯度] ─┤ GPU 1 拿 [梯度 2/4] GPU 3: [全部梯度] ─┘ GPU 2 拿 [梯度 3/4] GPU 3 拿 [梯度 4/4] 各自更新自己负责的 1/4 参数 ↓ AllGather → 每个 GPU 拿到更新后的完整参数通信量分析步骤DDP (AllReduce)ZeRO-1 (ReduceScatter AllGather)通信量2 × 参数量参数量 参数量 2 × 参数量结论通信量完全相同但显存省了 ~4x这是 ZeRO 最精妙的地方——Stage 1 在不增加通信量的情况下把优化器状态切分了。显存计算4 卡7B 模型FP16每张卡显存 参数(完整) 梯度(完整) 优化器状态(1/4) 14 GB 14 GB 21 GB 49 GB 对比 DDP: 14 14 84 112 GB → 节省 56%ZeRO Stage 2切分优化器状态 梯度原理既然梯度在 ReduceScatter 之后每个 GPU 只需要 1/N那就只保存 1/N 的梯度。GPU 0 GPU 1 GPU 2 GPU 3 ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ 模型参数 (FP16) │ 完整 │ │ 完整 │ │ 完整 │ │ 完整 │ ← 未切分 ├──────┤ ├──────┤ ├──────┤ ├──────┤ 梯度 (FP16) │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ ← 也切分了! ├──────┤ ├──────┤ ├──────┤ ├──────┤ 优化器状态 │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ ← 切分 └──────┘ └──────┘ └──────┘ └──────┘关键实现细节反向传播时梯度仍然是完整计算的因为参数是完整的。但每算完一层的梯度就立刻做 ReduceScatter把这一层的梯度分发到负责的 GPU释放掉其他 GPU 上的副本。这样峰值显存更低。反向传播过程 Layer N: 计算梯度 → ReduceScatter → 只保留自己负责的 1/N 梯度 → 释放其余 Layer N-1: 计算梯度 → ReduceScatter → 只保留自己负责的 1/N 梯度 → 释放其余 ... Layer 0: 计算梯度 → ReduceScatter → 只保留自己负责的 1/N 梯度 → 释放其余通信量和 Stage 1 相同ReduceScatter 的通信量不变。显存计算4 卡7B 模型FP16每张卡显存 参数(完整) 梯度(1/4) 优化器状态(1/4) 14 GB 3.5 GB 21 GB 38.5 GB 对比 DDP: 112 GB → 节省 66% 对比 Stage 1: 49 GB → 再省 21%ZeRO Stage 3切分优化器状态 梯度 参数原理参数也切成 N 份每个 GPU 只保存 1/N 的参数。这是最激进的切分。GPU 0 GPU 1 GPU 2 GPU 3 ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ 模型参数 (FP16) │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ ← 全切分了! ├──────┤ ├──────┤ ├──────┤ ├──────┤ 梯度 (FP16) │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ ├──────┤ ├──────┤ ├──────┤ ├──────┤ 优化器状态 │ 1/4 │ │ 1/4 │ │ 1/4 │ │ 1/4 │ └──────┘ └──────┘ └──────┘ └──────┘问题前向/反向传播时需要完整参数但每个 GPU 只有 1/N。怎么办解决方案按需 AllGather。前向传播每一层: 1. AllGather 该层的参数从 4 个 GPU 收集完整参数 2. 用完整参数做前向计算 3. 丢弃非自己负责的参数释放显存 反向传播每一层: 1. AllGather 该层的参数 2. 用完整参数做反向计算 3. ReduceScatter 梯度分发给负责的 GPU 4. 丢弃非自己负责的参数Layer 3: AllGather → 前向 → 丢弃 | AllGather → 反向 → ReduceScatter → 丢弃 Layer 2: AllGather → 前向 → 丢弃 | AllGather → 反向 → ReduceScatter → 丢弃 Layer 1: AllGather → 前向 → 丢弃 | AllGather → 反向 → ReduceScatter → 丢弃 Layer 0: AllGather → 前向 → 丢弃 | AllGather → 反向 → ReduceScatter → 丢弃通信量分析阶段通信量前向传播每层 AllGather 参数量 → 总计 参数量 × 层数反向传播每层 AllGather ReduceScatter → 总计 2 × 参数量 × 层数参数更新ReduceScatter 梯度 AllGather 参数 → 2 × 参数量总计~3 × 参数量 × 层数Stage 3 的通信量远大于 Stage 1/2乘以层数这就是代价。显存计算4 卡7B 模型FP16每张卡显存 参数(1/4) 梯度(1/4) 优化器状态(1/4) 临时完整参数(前向/反向) 3.5 GB 3.5 GB 21 GB 14 GB (临时) ≈ 42 GB (峰值含临时参数) 但参数和状态部分随 GPU 数量线性下降 8 卡时: 1.75 1.75 10.5 14 28 GB 16 卡时: 0.875 0.875 5.25 14 21 GB注意临时参数14 GB不随 GPU 数量变化这是 Stage 3 的显存下限。可以通过参数预取prefetch和梯度检查点gradient checkpointing来优化。ZeRO 三阶段完整对比显存占用7B 模型FP164 卡: DDP: ████████████████████████████████████████████████████ 112 GB/卡 ZeRO-1: █████████████████████████ 49 GB/卡 ZeRO-2: ███████████████████ 38.5 GB/卡 ZeRO-3: ███████████████████ 42 GB/卡 (峰值含临时参数) 通信量每步: DDP: ████████████████ 2M (AllReduce) ZeRO-1: ████████████████ 2M (ReduceScatter AllGather) ZeRO-2: ████████████████ 2M (逐层 ReduceScatter) ZeRO-3: ██████████████████████████████████████████ ~3M × 层数DDPZeRO-1ZeRO-2ZeRO-3切分内容无优化器状态优化器梯度优化器梯度参数通信量2M2M2M3M × 层数通信原语AllReduceReduceScatter AllGather逐层 ReduceScatter逐层 AllGather ReduceScatter显存节省无~4x~8x~N x随 GPU 数线性适用场景模型能放进单卡模型参数能放进单卡同左梯度也放不下时模型参数都放不下时实现难度低低中高DeepSpeed 实战代码基本使用ZeRO-2# ds_config.json{train_batch_size:32,gradient_accumulation_steps:4,fp16:{enabled:true},zero_optimization:{stage:2,offload_optimizer:{device:none//或cpu启用 ZeRO-Offload},allgather_partitions:true,allgather_bucket_size:5e8,reduce_scatter:true,reduce_bucket_size:5e8,overlap_comm:true//通信和计算重叠}}# train.pyimportdeepspeedimportargparse# DeepSpeed 会自动解析 configparserargparse.ArgumentParser()parserdeepspeed.add_config_arguments(parser)argsparser.parse_args()modelMyModel()optimizertorch.optim.Adam(model.parameters(),lr1e-4)# 一行代码启用 ZeROmodel,optimizer,_,_deepspeed.initialize(argsargs,modelmodel,optimizeroptimizer,model_parametersmodel.parameters(),configds_config.json)# 训练循环和单卡几乎一样forbatchindataloader:lossmodel(batch).loss model.backward(loss)# 自动处理梯度 ZeRO 切分model.step()# 自动处理参数更新 AllGather启动deepspeed--num_gpus4train.py--deepspeed_configds_config.json# 或多机deepspeed--num_gpus4--num_nodes2--hostfilehostfile train.pyZeRO-3 配置{zero_optimization:{stage:3,stage3_max_live_parameters:1e9,stage3_max_reuse_distance:1e9,stage3_prefetch_bucket_size:5e7,stage3_param_persistence_threshold:1e5,reduce_bucket_size:5e8,sub_group_size:1e9,overlap_comm:true,offload_optimizer:{device:cpu,pin_memory:true},offload_param:{device:cpu,pin_memory:true}}}ZeRO-Offload把显存压力转移到 CPU 内存正常 ZeRO-3: GPU 显存: [参数 1/N] [梯度 1/N] [优化器 1/N] [临时参数] ZeRO-Offload: GPU 显存: [临时参数] [计算中的激活值] CPU 内存: [参数 1/N] [梯度 1/N] [优化器 1/N] ← 移到 CPU CPU 做 Adam 更新 → 通过 PCIe 传回 GPU适用场景GPU 显存实在不够时比如单卡训练 7B 模型用 CPU 内存换 GPU 显存。代价是 PCIe 带宽 (~32 GB/s) 远慢于 GPU 显存带宽 (~2 TB/s)训练速度会降 2-5 倍。PyTorch FSDPFully Sharded Data ParallelFSDP 是 PyTorch 原生实现的 ZeRO-3不需要 DeepSpeed 依赖。FSDP2PyTorch 2.x基于 DTensor更加灵活。fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDPfromtorch.distributed.fsdp.wrapimporttransformer_auto_wrap_policyfromfunctoolsimportpartial# 定义哪些模块要被独立切分通常是 Transformer Layerauto_wrap_policypartial(transformer_auto_wrap_policy,transformer_layer_cls{TransformerBlock})# 包装模型modelFSDP(model,auto_wrap_policyauto_wrap_policy,sharding_strategyShardingStrategy.FULL_SHARD,# ZeRO-3# sharding_strategyShardingStrategy.SHARD_GRAD_OP, # ZeRO-2# sharding_strategyShardingStrategy.NO_SHARD, # DDPmixed_precisionMixedPrecision(param_dtypetorch.float16,reduce_dtypetorch.float16,buffer_dtypetorch.float16,),device_idtorch.cuda.current_device(),)# 训练循环和 DDP 完全一样forbatchindataloader:lossmodel(batch).loss loss.backward()# FSDP 自动处理 AllGather ReduceScatteroptimizer.step()optimizer.zero_grad()FSDP vs DeepSpeed 对比FSDPDeepSpeed维护方PyTorch 官方Microsoft集成度PyTorch 原生无需额外依赖需要pip install deepspeedZeRO 阶段Stage 2 (SHARD_GRAD_OP), Stage 3 (FULL_SHARD)Stage 1/2/3 都支持ZeRO-Offload有限支持完整支持MoE 支持无有灵活性与 PyTorch 生态无缝集成功能更全但耦合度高三阶段通信量总结N 个 GPU模型参数量 M阶段前向通信反向通信更新通信每步总通信DDP0AllReduce 2M02MZeRO-10ReduceScatter MAllGather M2MZeRO-20逐层 ReduceScatter MAllGather M2MZeRO-3逐层 AllGather M×L逐层 AllGatherReduceScatter 2M×LAllGather M3M×L ML 模型层数。可以看到 Stage 1 和 Stage 2 的通信量和 DDP 完全相同但显存大幅减少——这就是为什么它们是首选方案。本课小结概念要点DDP完整模型副本 梯度 AllReduce简单但有显存冗余ZeRO-1切分优化器状态省 ~4x 显存通信量不变ZeRO-2切分梯度省 ~8x 显存通信量不变ZeRO-3切分参数显存随 GPU 数线性下降但通信量 ×层数ZeRO-Offload把优化器/参数移到 CPU 内存用 PCIe 带宽换显存FSDPPyTorch 原生 ZeRO-3 实现无需 DeepSpeed 依赖自检DDP 中每张卡保存的优化器状态是一样的吗答是的完全相同——这就是冗余ZeRO-1 和 DDP 的通信量一样为什么 ZeRO-1 更好答通信量相同但显存减少 ~4x因为优化器状态不再冗余ZeRO-3 的通信量为什么乘以层数答每层前向/反向都需要 AllGather 完整参数训练 70B 模型4 卡 A100 80GB至少需要 ZeRO 几答70B × 18 bytes ≈ 1260 GB4 卡 × 80 GB 320 GB需要 ZeRO-3 Offload或者更多卡