RMSNorm:LLM 里的归一化为什么换成了这个
本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。
LayerNorm 在大模型里被 RMSNorm 替换了。LayerNorm 做了减均值再除方差,RMSNorm 只除均方根——去掉了减均值那一步。少一次 Reduce 操作,在量产推理里省掉 15-20% 的归一化时间。
LayerNorm 的计算流程
# LayerNorm——先算均值、再算方差、再归一化deflayer_norm(x,gamma,beta,eps=1e-6):""" x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放 beta: [hidden_dim] —— 可学习偏移 """# Step 1: 算均值——一次 Reducemean=x.mean(dim=-1,keepdim=True)# [b, s, 1]# Step 2: 算方差——二次 Reducevar=((x-mean)**2).mean(dim=-1,keepdim=True)# [b, s, 1]# Step 3: 归一化x_norm=(x-mean)/torch.sqrt(var+eps)# 减均值再除方差# Step 4: 缩放 + 偏移returnx_norm*gamma+beta# 每次 LayerNorm 做 2 次全张量 Reduce + 1 次逐元素 Scale# hidden_dim=4096 时:每次需要读 4096 个值 2 次 + 写 1 次LayerNorm 去均值那一步在 NLP 里不是必须的——Transformer 的残差连接已经做了中心化。RMSNorm 砍掉这步,只做 Scale。
RMSNorm 的数学差异
# RMSNorm——只除 RMS,过均值defrms_norm(x,gamma,eps=1e-6):""" x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放(无 beta) RMSNorm(x) = x / RMS(x) * gamma RMS(x) = sqrt(mean(x^2) + eps) """# Step 1: 算均方——只有 1 次 Reducerms=torch.sqrt((x**2).mean(dim=-1,keepdim=True)+eps)# [b, s, 1]# Step 2: 归一化——不做减均值,直接除x_norm=x/rms# Step 3: 缩放——有 gamma,没有 betareturnx_norm*gamma# 跟 LayerNorm 的差异:# 1. 没有 mean = x.mean() → 省一次全张量 Reduce# 2. 没有 x - mean → 省一次逐元素减法# 3. 没有 beta → 省一次加法# 统计上:RMSNorm 收敛到跟 LayerNorm 同等精度Llama 全系列用 RMSNorm——Llama-3.1-405B 也不例外。用 RMSNorm 替代 LayerNorm 后,405B 模型单次 Forward 省掉 2 次大 Tensor 操作。
CANN 上的 RMSNorm 融合实现
// Ascend C 实现的 RMSNorm——融合了 Pow + Reduce + Sqrt + DivclassRMSNormKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 一次性搞清 RMSNorm 的 Tile 策略constinttile_size=1024;// 每次处理 1024 维constinttiles_per_block=hidden_dim/tile_size;AscendC::LocalTensor<float>x_local;AscendC::LocalAlloc<float>(x_local,tile_size);AscendC::LocalTensor<float>sq_local;AscendC::LocalAlloc<float>(sq_local,tile_size);// 分 Tile 计算 x^2 并在片上做部分累加// 这样不用把所有 x 搬完再算 RMS——减少 L1⇄DDR 往返floatpartial_sum=0.0f;for(intt=0;t<tiles_per_block;t++){// 搬一个 Tile 到 L1 Bufferinttile_offset=t*tile_size;AscendC::DataCopy(x_local,gm_x+tile_offset,tile_size);// x^2——用了 Vec 单元的通用计算指令AscendC::Mul(sq_local,x_local,x_local);// 片上的局部 ReduceSum——不走 DDRAscendC::ReduceAdd(partial_sum,sq_local,tile_size);}// RMS = sqrt(partial_sum / hidden_dim + eps)floatrms=sqrtf(partial_sum/hidden_dim+1e-6f);floatinv_rms=1.0f/rms;// 用乘法代替除法// 第二遍 Tile:x * inv_rms * gammafor(intt=0;t<tiles_per_block;t++){inttile_offset=t*tile_size;AscendC::DataCopy(x_local,gm_x+tile_offset,tile_size);// 加载 gamma 参数AscendC::LocalTensor<float>gamma_local;AscendC::LocalAlloc(gamma_local,tile_size);AscendC::DataCopy(gamma_local,gm_gamma+tile_offset,tile_size);// x / rms * gamma——一次合并完成AscendC::Mul(x_local,x_local,inv_rms);AscendC::Mul(x_local,x_local,gamma_local);// 写回AscendC::DataCopy(gm_out+tile_offset,x_local,tile_size);}}};比 LayerNorm 少了一个x - mean和一个 Reduce,多出来的算力可以给 Batch 里的下一个请求。实测 Llama-7B 上把 Norm 替换为 RMSNorm 后,Decode 速度从 28 tok/s 提到 32 tok/s。
参考仓库
RMSNorm 等 Transformer 算子
神经网络基础算子库
