当前位置: 首页 > news >正文

RMSNorm:LLM 里的归一化为什么换成了这个

本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。

LayerNorm 在大模型里被 RMSNorm 替换了。LayerNorm 做了减均值再除方差,RMSNorm 只除均方根——去掉了减均值那一步。少一次 Reduce 操作,在量产推理里省掉 15-20% 的归一化时间。

LayerNorm 的计算流程

# LayerNorm——先算均值、再算方差、再归一化deflayer_norm(x,gamma,beta,eps=1e-6):""" x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放 beta: [hidden_dim] —— 可学习偏移 """# Step 1: 算均值——一次 Reducemean=x.mean(dim=-1,keepdim=True)# [b, s, 1]# Step 2: 算方差——二次 Reducevar=((x-mean)**2).mean(dim=-1,keepdim=True)# [b, s, 1]# Step 3: 归一化x_norm=(x-mean)/torch.sqrt(var+eps)# 减均值再除方差# Step 4: 缩放 + 偏移returnx_norm*gamma+beta# 每次 LayerNorm 做 2 次全张量 Reduce + 1 次逐元素 Scale# hidden_dim=4096 时:每次需要读 4096 个值 2 次 + 写 1 次

LayerNorm 去均值那一步在 NLP 里不是必须的——Transformer 的残差连接已经做了中心化。RMSNorm 砍掉这步,只做 Scale。

RMSNorm 的数学差异

# RMSNorm——只除 RMS,过均值defrms_norm(x,gamma,eps=1e-6):""" x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放(无 beta) RMSNorm(x) = x / RMS(x) * gamma RMS(x) = sqrt(mean(x^2) + eps) """# Step 1: 算均方——只有 1 次 Reducerms=torch.sqrt((x**2).mean(dim=-1,keepdim=True)+eps)# [b, s, 1]# Step 2: 归一化——不做减均值,直接除x_norm=x/rms# Step 3: 缩放——有 gamma,没有 betareturnx_norm*gamma# 跟 LayerNorm 的差异:# 1. 没有 mean = x.mean() → 省一次全张量 Reduce# 2. 没有 x - mean → 省一次逐元素减法# 3. 没有 beta → 省一次加法# 统计上:RMSNorm 收敛到跟 LayerNorm 同等精度

Llama 全系列用 RMSNorm——Llama-3.1-405B 也不例外。用 RMSNorm 替代 LayerNorm 后,405B 模型单次 Forward 省掉 2 次大 Tensor 操作。

CANN 上的 RMSNorm 融合实现

// Ascend C 实现的 RMSNorm——融合了 Pow + Reduce + Sqrt + DivclassRMSNormKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 一次性搞清 RMSNorm 的 Tile 策略constinttile_size=1024;// 每次处理 1024 维constinttiles_per_block=hidden_dim/tile_size;AscendC::LocalTensor<float>x_local;AscendC::LocalAlloc<float>(x_local,tile_size);AscendC::LocalTensor<float>sq_local;AscendC::LocalAlloc<float>(sq_local,tile_size);// 分 Tile 计算 x^2 并在片上做部分累加// 这样不用把所有 x 搬完再算 RMS——减少 L1⇄DDR 往返floatpartial_sum=0.0f;for(intt=0;t<tiles_per_block;t++){// 搬一个 Tile 到 L1 Bufferinttile_offset=t*tile_size;AscendC::DataCopy(x_local,gm_x+tile_offset,tile_size);// x^2——用了 Vec 单元的通用计算指令AscendC::Mul(sq_local,x_local,x_local);// 片上的局部 ReduceSum——不走 DDRAscendC::ReduceAdd(partial_sum,sq_local,tile_size);}// RMS = sqrt(partial_sum / hidden_dim + eps)floatrms=sqrtf(partial_sum/hidden_dim+1e-6f);floatinv_rms=1.0f/rms;// 用乘法代替除法// 第二遍 Tile:x * inv_rms * gammafor(intt=0;t<tiles_per_block;t++){inttile_offset=t*tile_size;AscendC::DataCopy(x_local,gm_x+tile_offset,tile_size);// 加载 gamma 参数AscendC::LocalTensor<float>gamma_local;AscendC::LocalAlloc(gamma_local,tile_size);AscendC::DataCopy(gamma_local,gm_gamma+tile_offset,tile_size);// x / rms * gamma——一次合并完成AscendC::Mul(x_local,x_local,inv_rms);AscendC::Mul(x_local,x_local,gamma_local);// 写回AscendC::DataCopy(gm_out+tile_offset,x_local,tile_size);}}};

