transformers 5.7.0 的 Flash Attention 实现有 bugs_aux为None。File/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py, line1776,in_wrapped_call_implreturnself._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py, line1787,in_call_implreturnforward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py, line293,inforward hidden_states, _self.self_attn(^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py, line1776,in_wrapped_call_implreturnself._call_impl(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py, line1787,in_call_implreturnforward_call(*args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py, line231,inforward attn_output, attn_weightsattention_interface(^^^^^^^^^^^^^^^^^^^^ File/usr/local/lib/python3.12/dist-packages/transformers/integrations/flash_attention.py, line84,inflash_attention_forwards_auxs_aux.to(query.dtype),# FA only accepts half precision^^^^^^^^ AttributeError:NoneTypeobject has no attributeto快速修复修改 YAML 配置打开/mnt/workspace/training_configs_1024/binding_paired.yaml将flash_attn从fa2改为sdpa# # AMD ROCm mi300 单卡适配# flash_attn:sdpa# fa2 在 transformers 5.7.0 有 bug改用 sdpabf16:true# 现在可以正常用 bf16 了fp16:false为什么改transformers 5.7.0的flash_attention_forward中s_aux处理有 bug在 ROCm 上会触发NoneType错误sdpaScaled Dot Product Attention是 PyTorch 原生实现在 ROCm 上稳定且高效重新运行训练llamafactory-cli train /mnt/workspace/training_configs_1024/binding_paired.yaml如果sdpa也有问题可以完全禁用flash_attn:disabled# 完全禁用使用默认 attention补充如果想用 Flash Attention 2需要手动打补丁修复 transformers 5.7.0 的 bug# 查看有 bug 的文件sed-n80,90p/usr/local/lib/python3.12/dist-packages/transformers/integrations/flash_attention.py确认 bug 位置直接打补丁修复1. 一键修复cd/usr/local/lib/python3.12/dist-packages/transformers/integrations/# 备份原文件cpflash_attention.py flash_attention.py.bak# 修复 s_aux 为 None 的 bugsed-is/s_auxs_aux.to(query.dtype), # FA only accepts half precision/s_aux(s_aux.to(query.dtype) if s_aux is not None else None), # FA only accepts half precision/flash_attention.py# 验证修复结果sed-n84,86pflash_attention.py2. 验证修复成功修复后第84行应该显示s_aux(s_aux.to(query.dtype)ifs_auxisnotNoneelseNone),# FA only accepts half precision打补丁后就可以继续用flash_attn: fa2了。建议先用flash_attn: sdpa跑起来确认训练正常后再考虑是否需要打补丁恢复 fa2。3. 重新运行训练llamafactory-cli train /mnt/workspace/training_configs_1024/2.yamlYAML 配置保持flash_attn: fa2和bf16: true不变即可。如果sed替换没生效特殊字符问题用 Python 修复python-c path /usr/local/lib/python3.12/dist-packages/transformers/integrations/flash_attention.py with open(path, r) as f: content f.read() old s_auxs_aux.to(query.dtype), # FA only accepts half precision new s_aux(s_aux.to(query.dtype) if s_aux is not None else None), # FA only accepts half precision if old in content: content content.replace(old, new) with open(path, w) as f: f.write(content) print(✅ 补丁已应用) else: print(❌ 未找到目标代码可能已被修复或行内容不匹配) print(请检查文件内容) 打完补丁后直接跑训练即可python3EOF import re path /usr/local/lib/python3.12/dist-packages/transformers/integrations/flash_attention.py with open(path, r) as f: lines f.readlines() patched False for i, line in enumerate(lines): if re.search(rs_aux\s*\s*s_aux\.to\(query\.dtype\), line): indent line[:len(line) - len(line.lstrip())] lines[i] indent s_aux(s_aux.to(query.dtype) if s_aux is not None else None), # FA only accepts half precision\n patched True print(f✅ 第 {i1} 行已修复:) print(f 修改前: {line.rstrip()}) print(f 修改后: {lines[i].rstrip()}) break if patched: with open(path, w) as f: f.writelines(lines) print(\n补丁已成功应用可以重新训练了。) else: print(❌ 未找到匹配行打印第80-90行供排查) for i in range(79, min(90, len(lines))): print(f 行{i1}: {repr(lines[i])}) EOF如果这个还是不行直接用最简单的方案先不改代码了直接改 YAML 配置用sdpa# 修改配置sed-is/flash_attn: fa2/flash_attn: sdpa//mnt/workspace/training_configs_1024/binding_paired.yaml# 确认修改grepflash_attn /mnt/workspace/training_configs_1024/binding_paired.yaml# 直接跑llamafactory-cli train /mnt/workspace/training_configs_1024/binding_paired.yamlsdpa在 MI300X 上性能也很好先让训练跑起来最重要__ ____________________80-90__________ __80: use_top_left_mask_use_top_left_mask,\n__81: target_dtypetarget_dtype,\n__82: attn_implementationmodule.config._attn_implementation,\n__83: layer_idxmodule.layer_idx if hasattr(module, layer_idx) else None,\n__84: s_aux(s_aux.to(query.dtype) if s_aux is not None else None), # FA only accepts half precision\n__85: **kwargs,\n__86: )\n__87:\n__88: return attn_output, None\n补丁已经打上了看第84行s_aux(s_aux.to(query.dtype)ifs_auxisnotNoneelseNone),# FA only accepts half precision这说明之前你执行的某个修复命令已经成功了只是没打印出成功提示可能是因为编码或输出问题。现在直接重新运行训练即可llamafactory-cli train /mnt/workspace/training_configs_1024/binding_paired.yaml配置保持flash_attn: fa2和bf16: true不变。s_aux的NoneTypebug 已经修复了