分布式训练为何艰难:DTensor、正确性与抽象的代价
Runway 工程师用四次失败的并行化尝试把分布式训练的正确性难题拆解得明明白白,还给出了 DTensor 在规模下吃掉的 MFU 和编译陷阱的一手数据,做大模型训练的人值得从头读到尾。
本文探讨了分布式训练中的正确性难题及DTensor方案的权衡。DTensor通过为张量附加放置元数据(如Shard、Replicate)来自动管理通信,确保计算正确性。文章通过一个并行化案例,展示了不使用DTensor时手动处理梯度计算可能引发的静默错误(如梯度为零或倍增),从而凸显了正确性的复杂性。然而,DTensor的抽象层在简化开发的同时,也可能在大规模场景下引入隐性的性能开销。因此,在设计分布式系统时,需要在抽象的开发便利与底层的计算效率之间做出审慎权衡。
DTensor 通过为每个张量附加放置元数据,使分布式训练保持正确性。在大规模场景下,它也会引入一些成本,如果你不针对这些成本设计,它们会悄然侵蚀吞吐量。
为什么分布式训练很难
当你将一个张量分片到某个进程组上时,流经该分片的每个梯度都必须与你在单个 GPU 上得到的梯度一致。手动正确实现这一点意味着需要在模型中分散放置集合通信操作,在算子内部管理放置假设,并维护 FSDP、张量并行和流水线并行的专用代码路径。它非常容易出错,而且这些 bug 几乎总是静默的。
DTensor(PyTorch 的分布式张量)试图将这些问题统一起来。每个张量携带一小段描述其放置的元数据:Replicate、Shard(dim) 或 Partial(sum)。然后,算子会自动传播放置信息,并在张量需要在不同布局之间移动时插入正确的集合通信操作。
理论上,这能为你提供更清晰的抽象和更安全的扩展。实践中,它解决了一类问题,却又创造了另一类问题。
四次尝试并行化一个三行模块
说明 DTensor 必要性的最清晰方法是尝试替代方案。考虑这个玩具级的扩散 Transformer 调制模块。每个 token 属于批次中的一个样本,对于每个样本,我们都有一个条件嵌入(时间步、类别标签、文本特征……),它需要对该样本的 token 进行调制。该模块将条件嵌入投影为每个通道的缩放因子,并将其乘入 token 激活值。这是一个简化版的 AdaLN 调制模式(不含移位和归一化):
class Modulation(torch.nn.Module): def __init__(self, hidden_dim: int): super().__init__() # 可学习投影:条件嵌入 -> 每个通道的缩放因子。 self.weight = torch.nn.Parameter( torch.randn(hidden_dim, hidden_dim, device=torch.cuda.current_device()) )
```python def forward( self, tokens: torch.Tensor, # [numtokens, hiddendim] cond: torch.Tensor, # [numsamples, hiddendim] sampleids: torch.Tensor, # [numtokens] -- which sample each token belongs to ) -> torch.Tensor: # 1. One scale vector per sample. # 1. 每个样本一个缩放向量。 persamplescale = torch.nn.functional.linear(cond, self.weight) # 2. Broadcast each sample's scale out to its tokens. # 2. 将每个样本的缩放广播到其 token 上。 pertokenscale = persamplescale.indexselect(0, sampleids) # 3. Modulate. # 3. 调制。 return pertokenscale tokens ```
The goal: shard tokens across a process group, compute locally, gather the result back, and produce four things that match the single-GPU baseline exactly: the forward result and the gradients on tokens, cond and self.weight.
目标是:将 tokens 分片到进程组中,本地计算,收集结果回来,并产生四个与单 GPU 基线完全匹配的东西:前向结果以及 tokens、cond 和 self.weight 上的梯度。
Getting the forward result right is easy. Getting the gradients right is not.
让前向结果正确很容易。让梯度正确就不容易了。
Attempt 1: torch.chunk and allgather
尝试一:torch.chunk 和 allgather
The obvious first try: split tokens with torch.chunk, compute, all-gather and concatenate. The forward result is correct. But every gradient is wrong!
最明显的第一尝试:用 torch.chunk 拆分 tokens,计算,all-gather 然后拼接。前向结果是正确的。但每个梯度都是错的!
The problem is the backward of torch.chunk. Locally, it looks fine: it places the incoming gradient into the corresponding slice of the output and zero-fills the rest. With four tokens on two ranks, what each rank sees in tokens.grad after backward is:
问题出在 torch.chunk 的反向传播上。在本地看来似乎没问题:它将传入的梯度放到输出的对应切片中,并将其余部分置零。四个 tokens 分布在两个 rank 上,反向传播后每个 rank 在 tokens.grad 中看到的是:
rank 0: tokens.grad = [g0, g1, 0, 0] rank 1: tokens.grad = [0, 0, g2, g3]
From rank 0's perspective this is correct: rank 0 never touched the second half of tokens, so it has no gradient to contribute there. But in the distributed setting we need the full gradient on every rank, and chunk has no idea other ranks exist. Single-GPU ops are oblivious to other ranks, and that obliviousness is the entire source of every bug in this section.
从 rank 0 的视角看这是正确的:rank 0 从未接触过后半部分的 tokens,因此它对那里没有梯度贡献。但在分布式场景中,我们需要每个 rank 上都拥有完整的梯度,而 chunk 并不知道有其他 rank 存在。单 GPU 操作对其他 rank 毫无感知,而正是这种无感知导致了本节所有 bug 的根源。
Attempt 2: a custom scatter
尝试二:自定义 scatter
We replace torch.chunk with a custom autograd function whose backward all-gathers the partial gradients and concatenates them. Now tokens.grad is consistent across ranks.
我们用自定义的 autograd 函数替换 torch.chunk,该函数的反向传播会 all-gather 部分梯度并将它们拼接起来。现在 tokens.grad 在各个 rank 间是一致的了。
It is also exactly twice the baseline. With TP=2, allgather's backward calls reducescatter: sum across ranks, then split. But the upstream gradient is identical on both ranks (the loss is computed on the gathered replicated output), so summing doubles it:
但梯度也正好是基线的两倍。在 TP=2 的情况下,allgather 的反向传播调用 reducescatter:跨 rank 求和,然后拆分。但上游梯度在两个 rank 上是相同的(损失是在收集到的复制输出上计算的),因此求和将其翻倍:
reduce: [o0, o1, o2, o3] + [o0, o1, o2, o3] = [2o0, 2o1, 2o2, 2o3] scatter: rank 0 gets [2o0, 2o1], rank 1 gets [2o2, 2o3] correct: rank 0 gets [o0, o1], rank 1 gets [o2, o3]
reduce: [o0, o1, o2, o3] + [o0, o1, o2, o3] = [2o0, 2o1, 2o2, 2o3] scatter: rank 0 得到 [2o0, 2o1],rank 1 得到 [2o2, 2o3] 正确: rank 0 得到 [o0, o1],rank 1 得到 [o2, o3]
Every value is TPworldsize x instead of x. The root cause is a mismatch: our custom scatter's backward does all-gather-then-concat, so the allgather in the forward is going from sharded to replicated. Its backward should be a plain chunk (each rank takes its slice), not reducescatter. PyTorch ships reducescatter because that's correct when the upstream gradient is partial; in our graph it's replicated, so the reduction double-counts.
每个值都是 TPworldsize × x 而不是 x。根本原因是一个不匹配:我们自定义 scatter 的反向传播执行的是 all-gather-then-concat,因此前向的 allgather 是从分片(sharded)变为复制(replicated)。它的反向传播应该是一个简单的 chunk(每个 rank 取自己的切片),而不是 reducescatter。PyTorch 默认提供了 reducescatter,因为当上游梯度是部分(partial)时这是正确的;但在我们的计算图中梯度是复制的,所以归约会导致重复计数。
Attempt 3: a custom all-gather-to-replicate
尝试 3:一个自定义的 all-gather-to-replicate
We write a second autograd function: forward is a normal all-gather, backward is just a chunk (each rank takes its own slice of the upstream gradient, no reduction). This is the right backward when the output is replicated and we are returning to a sharded state.
我们编写了第二个 autograd 函数:前向是正常的 all-gather,反向只是 chunk(每个 rank 取自己那一片上游梯度,不进行归约)。当输出是复制的,并且我们要返回分片状态时,这就是正确的反向传播。
tokens.grad finally matches! But cond.grad and weight.grad are still wrong: different on every rank, summing to the baseline.
tokens.grad 终于匹配了!但 cond.grad 和 weight.grad 仍然错误:每个 rank 上不同,加起来等于 baseline。
Why does this happen? cond is replicated, but on each rank it only interacts with the local slice of tokens, so each rank's cond.grad contains only the contribution from its half of the work. Whenever a replicated tensor is consumed alongside a sharded one, its gradient lands as a partial sum and needs an explicit reduction.
为什么会这样?cond 是复制的,但在每个 rank 上它只与 tokens 的本地切片交互,因此每个 rank 的 cond.grad 只包含来自其一半工作的贡献。每当一个复制的张量与一个分片的张量一起被消费时,其梯度会以部分和的形式出现,需要显式归约。
Attempt 4: copy-parallel for cond and self.weight
尝试 4:对 cond 和 self.weight 做 copy-parallel
The rule: any replicated tensor that flows into a sharded computation needs its gradient summed across ranks. We add a third autograd function: identity in the forward, allreduce in the backward. The forward exists only to insert the function into the autograd graph; the rest of the model sees the tensor unchanged. On the way back, the allreduce sums per-rank partials into the true gradient on every rank.
规则:任何流入分片计算的复制张量都需要将其梯度跨 rank 求和。我们添加了第三个 autograd 函数:前向是恒等变换,反向是 allreduce。前向仅用于将该函数插入 autograd 图;模型其余部分看到的张量不变。在反向传播时,allreduce 将每个 rank 的部分和汇总为每个 rank 上的真实梯度。
Both cond and self.weight are replicated tensors consumed alongside sharded tokens, so both need wrapping. Wrap them and all gradients finally match.
cond 和 self.weight 都是与分片 tokens 一起被消费的复制张量,因此两者都需要包装。包装之后,所有梯度终于匹配了。
The scoreboard across all four attempts:
以下为所有四次尝试的结果表:
| # | What changed | tokens.grad | cond.grad | weight.grad | | --- | --- | --- | --- | --- | | 1 | torch.chunk + allgather | wrong: half is zero-filled / rank | wrong | wrong | | 2 | + custom scatter (backward = allgather) | wrong: TPworldsize x | wrong | wrong | | 3 | + custom all-gather (backward = chunk) | match | wrong: partial per rank | wrong: partial per rank | | 4 | + all-reduce wrapper on cond and self.weight | match | match | match |
| # | 变更内容 | tokens.grad | cond.grad | weight.grad | | --- | --- | --- | --- | --- | | 1 | torch.chunk + allgather | 错误:每个 rank 上一半为零填充 | 错误 | 错误 | | 2 | + 自定义 scatter(反向 = allgather) | 错误:TPworldsize × x | 错误 | 错误 | | 3 | + 自定义 all-gather(反向 = chunk) | 匹配 | 错误:每个 rank 上为部分和 | 错误:每个 rank 上为部分和 | | 4 | + 对 cond 和 self.weight 使用 all-reduce 包装 | 匹配 | 匹配 | 匹配 |
Three custom autograd functions. A module whose forward pass is three lines. And we still are not done, because the parallelization is now coupled to the exact shape of the forward.
三个自定义 autograd 函数。一个前向传播只有三行的模块。而且我们还没有完成,因为并行化现在已经与前向的具体形状耦合在一起了。
重构前向传播,破坏梯度
假设有人重构了模块,先索引 cond 再执行 projection。数学上是等价的。但并行化以不明显的方式发生了变化。sampleids 不再需要分片。condpertoken 现在变成了 token 形状,必须被分片。关键是:必须从 self.weight 中移除 all-reduce 包装器,因为在新版本中它会与已经分片的输入交互。如果保留它,你会悄悄地使梯度翻倍。
这是那种不会产生错误、不会产生 NaN、不会导致崩溃的 bug。模型正常训练。损失在下降。结果明显比单 GPU 基线更差,而且你可能几周都不会注意到。
DTensor 实际做了什么
DTensor 的贡献在于一个类型系统:每个张量都携带放置元数据,运行时会拒绝将 DTensor 与常规张量混用。
DTensor 是一个常规的 PyTorch 张量加上两条分布式信息:
- 一个描述进程组拓扑的 DeviceMesh。 - 一个 Placement:Replicate、Shard(dim) 或 Partial(sum)。
我们手动编写的每个自定义 autograd 函数都映射为两个 placement 之间的一个 redistribute 调用。完整的转换集合:
AllGather Shard(dim) ──────────► Replicate ▲ │ │ │ │ ▼ ReduceScatter Scatter │ │ ▲ ▼ Partial(sum) ──────────► Replicate AllReduce
Shard(X) ─── AllToAll ──► Shard(Y)
每个箭头都是一个 redistribute 调用。DTensor 会根据源 placement 和目标 placement 自动选择正确的调用。
具体而言,以下是 TP=2 且全局值为 [a, b, c, d] 的 4 元素张量下每个转换的样子:
Scatter (Replicate → Shard):每个 rank 持有完整张量;redistribute 后,每个 rank 持有自己的切片。
before: rank 0: [a, b, c, d] rank 1: [a, b, c, d] after: rank 0: [a, b] rank 1: [c, d]
AllGather (Shard → Replicate):相反的过程。每个 rank 持有一个切片;redistribute 后,每个 rank 持有完整张量。
before: rank 0: [a, b] rank 1: [c, d] after: rank 0: [a, b, c, d] rank 1: [a, b, c, d]
AllReduce (Partial → Replicate):每个 rank 持有一个部分贡献(相同形状,不同值)。redistribute 会跨 rank 求和,从而每个 rank 得到真实总和。
之前:rank 0: [1, 2, 3, 4] rank 1: [5, 6, 7, 8] (部分和) 之后:rank 0: [6, 8, 10, 12] rank 1: [6, 8, 10, 12] (真实梯度)
ReduceScatter(部分 → 分片):与 AllReduce 相同,但每个 rank 不是复制完整和,而是只保留该和的一个分片。
之前:rank 0: [1, 2, 3, 4] rank 1: [5, 6, 7, 8] 和: [6, 8, 10, 12] 之后:rank 0: [6, 8] rank 1: [10, 12]
这些正是我们四次尝试中的原语:Scatter 将输入拆分到各 rank,AllGather 重新组装,AllReduce 对部分梯度求和。DTensor 的职责是,当两个相邻操作在放置方式上不一致时,自动选择合适的箭头。
对于每个算子,DTensor 会跟踪分片策略:给定输入放置方式,哪些输出放置方式是有效的。矩阵乘法是典型例子。如果 left-hand side 在其内部维度上做了分片,而 right-hand side 在其外部维度上做了分片,则输出自然是 Partial(sum):每个 rank 持有一部分点积结果,全局结果是它们的和。这正是张量并行中列线性层和行线性层的工作方式;DTensor 只是让状态转换变得明确,而不是隐含的。
当一个算子没有注册策略时(我们在 indexadd 等算子上经常遇到这种情况),DTensor 会拒绝运行,而不是猜测;你需要自己注册策略。其心智模型是:宁可出现类型错误,也不要静默地产生错误的梯度。
你通过编写一个计划来并行化一个模块:
1. 划分参数:将模块参数转换为具有所需放置方式的 DTensor。 2. 准备输入:在前向钩子(pre-forward hook)中将传入的常规张量转换为 DTensor。 3. 准备输出:在后向钩子(post-forward hook)中将 DTensor 转换回常规张量。
对于我们的调制模块,该计划包含三次 redistribute 调用,且无需自定义自动求导函数:
输入钩子:提升为 DTensor,对 tokens 和 sampleids 进行分片 tokens = DTensor.fromlocal(tokens, mesh, placements=(Replicate(),)) tokens = tokens.redistribute(mesh, placements=(Shard(0),)) # Replicate → Shard cond = DTensor.fromlocal(cond, mesh, placements=(Replicate(),)) # 保持 Replicate sampleids = DTensor.fromlocal(sampleids, mesh, placements=(Replicate(),)) sampleids = sampleids.redistribute(mesh, placements=(Shard(0),))
self.weight 通过 distributetensor 注册为 Replicate
输出钩子:收集并返回常规张量 output = output.redistribute(mesh, placements=(Replicate(),)).tolocal()
该方案明确规定了分片布局;DTensor 会插入正确的反向集合通信。调用点与单 GPU 版本完全一致。将前向计算重构为先索引 cond、再投影,你调整的是方案本身,而非 autograd 计算图。
如果你只关心正确性,故事到这里就可以结束了。
DTensor 在大规模场景下的代价
一旦 DTensor 被部署到实际的训练任务中,就会暴露出一些 API 层面并不明显的代价。
第一个是布局开销。对于普通张量,像 `torch.mm` 这样的操作会直接调度到 CUDA 内核。而对于 DTensor,运行时必须先检查每个输入的布局,根据这些布局为该操作查找正确的分片策略,判断是否需要重新分布,然后才能调度实际的内核。一次 Transformer 前向传播可能在每个微批次中执行数百次 DTensor 感知的操作,每次都要付出这个查找成本。单个开销很小,但会在整个模型中累积。
第二个是重新分布频率。当两个连续的操作期望不同的布局时(例如,注意力模块想要 Shard(dim=1),而 MLP 想要 Shard(dim=0)),DTensor 会在它们之间插入一次 redistribute。每一次 redistribute 都是一个真实的集合通信操作:all-gather、reduce-scatter 或 all-to-all。注意力层的形状调整、序列并行和激活检查点都会造成布局不匹配,从而强制产生这些额外的转换。即使每次单独的集合通信都很快,它们引入的同步点也会割裂计算图并限制重叠能力。
在实践中,这些代价是可测量的。在相同的工作负载下,仅从 FSDP 切换到 DTensor + TP 就使我们的 MFU 明显下降,而在此基础上再添加动态布局则使其进一步下降。性能下降的根源很少是通信本身,而是累积的调度开销、额外的重新分布以及计算图碎片化。
对 DTensor 运行时开销的自然反应是编译:将计算图追踪一次,把布局检查融合到编译后的内核中,将调度成本放到编译时而非每次迭代中支付。原则上,`torch.compile` 做的正是这件事。但在实践中,让编译能够可靠地与 DTensor 协作,是整个技术栈中最难的问题之一。
编译器为何在这里遇到困难
PyTorch 的编译管线(Dynamo、Inductor、FX passes)是为常规张量构建的,在那里表现最佳。DTensor 支持仍在完善中。编译器现在必须追踪 placement 传播、重分配逻辑以及设备网格感知调度——所有这些都位于它已知如何优化的算子之上。结果是更多的图断点、更差的融合以及更低的核函数效率。
两种失败模式反复出现。
编译器错误掩盖了原始 bug
一个简单的算子不匹配可能爆炸成数页的 DTensor 调度追踪、FX 图内部机制、placement 传播失败以及 Inductor 降级错误。原始 bug 仍然存在;你只是需要从抽象层之下把它挖出来。最可靠的缓解方法是限制 placement 的动态性:避免在热循环内改变布局,保持张量 placement 可预测,并在各模块间标准化分片模式。编译器看到的特化越少,失败信息就越清晰可读。
重编译风暴
对于常规张量,编译器能优雅地处理形状变化:如果维度 0 发生变化,Dynamo 会将其标记为动态,后续沿该维度的不同尺寸输入不会触发重编译。DTensor 则得不到这种处理。DTensor 上的形状变化总是会触发完整的重编译,因为编译器目前无法像对常规张量那样将 DTensor 的维度标记为动态。
这使得 DTensor 图对输入变化更加敏感。一个用常规张量每小时只会产生少量重编译的工作负载,在使用 DTensor 时可能会产生数百次重编译,因为每个新的序列长度或批次形状都会使已编译的图失效。
两种策略有所帮助:
- 将张量填充到固定大小,让编译器始终看到相同的形状。这会在填充上浪费一些计算,但完全消除了重编译。 - 避免在图边界使用 DTensor。仅在编译区域内转换为 DTensor,并在边界处转换回来,这样编译器追踪 DTensor 算子时形状固定,变化留在了编译图之外。
一个模块的改动,8 个百分点的 MFU 下降
在 DTensor 系统中,性能不是你正在查看的代码的局部属性。某个模块中的一处改动可能触发另一个模块中的重分配、破坏下游的融合、改变通信重叠、改变编译器特化,或在远离原始编辑位置的地方插入同步点。
具体来说,在相同的训练栈中:
将一个注意力块重构得更加整洁,导致我们某次运行的模型 FLOPs 利用率(MFU)下降了大约 8 个百分点,因为这一改动改变了其他地方的 redistribution 模式。
在我们对形状和放置位置的变异性加以约束之前,运行间吞吐量差异在早期稳定阶段达到了 ±22%。
在编译稳定过程中,GPU 空闲时间在最差的工作负载下介于 18% 到 27% 之间。
其反直觉的结果是,高性能的 DTensor 代码往往需要选择性地跳出该抽象层:对最热路径使用手动集合通信,在自动放置传播产生不佳调度的地方使用自定义融合内核,在隐式 redistribution 落在错误位置时使用显式 redistribution。在我们的工作负载中,混合栈始终优于完全抽象化的栈。DTensor 不仅仅是一种编程抽象;它更是一项系统层面的决策,同时影响着编译器行为、运行时调度、图稳定性以及运维可靠性。
要点总结
DTensor 是一层很好的抽象层,能够防止一类静默梯度错误。不过,正确性和性能是两个不同的课题,而 DTensor 目前只解决了前者。本文所描述的代价是成长中的阵痛:编译器集成、动态形状支持和算力算子覆盖度都在随每个 PyTorch 版本发布而不断改善。在它们趋于成熟之前,将 DTensor 用于保证正确性,同时用选择性跳出抽象层来优化性能,是当前行之有效的组合。
发现更多
新闻 Runway 与 Lionsgate 达成合作 客户故事 《大卫之家》如何借助 Runway 成为亚马逊最新热播剧集 新闻 探索电影制作的未来:Runway 与 Tribeca 电影节 2024 的项目合作