最新的Kimi的注意力残差连接以及字节的“残差”连接两篇论文在最开始的残差连接方案1中核心过程就是 xxf(x)随着不断的叠加卷积层数那么就容易导致 梯度消失以及 退化问题残差连接就是通过跳跃连接skip connection允许输入信息绕过若干层直接传递到后面的层。后续也有很多去对这个过程进行改进比如说使用门控残差连接、加权残差连接、修改连接位置等。不过影响都不是很大因此对于残差连接过程就一直没有变化还是保持最开始的计算方式了。在kimi以及字节最近新发表两篇论文都是对这个过程做的改进具体解释如下。Kimi注意力残差连接首先按照论文中逻辑出发在标准的残差计算中hlhl−1fl−1(hl−1) 对于这个计算方式在计算梯度传播过程中会直接将每一层的贡献是相同直接计算上公式梯度因此后续论文就提出做一个门控的残差连接方式 hlαl⋅hl−1βl⋅fl−1(hl−1)对于上述两种残差注意力方式带来最大的问题就是所有层的贡献都是一致的除此之外后续层只能获取前层的信息导致更加前面层的信息被稀释比如说l层只能获取l-1层信息虽然l-2的信息会融合到l-1层但是l-2还是对l层的作用有限。因此kimi的attention-residual出发点就是让后续层可以看到更加前面层的信息以及通过一个合适权重去控制残差连接基于这个论文2里面提出方案如下图c中描述对于_第n个block我将前面的几层的输出都进加权融合作为第n层的输入_具体融合方式为hlα0→l⋅h1∑i1l−1αi→l⋅fi(hi)其中 ∑i1l−1αi→l1那么对于权重系数 αi→l 的计算方式为其实也就是计算softmax的注意力权重里面的 qlwl 通过一个学习的向量以及历史层的输出去计算softmax值去控制权重特征融合。去看代码具体过程PYTHONCopydef block_attn_res(blocks: list[Tensor], partial_block: Tensor, proj: Linear, norm: RMSNorm) - Tensor:V torch.stack(blocks [partial_block]) # [N1, B, T, D]K norm(V)logits torch.einsum(d, n b t d - n b t, proj.weight.squeeze(), K)h torch.einsum(n b t, n b t d - b t d, logits.softmax(0), V)return hdef forward(self, blocks: list[Tensor], hidden_states: Tensor) - tuple[list[Tensor], Tensor]:partial_block hidden_states # 进入当前层的初始 hidden_states通常是上一层的输出# 在 Attention 子层前先做一次 Block AttnResh block_attn_res(blocks, partial_block, self.attn_res_proj, self.attn_res_norm)# 如果当前层是 Block 的边界层 → 把当前 partial_block 作为完整 Block 保存下来if self.layer_number % (self.block_size // 2) 0:blocks.append(partial_block) # blocks 列表增长新增一个完成的 Block reppartial_block None # 重置 partial新 Block 从零开始代码这里有小问题实际可能要用 h 或重置逻辑# 自注意力子层标准 Transformer attentionattn_out self.attn(self.attn_norm(h))partial_block partial_block attn_out if partial_block is not None else attn_out# ↑ 标准残差partial_block attn_out Block 内部用经典 # 在 MLP 子层前再做一次 Block AttnRes用不同的 proj 和 normh block_attn_res(blocks, partial_block, self.mlp_res_proj, self.mlp_res_norm)# MLP 子层mlp_out self.mlp(self.mlp_norm(h))partial_block partial_block mlp_out # 再次标准残差累加return blocks, partial_block # 返回更新后的 blocks 列表 当前 Block 的 partial sum其实通过代码很容易发现在block计算过程就是输入前将前n层的block特征进行attention-residual方式特征融合在计算完毕之后进行一个普通的残差连接而后在将输出进行mlp处理之前再次通过一次attention-residual连接处理。字节混合注意力在字节论文3中提出混合注意力去解决随着 LLM 的深度增加它们往往会遭遇信号衰减的问题在浅层形成的有用特征会因反复的残差更新而逐渐被稀释使得它们在更深的层中更难恢复出发点和kimi的attention-residual相同。对于上图中提到的read以及write分别表示的是残差连接方式 xxf(x)里面分别对于x以及连接方式比如说对于最开始残差连接我的read就是x不去对x进行其他处理因此论文里面将其标记为identity而我的连接方式是add因此将write处理为add。在上图b中选择直接将所有的信息进行拼接比如说第i层计算输出就行和输入就行concat操作虽然在信息传播过程是无损的可以解决上面的信号衰减问题但是这样会带来显存占用过高。那么论文里面提出Depth Attention处理过程为对于输入通过相面方式处理其中对于 Ki 以及 Vi 表示的GQA过程中我的历史缓存的kv值而 Ql−1 则是上一层的Q结果通过注意力融合方式得到最终的输入 Xlin 直接将这个结果解析attention的注意力计算得到 Xlout在得到结果之后通过又可以得到新的一层的输出结果相当于替代了之前的残差连接通过相加为线性层处理方式。除此之外进一步提出升级的 Mixture-of-Depth Attention方式对于上述过程中depth表示所有前面层的深度 KV cache对应深度部分而QKV则是表示当前层的结果对应序列部分10-23行处理序列部分注意力就是比较常规的注意力计算过程24-29行处理处理深度部分注意力在计算注意力过程中会用softmax去更新同一个m, acc, o相当于将cache部分信息融入到注意力中。总结两篇论文中都是为了解决随着层数的叠加带来的“信息遗忘”问题Kimi中选择直接将“历史block”信息通过注意力融合方式进行加权残差连接attention-residual也就是 yα⋅hl∑i1l−1αhi具体过程为将历史所有的block结果和用一个可学习的向量之间计算softmax作为权重α 具体残差发生在1、mlp处理前2、每一个block处理之后。在字节的mixture-of-depth attention处理方式则是直接将GQA中的kv-cache中的KV值用来计算注意力去弥补信息损失具体过