【infra之路】LLM 预测一个 Token 的完整流程:从文本输入到概率输出
前言当你向 ChatGPT 输入今天天气真模型回答好。这个看似简单的接一个字背后经历了分词、向量化、几十层 Transformer 的矩阵运算、概率采样等一系列精密步骤。而且模型不是在想好一整句话再输出——它是一个 token 一个 token 地生成每次生成一个 token 都要把整个模型跑一遍。这篇文章用一个具体的例子跟踪一个 token 从输入到输出的完整旅程。以 LLaMA-7B32 层 Transformer为例拆解每一步发生了什么、数据怎么变换、计算量有多大。一、总览一次前向传播的六个阶段不管输入多长的 prompt模型预测下一个 token的流程都可以拆成六步① 分词Tokenization 文本 → 整数 ID 序列 ② 嵌入Embedding 整数 ID → 稠密向量 ③ 位置编码Position 给每个向量加上位置信息 ④ Transformer 层 × N 自注意力 前馈网络逐层处理 ⑤ 输出投影LM Head 最后一层的隐藏状态 → 词汇表维度的 logits ⑥ 采样Sampling logits → 概率分布 → 选出下一个 token下面用一个具体的例子走一遍。假设输入是“今天天气真”模型需要预测下一个 token。二、第一步分词TokenizationLLM 不认识中文字符它只认识整数。分词器Tokenizer把文本切分成一系列 token ID。以 LLaMA 使用的 SentencePieceBPE为例输入文本: 今天天气真 分词结果: [2521, 1083, 29576, 17582]注意几点一个汉字不一定对应一个 token。今天可能被分成今和天两个 token也可能被合并成一个 token取决于 BPE 训练的合并规则。高频词更容易被合并生僻字更容易被拆成多个子词。LLaMA 的词汇表大小是32,000vocab_size32000意味着每个 token ID 是 0~31999 之间的一个整数。分词是确定性的——同一段文本永远被分成相同的 token 序列。三、第二步嵌入Embedding Lookup拿到 token ID 后需要从嵌入矩阵中查表把每个整数变成一个稠密向量。嵌入矩阵 E: 形状 [32000, 4096]vocab_size × hidden_dim token ID 2521 → 取 E 的第 2521 行 → 一个 4096 维的浮点向量 token ID 1083 → 取 E 的第 1083 行 → 一个 4096 维的浮点向量 ...对于 4 个 token 的输入嵌入层的输出是一个[4, 4096]的矩阵——4 个 token每个 4096 维。嵌入矩阵本身是模型训练出来的参数占了模型参数的很大一部分。对于 LLaMA-7B嵌入矩阵有 32000 × 4096 × 2 bytes ≈262 MB。四、第三步位置编码嵌入向量本身不包含这个词在句子的什么位置的信息。Transformer 需要额外注入位置信息否则今天天气真和天今真气天会被认为是一样的。主流 LLM 使用RoPERotary Position Embedding它的做法是在每一层计算注意力时对 Query 和 Key 向量做旋转旋转角度取决于 token 的位置。位置 0今: Q₀ 和 K₀ 旋转 0° 位置 1天: Q₁ 和 K₁ 旋转 θ° 位置 2气: Q₂ 和 K₂ 旋转 2θ° 位置 3真: Q₃ 和 K₃ 旋转 3θ°RoPE 的优势在于两个 token 之间的注意力得分只取决于它们的相对位置位置差而不是绝对位置。这让模型可以泛化到比训练时更长的序列。五、第四步Transformer 层核心这是 LLM 计算量最大的部分。LLaMA-7B 有32 层 Transformer每一层的结构相同数据从第 1 层流向第 32 层。图1Decoder-only Transformer 的单层结构。输入向量经过自注意力机制和全连接网络每步都有残差连接和层归一化。5.1 自注意力机制Self-Attention这是 Transformer 最核心的计算。它的目的是让每个 token 看看其他所有 token决定该关注谁、关注多少。计算 Q、K、V输入矩阵 X形状 [4, 4096]分别乘以三个权重矩阵得到 Query、Key、ValueQ X × W_Q 形状: [4, 4096] 4 个 token每个产生一个 4096 维的 Query K X × W_K 形状: [4, 4096] V X × W_V 形状: [4, 4096]实际中 4096 维会被拆成32 个注意力头每个头 128 维多头并行计算。计算注意力分数Attention(Q, K, V) softmax(Q × K^T / √d_k) × V拆开看Q × K^T每个 token 的 Query 和所有 token 的 Key 做点积得到一个 [4, 4] 的注意力分数矩阵。分数越高表示两个 token 之间的关联度越强。/ √d_k除以 √128 ≈ 11.3 做缩放防止点积值太大导致 softmax 梯度消失。Causal Mask把未来位置的注意力分数设为 -∞。位置 0 只能看位置 0位置 1 能看 0 和 1位置 2 能看 0、1、2……这是 Decoder 的关键约束——模型不能偷看后面的内容。softmax把每一行的分数归一化成概率和为 1。× V用注意力权重对 Value 做加权求和得到每个 token 的融合了上下文信息的输出。Causal Mask 矩阵4 个 token: tok0 tok1 tok2 tok3 tok0 [ ✓ -∞ -∞ -∞ ] ← 今只能看自己 tok1 [ ✓ ✓ -∞ -∞ ] ← 天能看今和天 tok2 [ ✓ ✓ ✓ -∞ ] ← 气能看前三个 tok3 [ ✓ ✓ ✓ ✓ ] ← 真能看所有输出投影多头注意力的结果拼接后再乘以一个输出权重矩阵 W_O得到自注意力的最终输出。加上残差连接X_attn LayerNorm(X MultiHeadAttention(X)) 形状: [4, 4096]5.2 前馈网络Feed-Forward Network自注意力处理的是 token 之间的交互前馈网络处理的是每个 token 自身的特征变换。LLaMA 使用SwiGLU激活函数比 GELU 效果更好计算过程FFN(x) (x × W₁ ⊙ Swish(x × W₃)) × W₂其中 W₁ 和 W₃ 把维度从 4096 扩展到 11008约 2.7 倍W₂ 再压缩回 4096。⊙表示逐元素相乘Swish(x) x × σ(x)是激活函数。同样加上残差连接和层归一化X_ffn LayerNorm(X_attn FFN(X_attn)) 形状: [4, 4096]5.3 32 层堆叠上面这个过程自注意力 前馈网络重复32 次。每一层的输出作为下一层的输入信息逐层传递和提炼。最后一层输出的 [4, 4096] 矩阵就是模型对整个输入序列的最终理解。每一层都有独立的权重参数W_Q、W_K、W_V、W_O、W₁、W₂、W₃ 以及 LayerNorm 的参数LLaMA-7B 的大部分参数都花在了这 32 层的权重矩阵上。六、第五步输出投影LM Head经过 32 层 Transformer 后取最后一个 token“真”的隐藏状态一个 4096 维向量通过一个线性层投影到词汇表维度最后一个 token 的隐藏状态 h: 形状 [4096] LM Head 权重矩阵 W_lm: 形状 [32000, 4096] logits h × W_lm^T 形状 [32000]得到的logits是一个 32000 维的向量每个值对应词汇表中一个 token 的原始得分。比如logits[6521] 8.7 好的得分 logits[29892] 5.3 热的得分 logits[1523] 3.1 冷的得分 logits[892] 0.2 狗的得分 ...有些模型如 LLaMA使用权重共享Weight TyingLM Head 的权重矩阵和嵌入矩阵是同一个矩阵 E。这样可以节省 262 MB 参数。七、第六步采样Samplinglogits 只是原始分数需要转成概率才能选 token。Softmax 归一化prob_i exp(logit_i / T) / Σ exp(logit_j / T)其中 T 是温度参数Temperature控制概率分布的锐利度。温度参数的作用Temperature 0.1低温: 好 98.2%, 热 1.3%, 冷 0.3%, ... ← 几乎确定 Temperature 1.0常温: 好 72.5%, 热 15.1%, 冷 5.2%, ... ← 有一定随机性 Temperature 2.0高温: 好 42.3%, 热 28.7%, 冷 15.1%, ... ← 很随机温度越低模型越保守倾向于选概率最高的 token温度越高分布越平坦低概率 token 也有机会被选中。常见采样策略Greedy贪心直接选概率最高的 token。确定性输出但容易生成重复和无聊的文本。Top-K只保留概率最高的 K 个 token从中采样。比如 K50就把 32000 个候选缩减到 50 个。Top-PNucleus保留累积概率达到 P 的最小 token 集合。比如 P0.9就取概率从高到低累加到 90% 的那些 token。好处是候选数量动态变化——当模型很确定时候选少不确定时候选多。Top-P0.9 的例子: 好 72.5% → 累积 72.5%继续 热 15.1% → 累积 87.6%继续 冷 5.2% → 累积 92.8%超过 90%停止 候选集合: {好, 热, 冷}重新归一化后从中采样Repetition Penalty重复惩罚对已经出现过的 token 降低概率防止模型陷入重复循环。选出 Token采样完成后得到下一个 token 的 ID。把它追加到序列末尾原始序列: [2521, 1083, 29576, 17582] → 今天天气真 新 token: 6521 → 好 更新序列: [2521, 1083, 29576, 17582, 6521] → 今天天气真好然后把这个新序列再次送入模型预测下一个 token。如此循环直到模型生成EOSEnd of Sequencetoken 或达到最大长度限制。这就是自回归生成Autoregressive Generation。八、Prefill 与 Decode推理的两个阶段在实际推理中LLM 的工作分为两个截然不同的阶段图2LLM 推理的两个阶段。Prefill 阶段并行处理所有输入 token 并初始化 KV CacheDecode 阶段逐 token 生成每步只处理一个 token 但需要读取完整 KV Cache。Prefill预填充阶段用户提交 prompt 后模型一次性并行处理所有输入 token。这个阶段的特点是所有 token 同时通过 Embedding、Transformer 层GPU 的计算单元被充分利用compute-bound计算出所有 token 的 K 和 V 向量写入KV Cache产出第一个生成的 token性能指标TTFTTime to First Token即用户从提交到看到第一个字的延迟Decode解码阶段Prefill 产出第一个 token 后进入逐 token 生成阶段。每生成一个 token只有一个新 token通过 Transformer 层batch_size1GPU 的计算单元利用率很低瓶颈变成了从显存读取权重和 KV Cache 的速度memory-bound每生成一个 token 的时间基本固定不像 Prefill 那样和 prompt 长度成正比性能指标TPSTokens Per Second即每秒生成多少个 token这就是为什么 ChatGPT 有时候想了很久才蹦出第一个字Prefill 处理长 prompt但一旦开始输出就一个字一个字很稳定Decode 阶段。九、KV Cache避免重复计算的关键每次 Decode 生成新 token 时如果从头重算所有 token 的 K 和 V计算量会随着序列长度平方增长。KV Cache通过缓存已计算的 K、V 向量来避免这个问题。图3KV Cache 的工作原理。Prefill 阶段计算并缓存所有输入 token 的 K/VDecode 阶段只需计算新 token 的 K/V 并追加到缓存中注意力计算时用新 Q 和缓存中的所有 K/V 做点积。工作原理第 t 步生成新 token 时: 1. 新 token 经过线性投影得到 Q_new, K_new, V_new 2. K_new, V_new 追加到 KV Cache 3. 注意力计算: Attention(Q_new, [K_cached, K_new], [V_cached, V_new]) 4. 只有 Q_new 是本次新算的K 和 V 大部分来自缓存计算量对比无 KV Cache有 KV Cache每步向量投影计算O(t × d²)O(d²)每步注意力计算O(t² × d)O(t × d)生成 N 个 token 总计算量O(N³)O(N²)对于生成长度为 1000 token 的文本KV Cache 带来的加速约为1000 倍。代价是显存占用——每个 token 在每一层都要缓存 K 和 V。KV Cache 的显存占用LLaMA-7B 的 KV Cache 显存公式mem 2 × num_layers × num_heads × seq_len × head_dim × dtype_bytes × batch_size 2 × 32 × 32 × seq_len × 128 × 2 × 1 ≈ 0.5 MB × seq_len一个 2000 token 的序列KV Cache 需要约1 GB显存。这也是为什么 LLM 有上下文长度限制——不是模型记不住而是 KV Cache 的显存撑不下。十、计算量估算预测一个 Token 到底做了多少运算以 LLaMA-7B 为例估算 Decode 阶段生成一个 token 的计算量计算环节FLOPs近似每层自注意力QKV 投影 注意力 输出投影~2 × 4096² × 3 ≈ 1 亿每层 FFNSwiGLU 两次线性变换~2 × 4096 × 11008 × 2 ≈ 1.8 亿32 层合计~32 × 2.8 亿 ≈90 亿 FLOPsLM Head4096 → 32000~2.6 亿单 token 总计~93 亿 FLOPs每生成一个 token模型要做约93 亿次浮点运算。一张 A100 的峰值算力是 312 TFLOPs理论上每秒能生成约 33,000 个 token——但实际受限于显存带宽Decode 阶段是 memory-boundA100 上 LLaMA-7B 的实际生成速度大约是100-200 tokens/s。十一、总结LLM 预测一个 token 的过程本质上是一次完整的前向传播文本经分词变成整数经嵌入变成向量经过几十层 Transformer 的注意力和前馈网络逐层提炼最后通过 LM Head 投影到词汇表空间得到一个 32000 维的 logit 向量经 Softmax 转成概率后采样选出下一个 token。一个实用的理解框架Embedding 把离散符号变成连续空间中的点Transformer 层在这个空间中做上下文感知的特征变换LM Head 把最终特征映射回离散的词汇表空间。整个模型做的事情就是在连续空间中找最合适的下一个词。Decode 阶段通过 KV Cache 缓存历史计算避免重复使得逐 token 生成在工程上可行。参考资料大模型推理机制解析预填充与生成阶段LLM 模型推理全流程解析从输入到输出的技术实现KV Cache 深度解析从原理到显存优化从输入到输出大语言模型一次完整推理简单解析大语言模型 Next Token Prediction 与 Transformer 架构