Triton编译流程:调了一个通宵后,我终于搞懂了 Triton 的编译流程
昨晚又踩了一个大坑。用 PyTorch 跑一个自定义 kernel看着代码写得挺优雅——几行 Pythonblock 级别抽象不用手动管 shared memory 和 warp。结果一跑崩了。报错信息扔过来一串ttgir、llir、MLIR之类的词满脸问号。Triton 到底是怎么把我的 Python 代码变成 GPU 上跑的二进制指令的这个问题不搞清楚debug 就是在盲人摸象。花了一个通宵把 Triton 的编译管线从头到尾捋了一遍。这篇文章就是我的复盘笔记。一、Triton 编译器长什么样先看一张全景图心里有个底。Triton 编译器是一个三段式结构架构相当清晰前端Frontend把 Python kernel 代码或 PyTorch Inductor 生成的 TritonKernel转成 Triton IR即 TTIR优化器Optimizer通过一系列 pass把 TTIR 逐步 lowering 并优化为 Triton GPU IRTTGIR后端Backend把 TTGIR 最终转成 LLVM IR交给 LLVM/NVPTX 后端生成 PTX 和 SASS这三个阶段构成了从「Python 语义」到「GPU 机器码」的完整 lowering 管道。二、编译全流程一张图说完整个编译管线可以用下面这张图概括Triton Python DSL ↓ Triton AST ↓ TTIR ← 高级张量抽象block 级 ↓ TTGIR ← 引入 GPU 概念warp / shared memory / tensor core ↓ LLVM IR ← 平台无关可交给 LLVM 后端 ↓ PTX ← NVIDIA 虚拟指令集 ↓ SASS ← GPU 硬件实际执行的机器码核心思想一句话用 MLIR 的多层 IR把「Tensor 计算语义」逐步 lowering 到「GPU 硬件执行细节」。接下来逐个拆开看。三、Triton IR还在聊「算法」的层次TTIRTriton Tensor IR是整个编译流程的第一层 MLIR Dialect。它的核心特征是硬件无关但 tile 感知。在这一层你看到的不是 thread不是 warp不是 shared memory。你看到的是 block——一块小张量。%A tt.load %ptrA : tensor128x128xf16 %B tt.load %ptrB : tensor128x128xf16 %C tt.dot %A, %B : tensor128x128xf32 tt.store %ptrC, %C : tensor128x128xf32这里面没有任何 CUDA 概念只有 block 级别的 load / compute / store。TTIR 阶段做的优化包括Tile shape 推断Operator fusion算子融合维度重排自动插入 reduction / broadcast到这里TTIR 的任务完成接下来交给下一层。四、Triton GPU IRGPU 硬件模型正式登场TTGIRTriton Tensor GPU IR是把抽象 tile 计算映射到 GPU 硬件执行模型的关键阶段。这一层开始引入真正的 GPU 概念GPU 硬件概念TTGIR 中的表示Threadttgpu.threadWarp32 线程一组ttgpu.warpBlockttgpu.sublaneShared Memoryttgpu.shared_memoryTensor Core / MMAttgpu.mma翻译成人话TTGIR 要决定哪些数据进 shared memory、哪些搁 register 里、用不用 Tensor Core、warp 怎么分配 tile。这个阶段的优化直接决定 kernel 跑得快不快Warp-level tiling大 tile 拆成 warp 小 tileCoalesced memory access保证在一个 warp 内相邻线程访问连续内存地址Shared memory bank conflict避免layout 要避开 bank conflictMMA layout 匹配数据排布要对齐 Tensor Core 要求的格式f16 / tf32 / int8说白了TTGIR 就是 Triton 性能的灵魂。如果这一层的 lowering 做不好后面 LLVM IR 再优化也救不回来。五、LLVM IR从「领域语言」到「通用中间码」到了 LLVM IR 这一层Triton 的自定义 Dialect 已经全部展开。ttgpu.thread、ttgpu.mma这些东西被翻译成 LLVM 的 SSA 指令。这层的职责是平台无关LLVM IR 是通用中间表示不绑定 NVIDIA优化LLVM 自身有大量通用优化 pass死代码消除、循环展开等代码生成交给 NVPTX 后端生成 PTX再由 NVIDIA driver 编译为 SASS模块化LLVM IR 可以链接、内联、做 LTOLowering 过程中会自动插入__syncthreads、lane_id、warp_id这些底层调用。最终生成的 LLVM IR 长这样call void llvm.nvvm.barrier() %v call 2 x float llvm.nvvm.mma.m16n8k16六、实战PyTorch Inductor 怎么调用 Triton前面讲的全是原理来点能跑的命令。PyTorch 2.0 的 Inductor 后端会自动把计算图编译成 Triton Kernel。想看生成的 Triton IR两步# 示例一简单示例TORCH_COMPILE_DEBUG1python example.py# 示例二跑 ResNet50TORCH_COMPILE_DEBUG1python resnet50.pyTORCH_COMPILE_DEBUG1这个环境变量一设PyTorch 会在/tmp/torch_compile_debug下 dump 出所有中间产物。你能看到每一层的 IR*.ttir→ Triton IR*.ttgir→ Triton GPU IR*.llir→ LLVM IR调试 Triton 编译问题时这些中间文件比报错信息有用一百倍。直接打开看哪一层 IR 出了问题比对着 traceback 瞎猜要高效得多。七、一张表总结三层 IRIR 层次抽象级别关心的东西典型操作Triton IR算法级Block / Tile / Tensorload、dot、store、fusionTriton GPU IR硬件映射级Warp / Shared Memory / MMA调度映射、bank conflict 优化LLVM IR指令级寄存器 / 内存 / 同步展开为 SSA、生成 PTX从 TTIR → TTGIR → LLVM IR 的 lowering 过程本质上是一个「逐步丧失抽象、逐步获取性能」的过程。每一层丢掉一些通用性换来更贴近硬件的执行效率。这也是 MLIR 的核心设计哲学——不是为了多套几层 IR 炫技而是每一步 lowering 都有明确的优化机会。写在最后Triton 的编译流程看起来复杂但拆开看其实非常线性Python → TTIRtile 语义TTIR → TTGIRGPU 映射TTGIR → LLVM IR指令生成LLVM IR → PTX → SASS交给 NVIDIA 工具链下次再遇到 Triton 编译报错别慌。拿TORCH_COMPILE_DEBUG1dump 出中间 IR顺着这三层的 lowering 管线排查问题大概率藏在 TTGIR 的 warp tiling 或者 shared memory layout 那里。反正我昨晚就是这样找到的。