浏览器内的推理引擎WASM 端侧 AI 推理的架构与实现一、云端推理的延迟困境为什么 AI 需要跑到用户设备上大模型推理依赖云端 GPU 集群每次请求需要经过网络传输、排队等待和 GPU 计算三个阶段。在弱网环境下仅网络往返延迟就可能超过 500ms对于实时交互场景语音助手、AR 滤镜、智能输入法而言这个延迟不可接受。此外将用户数据上传到云端推理存在隐私风险——医疗影像、金融文档等敏感数据不应离开用户设备。端侧推理是解决延迟和隐私问题的直接方案但面临两个工程挑战模型体积和计算性能。一个 7B 参数的量化模型仍需 4GB 存储浏览器环境无法承受而 JavaScript 的动态类型和 JIT 编译无法提供足够的计算吞吐。WebAssembly 提供了第三条路径接近原生的执行性能、沙箱化的安全隔离、跨浏览器的标准化运行时。通过 WASM WebGPU 的组合可以在浏览器中运行量化后的轻量模型实现端侧 AI 推理。二、WASM 线性内存与计算沙箱端侧推理的运行时基础2.1 WASM 执行模型WASM 模块运行在虚拟机提供的线性内存Linear Memory中这是一块可增长的字节数组所有数据通过偏移量访问。WASM 代码无法直接访问宿主环境浏览器的内存所有交互通过导入/导出函数完成。这种沙箱模型保证了 WASM 代码的安全性——即使推理引擎存在漏洞也无法越权访问用户数据。graph TB subgraph 浏览器宿主环境 A[JavaScript 运行时] --|调用导出函数| B[WASM 模块] B --|写入线性内存| C[Linear Memory] A --|读取线性内存| C B --|调用导入函数| D[WebGL/WebGPU API] D --|GPU 计算| E[GPU Shader] end subgraph WASM 沙箱内部 B -- F[模型权重br/f16 量化] B -- G[推理计算图] B -- H[中间激活 Buffer] F -- C G -- C H -- C end subgraph 数据流 I[用户输入] --|JS 传入| C C --|WASM 计算| G G --|结果写回| C C --|JS 读出| J[推理结果] end2.2 Rust 到 WASM 的编译链Rust 通过wasm32-unknown-unknown目标直接编译为 WASM 字节码无需 C 中间层。wasm-bindgen工具生成 JS 绑定代码处理类型映射Rust 的VecT对应 JS 的Float32Array、内存管理和函数导出。wasm-pack将编译产物打包为 npm 包支持直接在 Web 项目中引用。2.3 内存管理的关键约束WASM 线性内存的增长通过memory.grow指令实现每次增长一个 Page64KB。频繁的内存增长会导致性能抖动因此推理引擎应在初始化时预分配足够的内存。Rust 的默认分配器dlmalloc在 WASM 中可用但性能一般推荐使用wee_alloc减小代码体积或使用lol_alloc获得更快的分配速度。三、端侧推理引擎的 WASM 实现3.1 模型加载与量化推理use wasm_bindgen::prelude::*; /// WASM 端侧推理引擎 /// 编译目标wasm32-unknown-unknown #[wasm_bindgen] pub struct WasmInferenceEngine { // 模型权重使用 f16 量化减少内存占用 weights: Vecu16, // 推理中间激活 Buffer预分配避免运行时增长 activations: Vecf32, // 模型配置 hidden_size: usize, num_layers: usize, } #[wasm_bindgen] impl WasmInferenceEngine { /// 从二进制权重数据创建推理引擎 /// 权重在 JS 端通过 fetch 加载后传入 #[wasm_bindgen(constructor)] pub fn new( weights_data: [u8], hidden_size: usize, num_layers: usize, ) - ResultWasmInferenceEngine, JsValue { // 预分配激活 Buffer避免推理过程中的内存增长 let activation_size hidden_size * 4; // 4 倍隐藏层大小用于中间结果 let activations vec![0.0f32; activation_size * num_layers]; // 将字节流解码为 f16 权重 if weights_data.len() % 2 ! 0 { return Err(JsValue::from_str(权重数据长度必须是偶数字节)); } let weights: Vecu16 weights_data .chunks_exact(2) .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) .collect(); Ok(Self { weights, activations, hidden_size, num_layers, }) } /// 执行单步推理 /// input_ids: 输入 token ID 序列 /// 返回下一个 token 的概率分布 pub fn forward(mut self, input_ids: [u32]) - ResultVecf32, JsValue { if input_ids.is_empty() { return Err(JsValue::from_str(输入序列不能为空)); } let seq_len input_ids.len(); let vocab_size self.weights.len() / (self.hidden_size * self.num_layers); // Token Embedding从权重中查找对应向量 // 使用 f16 到 f32 的转换进行计算 let mut hidden_state vec![0.0f32; self.hidden_size]; for token_id in input_ids { let embed_offset token_id as usize * self.hidden_size; if embed_offset self.hidden_size self.weights.len() { return Err(JsValue::from_str(Token ID 超出词表范围)); } for i in 0..self.hidden_size { hidden_state[i] f16_to_f32(self.weights[embed_offset i]); } } // 逐层 Transformer 计算 for layer in 0..self.num_layers { let layer_offset layer * self.hidden_size * 4; // QKVO 四组权重 self.attention_forward( mut hidden_state, layer_offset, )?; } // 输出投影隐藏状态 - 词表概率分布 let mut logits vec![0.0f32; vocab_size]; // 矩阵乘法hidden_state output_weight for v in 0..vocab_size.min(32000) { // 限制词表大小防止内存溢出 let weight_offset v * self.hidden_size; let mut sum 0.0f32; for h in 0..self.hidden_size { let w f16_to_f32(self.weights[weight_offset h]); sum hidden_state[h] * w; } logits[v] sum; } // Softmax 归一化 let max_logit logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exp_sum: f32 logits.iter().map(|x| (x - max_logit).exp()).sum(); for logit in logits.iter_mut() { *logit (*logit - max_logit).exp() / exp_sum; } Ok(logits) } /// 获取引擎的内存占用字节 pub fn memory_usage(self) - usize { self.weights.len() * 2 self.activations.len() * 4 } } impl WasmInferenceEngine { /// 单头注意力计算 fn attention_forward( self, hidden: mut [f32], weight_offset: usize, ) - Result(), JsValue { let hs self.hidden_size; if weight_offset hs * 4 self.weights.len() { return Err(JsValue::from_str(权重偏移量超出范围)); } // Q hidden W_q let mut q vec![0.0f32; hs]; for i in 0..hs { q[i] hidden[i] * f16_to_f32(self.weights[weight_offset i]); } // K hidden W_k let mut k vec![0.0f32; hs]; for i in 0..hs { k[i] hidden[i] * f16_to_f32(self.weights[weight_offset hs i]); } // V hidden W_v let mut v vec![0.0f32; hs]; for i in 0..hs { v[i] hidden[i] * f16_to_f32(self.weights[weight_offset hs * 2 i]); } // Attention: softmax(Q * K^T / sqrt(d)) * V let scale (hs as f32).sqrt().recip(); let mut attn_weights vec![0.0f32; hs]; let mut max_val f32::NEG_INFINITY; for i in 0..hs { attn_weights[i] q[i] * k[i] * scale; max_val max_val.max(attn_weights[i]); } let exp_sum: f32 attn_weights.iter().map(|x| (x - max_val).exp()).sum(); for w in attn_weights.iter_mut() { *w (*w - max_val).exp() / exp_sum; } // 输出 attn_weights * V for i in 0..hs { hidden[i] attn_weights[i] * v[i]; } Ok(()) } } /// IEEE 754 半精度浮点数转单精度 /// 在 WASM 中使用软件模拟因为 WASM 规范不原生支持 f16 fn f16_to_f16(half: u16) - f32 { let sign (half 15) 1; let exponent (half 10) 0x1F; let mantissa half 0x3FF; match exponent { 0 { if mantissa 0 { // 零 if sign 1 { -0.0 } else { 0.0 } } else { // 非正规数 let f (mantissa as f32) * 2.0f32.powi(-24); if sign 1 { -f } else { f } } } 31 { if mantissa 0 { // 无穷 f32::INFINITY.copysign(if sign 1 { -1.0 } else { 1.0 }) } else { // NaN f32::NAN } } _ { // 正规数 let f 2.0f32.powi(exponent as i32 - 15) * (1.0 mantissa as f32 / 1024.0); if sign 1 { -f } else { f } } } } // 修正函数名 fn f16_to_f32(half: u16) - f32 { f16_to_f16(half) }3.2 JS 端集成与内存交互// JavaScript 端加载与调用 import init, { WasmInferenceEngine } from ./pkg/inference_engine.js; async function runInference() { // 初始化 WASM 模块 await init(); // 从服务器加载量化模型权重 const response await fetch(/models/tinyllama-q4.wasm); const weightsBuffer await response.arrayBuffer(); const weightsData new Uint8Array(weightsBuffer); // 创建推理引擎 const engine new WasmInferenceEngine( weightsData, 512, // hidden_size 4, // num_layers ); // 执行推理 const inputIds new Uint32Array([1, 1543, 566, 1024]); const logits engine.forward(inputIds); // 解码输出取概率最高的 token let maxIdx 0; let maxVal -Infinity; for (let i 0; i logits.length; i) { if (logits[i] maxVal) { maxVal logits[i]; maxIdx i; } } console.log(预测下一个 token: ${maxIdx}, 概率: ${maxVal}); console.log(引擎内存占用: ${engine.memory_usage()} 字节); }四、WASM 推理的性能边界与架构妥协WASM 端侧推理存在几个硬性限制需要在架构设计时充分考虑。计算性能天花板。WASM 的计算性能约为原生代码的 70%-90%纯 CPU 计算密集型场景但缺乏 SIMD 的完整支持WASM SIMD 128 目前已广泛支持但 256/512 位宽尚在提案阶段。对于 Transformer 模型中的矩阵乘法SIMD 向量化是性能的关键128 位 SIMD 意味着每次处理 4 个 f32而 AVX-512 可以处理 16 个。这导致 WASM 推理的吞吐量约为原生代码的 1/3 到 1/2。内存限制。浏览器对 WASM 线性内存有上限Chrome 默认 4GBFirefox 约 2GB。一个 1.5B 参数的 Q4 量化模型需要约 1GB 权重内存加上推理中间激活总内存接近 2GB。更大的模型在浏览器中无法运行。解决方案是模型切片将权重分片按需加载或更激进的量化Q2、1.5-bit但都会影响推理质量。多线程限制。WASM 的SharedArrayBuffer要求页面配置 COOP/COEP 安全头许多网站无法满足。没有共享内存WASM 只能使用单线程计算无法利用多核 CPU。WebGPU 可以部分弥补这个缺陷将矩阵运算卸载到 GPU但 WebGPU 的计算 Shader 性能也低于原生 CUDA。适用边界。WASM 端侧推理最适合以下场景模型参数量 1.5B 的轻量级任务文本分类、NER、小规模生成、对延迟极度敏感且无法容忍网络往返的交互场景、隐私敏感数据不能离开用户设备的合规要求。不适合的场景包括大参数量模型的生成任务、需要 GPU 级吞吐量的批量推理、对计算精度有严格要求的科学计算。五、总结WASM 端侧推理通过浏览器沙箱和线性内存模型在保护用户隐私的前提下实现了接近原生的推理性能。本文从 WASM 执行模型出发实现了基于 Rust 编译的端侧推理引擎包含 f16 量化权重管理、注意力计算和 JS 端集成。落地路线建议第一步使用wasm-pack build --target web编译推理引擎为 npm 包在本地浏览器中验证基础推理功能第二步对计算热点矩阵乘法、Softmax使用 WASM SIMD intrinsics 优化通过wasm-simd特性门启用第三步引入 WebGPU 后端将矩阵运算卸载到 GPU使用wgpucrate 的 WASM 后端第四步对模型进行 Q4 量化压缩使用safetensors格式存储权重通过fetchReadableStream实现流式加载。