PyTorch KernelAgent 源码解读 ---(2)--- 总体流程
任务上达到 100% 正确率的开源智能体系统。KernelFalcon 代码库位于 github.com/meta-pytorch/KernelAgent附带文档与入门示例。0x01 引言1.1 背景与挑战KernelFalcon 面临的背景和挑战如下手写优化GPU kernel是部署瓶颈。编写优化的 GPU 内核仍是部署机器学习模型的瓶颈。团队很少有为每种形状、数据类型和硬件代际手工调优算子的带宽。随着模型演进问题加剧适用于 ResNet 的模式无法直接映射到 Mamba 的选择性状态或 MoE 的条件路由。传统编译器TorchInductor、TVM、XLA难以处理长尾场景。现代编译器取得实质进展但仍对长尾场景束手无策。TorchInductor 覆盖常见模式TVM 自动调度稠密内核XLA 针对动态形状做特化。然而异常算子、动态控制流和异构融合模式仍逃逸最优编译。NVIDIA 今年年初使用 DeepSeek-R1 配合推理时缩放在 KernelBench L1/L2 上取得强劲结果证明基于 LLM 的方法配合验证循环可媲美或超越传统方法——但未触及完整模型架构L3。LLM方法需要更好的架构设计。如果能在不扩充规则库或雇佣更多 GPU 专家的前提下自动合成保持 PyTorch 语义且逼近手工调优性能的 Triton 内核会怎样因此PyTorch推出了 KernelFalcon一个保持 PyTorch 语义的代码到代码系统可生成优化的 Triton 内核。它采用并行探索与基于执行的验证而非一次性生成——交付的内核真实运行在 GPU 上且与原始模型数值等价。注长尾问题原本是指在数据分布中少数数据“头部”的频率非常高而大多数数据“尾部”的频率非常低的一种现象。在这里主要想表达的是AI、算子相关的工作中总是由于创新、软硬件变化等缘故出现传统方案无法覆盖的场景导致功能不支持或者性能下降。1.2 设计理念KernelFalcon的核心目标是将PyTorch 程序自动转换为经过验证的 GPU内核kernel。比如它通过LLM生成 Triton 代码并在沙盒子进程中自动验证正确性迭代修复直到通过测试。KernelFalcon基于 Pytorch Module 自动生成GPU Kernel其设计理念如下保持Python语义支持if/else、while、动态shapeverification 优先循环编译测试候选kernel失败反馈端到端组合 verification算子融合kernel替换原ops全模型等价性检查项目分两个层次Fuser/编排层面向完整PyTorch问题进行编排包括AST分析路由、子图提取、内核分发、端到端组合。triton_kernel_agent/单内核合成工作池负责单个Triton内核的生成、验证和迭代改进是底层工作单元。0x02 架构为何选择 KernelFalcon为何是深度智能体传统静态、基于图的编译器依赖 IR 变换和每模式调度。追踪常常将控制流冻结为单一路径并在动态形状下失效。KernelFalcon 则走另一条路保持 Python 语义。KernelFalcon 停留在 PyTorch 代码到代码层面因此 if/else、while、数据依赖路由和动态形状依然有效。验证器优先循环。KernelAgent 编译并测试候选内核失败结果本地反馈KernelFalcon 在首个数值正确的内核上提前退出。端到端组合与验证。融合内核替换原始算子后进行整模型等价性检查才被接受。底层是一个深度智能体架构——多阶段系统通过结构化问题来减少 LLM 失效模式显式任务分解将模糊目标转化为精确、工具就绪的子问题将复杂任务分解为子任务并分配给专门Agent处理确定性编排将控制逻辑留在 Python让 LLM 专注于认知而非让LLM决定工作流程并行搜索与早停高效探索多样解多个worker并行探索不同kernel实现一旦某个worker生成并通过verification的kernel立即终止其他worker以节省计算资源基于真实工具的每一步验证都针对真实编译器和硬件使用Triton编译器、GPU执行等真实工具而非模拟结构化状态持久化提示、日志与产物用于可审计与断点续跑这不仅是更干净的实现这是不同范式。不再问“LLM 能否解决此问题”而是问“KernelFalcon 如何塑造任务使 LLM 可能成功”结果是更广覆盖与更现实性能——无需膨胀规则集或牺牲语义。2.1 KernelFalcon 架构图 1KernelFalcon 的深度智能体架构以 Orchestrator 为中心协调整个工作流。规划负责任务分解与预算分配。上下文工程提供结构约束模板、指南。子智能体处理专业任务提取融合边界、生成 Triton 内核、组合端到端模块并执行验证。持久化内存存储产物用于调试与续跑。Orchestrator 委派给专家接收结构化错误反馈并在整个执行过程中保持状态。该架构体现深度智能体原则分层委派Orchestrator 将高级任务融合此模型分解为精确子问题提取子图、生成内核、组合结果并分配给专家智能体确定性控制规划与编排逻辑是显式 Python 代码而非 LLM 驱动——worker生命周期、超时和成功条件均以编程方式定义基于真实执行每个智能体都针对真实工具Triton 编译器、PyTorch 参考、GPU 执行验证而非模拟或 LLM 判断结果持久化状态所有中间结果、提示、日志和产物持久化到磁盘用于可审计、调试和跨会话续跑结构约束上下文工程将规则编码为模板和策略使正确性要求结构强制执行而非依赖提示2.2 流水线数据如何流经系统图 2多阶段工作流图显示 PyTorch 输入流经 FuserAgent创建可融合子图、ExtractorAgent生成 JSON 规范、并行 KernelAgent worker三个 Triton 框显示并发生成和 ComposerAgent缝合已验证内核。箭头表示数据流并标注中间表示。流水线包含四个不同阶段FuserAgent – 保持 Python 语义的代码到代码融合ExtractorAgent – 形状推理与合约生成Dispatcher KernelAgent – 并行 Triton 内核合成与验证ComposerAgent – 端到端集成与验证2.3 模块依赖关系图系统中的模块依赖关系如下2.4 KernelAgent 调用时序图0x03 架构分阶段详解3.1 阶段 1FuserAgent – 代码到代码融合算子融合FuserAgent 其实是在做算子融合。直接在PyTorch源代码上操作保持控制流和Python语义输出带明确子图边界的干净PyTorch模块传统编译器在融合分析期间将 PyTorch 降级为静态 IR丢失使调试困难的信息并在动态控制流上失效。FuserAgent 直接在 PyTorch 源代码上操作。Orchestrator 管理融合工作流生成具有显式子图边界的干净 PyTorch 模块。输入任意复杂度的原始 PyTorch 模型class Model(nn.Module): def forward(self, x): if x.sum() 0: x self.conv(x) x self.bn(x) x torch.tanh(x) x F.max_pool2d(x, 2) return self.norm(x)过程解析与分析提取操作序列、数据依赖关系和控制流边界识别融合机会找到可融合且保持语义的算子组生成融合模块创建带显式测试的干净 PyTorch 函数增量验证在继续前独立测试每个融合子图输出具有子图函数的融合 PyTorch 模块控制流保持完整# Fused module with control flow preserved class FusedModel(nn.Module): def __init__(self, channels: int): super().__init__() self.branch ConvBnTanhMaxPool(channelschannels) self.norm ChannelwiseNorm(channelschannels) def forward(self, x: torch.Tensor) - torch.Tensor: if x.sum() 0: # Control flow intact x self.branch(x) return self.norm(x)优势: 保持Python语义不丢失调试信息支持动态控制流为何有效为何有效这是因为 Orchestrator 生成精确规范下游阶段可执行验证。控制流if x.sum() 0保留在 Python 中——KernelFalcon 从不尝试将其编译掉。具体而言通过停留在 Python 源码层KernelFalcon 保留变量名、注释和完整控制流上下文。大多数传统编译器式融合器假设它们在优化静态数据流图因此动态 Python 侧控制流要么在追踪期间被折叠要么需要大量手动工作显式编码控制流因此与KernelFalcon 保持 Python if 并在其中插入融合子模块的提示驱动方法不同传统基于编译器的融合往往特化为追踪期间的单分支或需要大量手动努力显式编码控制流。当 TorchScript 降级到 SSA 形式时您精心命名的 hidden_states 变成 t0。当 torch.fx 追踪条件时未采取的分支直接消失。即使有 TorchDynamo/torch.compile虽然通过图中断和守卫更好地处理控制流它仍然特化图为观察到的路径——您的 if x.sum() 0 变成守卫检查要么重用缓存图要么触发重新编译。FuserAgent 采取不同方法KernelFalcon 保留 Python if 语句但融合每个分支内的操作。您仍可获得内核融合收益每个分支内的操作变成优化的 Triton 内核但控制流本身保持可读的 Python。这对现代 ML 模式至关重要TreeLSTM 递归解析树、早退网络在自信时退出、混合专家路由到不同子网络。而且关键的是当调试出错——当您的内核产生 NaN 或融合失败——您希望阅读 Python而非 IR。您希望看到系统实际尝试融合的内容用您编写的语言。深度智能体原则确定性控制平面所有编排——worker生命周期、超时、产物路径和成功时早退——都用 Python 实现。LLM 生成候选代码与元数据融合模块、子图 JSON、Triton 内核、组合内核控制器执行并验证输出。工作流如下Orchestrator 生成 N 个worker携带类型化 WorkerConfig流式传输日志等待队列上的获胜者取消其他worker并打包产物worker迭代渲染提示 → 流式 LLM → 提取 Python 块 → 按 SHA 去重 → 执行候选 → 若通过则发出获胜信号否则保存错误并重试无需手动 AST 解析或基于规则的融合检测——LLM 直接通过prompts 来融合代码然后 Python 通过执行验证Signature去重构建stable signatureops序列 shapes weights包含input shape、output shape、权重、layout、数据类型相同signature合并累加count避免重复生成相同配置的kernel参考Fuser/orchestrator.py、Fuser/worker.py、Fuser/runner.py、Fuser/prompting.py3.2 阶段 2ExtractorAgent – 子图边界推理本阶段是 提取器使用 LLM 分析融合代码并识别具有形状合约的精确子图边界。输入来自阶段 1 的融合 PyTorch 模块提取过程运行 orchestrator首先从阶段 1 获取融合代码分析算子融合后的PyTorch代码识别每个唯一的子图提取每个子图的shape信息input/output/weights shape为生成Triton kernel提供精确的shape信息提示 LLM要求 LLM 识别不同的子图函数推断形状并编目操作生成 JSON 规范LLM 生成带类型规范包含操作序列、形状和权重元数据去重与合并按稳定签名算子 形状 权重对子图分组聚合计数即构建stable signature (ops shapes weights) 用于去重相同shape的子图只需生成一次kernel输出子图规范的 JSON 数组如下[ { id: sg_conv_bn_tanh_pool_1, type: Conv2d_BN_Tanh_MaxPool, data_layout: NCHW, dtype: float32, ops: [ {op: conv2d, kernel_size: [3, 3], stride: [1, 1], padding: [1, 1], dilation: [1, 1], groups: 1, bias: false}, {op: batch_norm, eps: 1e-5, momentum: 0.1}, {op: tanh}, {op: max_pool2d, kernel_size: [2, 2], stride: [2, 2]} ], input_shape: [B, C_in, H, W], output_shape: [B, C_out, H_out, W_out], weights_original: { conv.weight: [C_out, C_in, 3, 3], batch_norm.weight: [C_out], batch_norm.bias: [C_out], running_mean: [C_out], running_var: [C_out] }, weights_fused: null, count: 1, where: Model.forward conditional branch, source: { module: FusedConvBnTanhPool, code: def forward(self, x):\n x F.conv2d(x, self.conv_w, stride1, padding1)\n x F.batch_norm(x, self.bn_rm, self.bn_rv, self.bn_w, self.bn_b, trainingFalse, epsself.eps)\n x torch.tanh(x)\n return F.max_pool2d(x, 2) } } ]此 JSON 成为 KernelAgent 的合约——显式、带类型且可验证。每个子图包括操作序列与算子特定参数输入与输出的形状合约跟踪融合与原始参数的权重元数据位置信息在模型中的位置、源模块去重子图的计数Orchestrator 控制工作流LLM 生成形状感知元数据去重处理跨模型的相同模式。参考Fuser/subgraph_extractor.py3.3 阶段 3Dispatcher KernelAgent – 并行 Triton 生成Dispatcher 为每个子图规范协调并行 Triton 内核生成。对于每个子图它创建一个带有worker池默认 4 个worker的全新 TritonKernelAgent。协调多个并行worker每个worker独立生成kernel并进行verificationEarly Stop第一个成功的worker停止其他worker图 3KernelAgent 生成并行worker采用多样采样参数生成 Triton 内核。每个候选者经历验证阶段语法、编译、数值。失败的候选者触发隔离错误反馈至其源worker——无上下文污染。首个通过所有阶段的候选者立即部署并取消剩余worker。实现并行探索与隔离上下文及早退。并行方法用相同提示但不同温度设置0.8、0.9、1.0 等生成 N 个内核种子。生成 N 个worker默认 4 个每个在其自己的工作目录中运行隔离的精炼循环。不同温度导致worker探索不同的优化策略——有些保守有些探索性。多个worker并行探索不同解决方案第一个成功的worker立即停止其他worker减少上下文消耗降低延迟提高鲁棒性关键机制本地错误反馈防止上下文污染每个worker维护自己的工作目录与每轮历史。当worker 2 遇到编译错误时只有worker 2 的下一次迭代看到它——错误上下文保持本地。worker将 kernel.py 与 test_kernel.py 写入其自己的 workdir通过子进程执行测试并独立跟踪结果。其他worker继续以干净上下文运行。早退节省计算集中管理器监控结果队列以获取完成事件。任何worker报告成功测试子进程退出码 0时管理器设置共享成功事件以通知所有worker停止然后加入/终止它们。首个通过所有验证阶段的内核获胜剩余worker立即终止。深度智能体原则基于真实工具的执行worker在隔离子进程中执行真实 Python/Triton 代码。每个worker生成 Triton 内核实现及其验证驱动然后作为独立子进程运行验证。Triton 的 JIT 编译器在测试驱动首次调用时自动将内核编译为 PTX——编译在测试执行期间隐式发生因此任何语法或编译错误以测试失败形式出现退出码非零。验证驱动将内核输出与 PyTorch 参考实现进行比较。成功意味着子进程退出码 0失败捕获 stderr 用于下一轮精炼。框架不抽象判断正确性——它只是执行代码并报告发生的情况。这种基于真实执行的基础消除了“模拟判断”问题即 LLM 可能幻觉损坏的代码有效。参考Fuser/dispatch_kernel_agent.py、triton_kernel_agent/manager.py、triton_kernel_agent/worker.py、triton_kernel_agent/agent.py3.4 阶段 4ComposerAgent – 端到端内核缝合Composer 使用 LLM 获取已验证的 Triton 内核并将其集成为完整、可测试的模块。输入来自阶段 3 的已验证 Triton 内核集合、subgraphs.json 与原始问题组合过程提示 LLM提供原始问题代码、紧凑子图摘要与成功内核文件生成集成LLM 合成具有所需结构的端到端 Triton 实现可选验证执行组合内核并通过 PASS/哨兵检测验证打包产物将组合实现与验证元数据写入输出目录生成结构LLM 产生完整的 Python 模块包含一个或多个 Triton 内核每个用 triton.jit 装饰实现融合操作。例如一个内核可能处理 conv-bn-tanh-pool 融合而另一个处理归一化。顶级包装器函数命名为 kernel_function(...) 匹配原始模型的输入。此包装器分配输出张量配置网格维度并按顺序启动 Triton 内核编排它们之间的数据流。自测试驱动测试函数播种随机数生成器构建原始 PyTorch 参考调用组合内核函数并使用来自提示指导的容差通过 torch.allclose 验证等价性。这些特定于数据类型的容差考虑了每种精度级别固有的舍入误差累积匹配 PyTorch 自己的内部测试标准。成功时它打印 “PASS” 并以代码 0 退出。验证过程Composer 确保单独正确的内核正确组合——验证整体等于部分之和。Python 通过执行组合模块作为子进程并检查 stdout 中的 “PASS” 以及退出码 0 来验证。这种基于真实执行的基础而非模拟或 LLM 判断的正确性。输出产物成功记录验证状态、计时与产物路径。完整的组合模块成为最终交付物准备部署或进一步参考Fuser/compose_end_to_end.py0x04 KernelAgent 项目入口点KernelAgent 项目的入口点分为两个层面用户界面层面通过命令行工具pipeline.py和auto_agent.py和 GUI 界面编程接口层面通过TritonKernelAgent和AutoKernelRouter类提供 API 访问其中auto_agent.py是推荐的主要入口点因为它会根据问题复杂度自动选择最优路径。4.1 命令行入口点KernelAgent 项目提供了多个命令行入口点主要通过 Fuser 模块访问管道执行入口pipeline.py 是主要的端到端管道执行入口python -m Fuser.pipeline \ --problem /abs/path/to/problem.py \ --extract-model gpt-5 \ --dispatch-model o4-mini \ --compose-model o4-mini \ --workers 4 --max-iters 5 \ --verify自动路由入口auto_agent.py 是自动路由决策的入口python -m Fuser.auto_agent \ --problem /abs/path/to/KernelBench/level1/19_ReLU.py \ --verify各个管道组件入口subgraph_extractor.py - 子图提取dispatch_kernel_agent.py - 内核调度compose_end_to_end.py - 端到端合成4.2 UI 入口点不同级别的 UI 入口Triton KernelAgent UI: kernel-agent 或 python scripts/triton_ui.pyFuser 编排 UI: fuser-ui 或 python scripts/fuser_ui完整管道 UI: pipeline-ui 或 python scripts/pipeline_ui4.3 编程接口入口点TritonKernelAgent 类agent.py 中的 TritonKernelAgent 类提供了编程接口from triton_kernel_agent import TritonKernelAgent agent TritonKernelAgent(num_workers4, max_rounds8, model_namegpt-5) result agent.generate_kernel( problem_descriptionImplement ReLU over a contiguous 1D tensor of length 1024 )AutoKernelRouter 类auto_agent.py 中的 AutoKernelRouter 类提供了自动路由功能from Fuser.auto_agent import AutoKernelRouter router AutoKernelRouter() result router.solve(problem_path)4.4 管道执行流程完整管道执行顺序提取阶段subgraph_extractor.py - 从问题文件提取子图分派阶段dispatch_kernel_agent.py - 为每个子图生成 Triton 内核合成阶段compose_end_to_end.py - 将子内核合成为完整的端到端内核自动路由决策流程静态分析解析问题文件的 AST检查操作符模式复杂度评估评估是否存在难以融合的操作如注意力机制、转置卷积等决策根据复杂度决定使用 KernelAgent 直接路径还是完整的 Fuser 管道执行执行选定的路径并在失败时可选择回退到另一条路径4.5 系统架构入口Fuser 编排器orchestrator.py 管理融合重构过程是 Fuser 系统的核心编排组件。验证运行器runner.py 负责安全地执行候选程序并验证其正确性。0x05 pipeline.py 的作用分析pipeline.py 是 KernelAgent 系统中的主协调器将复杂的多步骤过程封装为简单的端到端管道。它提供了一个高级接口隐藏了底层的复杂性使得用户可以轻松执行从 PyTorch 模型到优化 Triton 内核的完整转换过程。5.1 执行流程pipeline.py 是 KernelAgent 系统中的核心管道文件实现了“extract → dispatch → compose”的端到端工作流程。这是一个一站式管道运行器将三个关键步骤整合到一个统一的执行流程中。问题文件.py ↓ [subgraph_extractor.py] - 提取子图 ↓ [dispatch_kernel_agent.py] - 分发到 KernelAgent 生成 Triton 内核 ↓ [compose_end_to_end.py] - 组合最终内核 ↓ 最终结果对应函数如下run_pipeline () # 主函数 ├─ extract_subgraphs_to_json () # 提取阶段 │ ├─ extract_subgraphs_to_json () │ └─ orchestrator.run () ├─ dispatch_run () # 分发阶段 │ └─ run () └─ compose () # 组合阶段流程图pipeline.py 的流程图如下。参数处理阶段解析输入参数问题路径、模型名称、工作进程数等如果未指定 dispatch_model则根据问题级别自动选择模型管道执行阶段调用 extract_subgraphs_to_json 提取子图调用 dispatch_run 分发子图到 KernelAgent 生成内核调用 compose 合成最终内核返回包含所有结果的字典具体如下数据流转pipeline.py 的数据流转如下输入数据流problem_path输入问题文件 ↓ subgraphs_path提取的子图 JSON ↓ kernels_summary_path内核生成摘要 ↓ composition_result组合结果输出数据流return { run_dir: str (run_dir), # 运行目录 subgraphs: str (subgraphs_path), # 子图路径 kernels_summary: str (summary_path), # 内核摘要 composition: comp_res, # 组合结果 }5.2 三大核心步骤提取阶段Extract功能描述从输入的问题文件中提取子图使用 extract_subgraphs_to_json 函数生成包含子图信息的 JSON 文件依赖模块subgraph_extractor.py - 提取子图并转换为 JSONorchestrator.py - 运行融合重构prompting.py - LLM 提示构建code_extractor.py - 代码提取runner.py - 代码执行输入输出输入问题文件路径输出子图描述 JSON 文件和运行目录分派阶段Dispatch功能描述将提取的子图分派给 TritonKernelAgent使用 dispatch_run 函数并行生成 Triton 内核并行处理支持并发处理多个子图可以自动匹配子图数量--dispatch-jobs auto依赖模块dispatch_kernel_agent.py - 调度子图到 KernelAgenttriton_kernel_agent - Triton 内核生成引擎platform_config.py - 平台配置数据流转输入subgraphs_path来自提取阶段输出summary_path内核生成摘要组合阶段Compose功能描述将验证过的内核组合成单个 Triton 程序使用compose函数创建最终的端到端内核验证功能可选的验证步骤--verify标志确保组合后的内核功能正确依赖模块compose_end_to_end.py - 组合 Triton 内核code_extractor.py - 代码提取runner.py - 代码执行验证参数配置模型选择# 自动选择 Level 2/3 的默认模型为 GPT-5 if is_12 or is_13: dispatch_model gpt-5 else: dispatch_model o4-mini并发控制dispatch-jobs控制并行处理的作业数量workers控制提取阶段的并发数支持auto模式自动匹配子图数量输出管理目录结构管道在.fuse//目录下组织所有工件.fuse/run_id/ ├── subgraphs.json # 子图描述 ├── kernels_out/ # 生成的内核 ├── summary.json # 每个子图的成功/失败状态 └── compose_out/ └── composed_kernel.py # 最终 Triton 程序 └── summary.json # 组合元数据验证结果验证通过时输出验证日志确保最终结果的正确性5.3 依赖关系上游从subgraph_extractor.py获取子图提取功能中游调用dispatch_kernel_agent.py进行内核生成下游使用compose_end_to_end.py进行最终组合# 核心模块依赖 from .subgraph_extractor import extract_subgraphs_to_json from .dispatch_kernel_agent import run as dispatch_run from .compose_end_to_end import compose # 平台配置依赖 from triton_kernel_agent.platform_config import get_platform_choices依赖关系如下pipeline.py ├─ subgraph_extractor.py │ ├─ config.py │ ├─ orchestrator.py │ ├─ prompting.py │ ├─ code_extractor.py │ ├─ runner.py │ └─ utils.providers ├─ dispatch_kernel_agent.py │ ├─ triton_kernel_agent │ │ ├─ agent.py │ │ ├─ manager.py │ │ ├─ worker.py │ │ ├─ prompt_manager.py │ │ └─ platform_config.py └─ compose_end_to_end.py ├─ code_extractor.py ├─ runner.py └─ utils.providers5.4 使用场景手动执行当需要对模型或并发性进行显式控制时使用python -m Fuser.pipeline \ --problem /path/to/problem.py \ --extract-model gpt-5 \ --dispatch-model o4-mini \ --compose-model o4-mini \ --workers 4 \ --max-iters 5 \ --verify平台支持支持 CUDA 平台默认支持 XPU 平台--target-platform xpu错误处理统一的异常处理机制系统退出码管理详细的错误信息输出5.5 总结