比 LayerNorm 少了一个x - mean和一个 Reduce,多出来的算力可以给 Batch 里的下一个请求。实测 Llama-7B 上把 Norm 替换为 RMSNorm 后,Decode 速度从 28 tok/s 提到 32 tok/s。

参考仓库

RMSNorm 等 Transformer 算子

神经网络基础算子库

http://www.cnnetsun.cn/news/2534026.html

相关文章:

  • Midjourney颗粒感失控?3分钟定位根源:从--stylize参数误用到--quality陷阱的9个致命误区
  • 政府科技管理部门如何推动区域创新?
  • TIPTOP ERP二次开发实战:从服务器拉取程序到本地Genero Studio调试的完整流水线
  • Boss-Key:职场隐私保护终极指南,一键隐藏窗口的智能解决方案
  • 专业级EdgeRemover配置指南:5种高效部署方案深度解析
  • ROS2 TurtleBot3仿真SLAM导航:RVIZ不显示机器人模型的终极排查与修复指南
  • Node.js后端服务如何集成多模型能力并管理API成本
  • 告别内存爆炸!用UNETR搞定3D医学图像分割,保姆级PyTorch+MONAI复现教程
  • 别再死记硬背!用Python+NetworkX可视化理解拉普拉斯矩阵的5个核心性质
  • 深度解析:xiaozhi-esp32-server语音交互系统的架构设计与工程实践
  • 用C语言指针实战分析双色球历史数据:一个C语言初学者的趣味项目
  • 独立开发者如何借助 Taotoken 低成本实验多种大模型
  • 【收藏干货】2026 版大模型推理底层原理拆解!吃透 Prefill/Decode 与 vLLM 核心优化
  • Qt QLineEdit的editingFinished信号为啥按回车会触发两次?一个弹窗引发的‘血案’与三种修复方案
  • HLK-LD1125H-24G雷达模块配置避坑指南:手把手教你调参实现最佳检测效果
  • 别再傻傻分不清了!一文搞懂Windows 11/10下搜狗/微软拼音输入法的全角半角切换(含快捷键设置)
  • Windows右键菜单终极清理指南:用ContextMenuManager告别杂乱,重获高效桌面
  • 从POS机到你的钱包:拆解一次刷卡背后的ISO8583协议‘暗语’
  • 从‘最大熵’到‘瑞丽熵’:手把手推导RDP公式,理解差分隐私的理论进化
  • 开始转到拼多多上面销售APP
  • 爬虫/API调用老出错?可能是你没用好requests库的raise_for_status方法
  • 从激光雷达到PET扫描:拆解SiPM在不同应用场景下的电路设计“避坑”指南
  • 不止于下载:用Charles抓包分析微信视频号的传输协议与缓存策略
  • 教育AI Agent部署失败率高达63%?(一线校长不愿公开的7个致命盲区)
  • 分享今日日常
  • 别再手动刷新了!用HomePage的YAML配置打造你的智能服务仪表盘
  • STM32F103C8T6上实现INA3221三路电流电压监控(附完整LL库驱动代码)
  • CANN-昇腾NPU-推理服务高可用-怎么做到99.99%可用性
  • 使用Taotoken聚合API为创业团队优化AI开发成本与效率
  • AI采购决策再不能靠感觉!Claude ROI模型实测数据:平均12.7天回本,但93%团队用错了基准线