CANN-RotaryEmbedding-昇腾NPU上位置编码为什么该融进Attention
RoPE(Rotary Position Embedding)在 Llama、Qwen 这些主流大模型里是标配。大部分人写模型的时候,RoPE 是一个独立的算子:先算旋转角度,再对 Q、K 做复数旋转,然后才进 Attention。在昇腾NPU上跑,这个"独立算子"每次都要把 Q、K 从 Cube 单元搬出来、做完旋转再搬回去——白跑两趟显存。ops-transformer 仓库里的 RotaryEmbedding 融合算子,直接把旋转操作塞进 Attention kernel,省掉这两趟搬运。
RoPE 的标准实现为什么慢
标准 RoPE 的计算过程:
1. 根据 position 和 freq 计算旋转角度 cos/sin → Vector 单元 2. 把 Q 拆成相邻两两一对 → 内存搬运 3. 对每对做复数旋转:[q0, q1] → [q0·cos-q1·sin, q0·sin+q1·cos] → Vector 单元 4. 对 K 重复步骤 2-3 → 内存搬运 5. 把旋转后的 Q、K 送入 Attention → Cube 单元步骤 2 和 4 的"拆对"操作,以及 Q/K 在 Vector 和 Cube 之间的来回搬运,是性能黑洞。在序列长度 8K、hidden_dim 4096 的场景下,Q 和 K 各有 32MB,来回搬运一遍就是 128MB 的显存读写。昇腾NPU的 HBM 带宽虽然不低,但跟 Cube 单元的计算速度比起来,搬运永远是瓶颈。
融合算子做了什么
把旋转操作嵌入 Attention 的分块计算中。
FlashAttention 本身就是把 Q、K、V 切块后在片上缓存计算。RotaryEmbedding 融合算子的做法是:在 Q、K 的分块进入 Softmax 之前,直接在片上缓存里完成旋转。
原来:Q_tile → 写回显存 → 读出 → 旋转 → 写回 → 再读出 → Attention 现在:Q_tile → 片上缓存直接旋转 → Attention旋转计算本身很简单,就是几组乘加操作,Vector 单元一个 cycle 就搞定。关键是省掉了那两次显存读写。
另外,cos/sin 的预计算也做了优化。标准实现里,每个序列位置的旋转角度是θ_i = 1 / (10000^(2i/d)),需要逐位置计算。融合算子把这些值预计算好存在常量缓存里,不同 batch 共享同一份。这在小 batch 推理时特别有用——batch=1 的场景下,预计算的开销从 O(seq_len × head_dim) 降到几乎为零。
收益数据
在 Atlas 800I A2 上,Llama2-7B 的推理性能对比:
| 配置 | 首 token 延迟 (ms) | 吞吐 (tokens/s) |
|---|---|---|
| RoPE 独立算子 + FlashAttention | 86 | 2,840 |
| RotaryEmbedding 融合 + FlashAttention | 72 | 3,190 |
首 token 延迟降了 16%。这个数字看起来不算炸裂,但要知道 RoPE 本身只占 Attention 计算的 5-8%,融合带来的收益主要来自减少显存搬运,而不是减少计算量。
长序列场景收益更明显。序列长度从 2K 拉到 16K,融合算子的延迟增长是线性的,独立算子因为显存搬运量翻倍,延迟增长接近 O(N^1.3)。
怎么用
PyTorch 场景下不需要改模型代码,CANN 的框架适配器自动处理:
importtorch_npu# CANN 自动把 RoPE + SDPA 路由到融合算子# 前提:模型用 F.scaled_dot_product_attention 接口q_rotated=apply_rotary_emb(q,cos,sin)# 这里会被融合进 Attentionout=torch.nn.functional.scaled_dot_product_attention(q_rotated,k_rotated,v)如果你手动调用了torch_npu.npu.flash_attention,需要显式传入 cos/sin:
out=torch_npu.npu.flash_attention(q,k,v,cos=cos_tensor,# 旋转角度余弦,提前算好sin=sin_tensor,# 旋转角度正弦rotary_mode=1# 1=融合旋转,0=不融合)踩坑提醒
融合算子要求 cos/sin 的 dtype 和 Q 一致。Llama 官方权重里的 RoPE 参数是 float32,但推理时 Q 通常转成了 float16。dtype 不匹配不会报错——它会默默 fallback 到独立算子。检查日志里有没有rotary_fusion_fallback关键字。
还有一个:融合算子目前不支持动态序列长度。如果你的推理服务支持变长输入(比如 batch 里不同样本长度不同),需要 padding 到统一长度再传入。CANN 8.5 加了变长支持,但只限 ATB 路径,PyTorch 路径还没跟上。
如果你的 Llama 推理服务首 token 延迟偏高,先看 RoPE 是不是独立算子在跑。融合方式很简单,但 fallback 是静默的,不查日志你根本不知道。仓库在这里:
https://atomgit.com/cann/ops-transformer
