当前位置: 首页 > news >正文

拆解Transformer本源:350行源码吃透Attention底层原理

文章目录

    • 前言
    • 一、Scaled Dot-Product Attention:AI界的"查户口"大师
    • 二、Multi-Head Attention:一个人同时开八个脑洞
    • 三、Position-wise FFN:每个token的"健身房私教课"
    • 四、Positional Encoding:给token发"座位号"
    • 五、Layer Normalization:给数据穿"统一制服"
    • 六、Encoder Layer:自注意力+FFN的"组合拳"
    • 七、Decoder Layer:戴眼罩的"传话游戏"
    • 八、完整Transformer:组装高达的时刻
    • 九、写在最后:350行代码,八年AI霸权

P.S. 无意间发现了一个巨牛的人工智能教程,非常通俗易懂,对AI感兴趣的朋友强烈推荐去看看,传送门https://blog.csdn.net/HHX_01

前言

2017年,Google那帮大佬甩出一篇论文,标题叫《Attention Is All You Need》。翻译成人话就是:“Attention就够了,别的都是弟弟。“我当时一看,好家伙,这口气比我家楼下烧烤摊老板还大。老板至少还谦虚地说"我家羊肉串全市第二”,Google直接说"我只需要注意力”。

结果八年过去了,GPT、BERT、LLaMA全是从这玩意儿肚子里爬出来的。现在大模型卷得跟春运抢票似的,但你敢信吗?这祖宗的源码,纯PyTorch写出来,就350行。350行!我上次写个登录页面都不止350行。Google这帮人是真狠,用个博客文章的长度,把整个AI行业的地基给打好了。

今天我就当一回"源码拆弹专家",把这350行代码一行一行掰开揉碎。放心,不催眠,不念经,全程脱口秀节奏。你要是看完还犯困,那我……那我下次换个更吵的BGM。

一、Scaled Dot-Product Attention:AI界的"查户口"大师

Transformer的核心就一句话:Query问Key,Key回答Value。听起来像不像相亲?Query就是男方,问Key:"你有房吗?有车吗?存款几位数?"Key一一回答,然后Value就是女方实际的嫁妆——哦不,是实际的语义信息。

公式长这样:Attention(Q,K,V) = softmax(QK^T / √d_k) @ V。别跑!这玩意儿翻译成中文就是:先把Query和Key的点积算出来,除以一个√d_k,再softmax一下,最后跟Value乘一块儿。简单吧?就像你问相亲对象三个问题,打分,归一化,最后决定要不要继续聊。

那为什么非要除以√d_k呢?因为维度一高,点积的数值容易膨胀,softmax直接"社死"——梯度消失得比我的头发还快。除以√d_k就相当于给数值"减肥",保持身材匀称,训练才不会崩盘。这操作,跟我过年狂吃后上秤前先脱鞋脱外套一个逻辑。

核心代码

classScaledDotProductAttention(nn.Module):def__init__(self,dropout:float=0.1):super().__init__()self.dropout=nn.Dropout(dropout)defforward(self,Q,K,V,mask=None):d_k=Q.size(-1)scores=torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(d_k)ifmaskisnotNone:scores=scores.masked_fill(mask==0,float('-inf'))attn_weights=F.softmax(scores,dim=-1)attn_weights=self.dropout(attn_weights)output=torch.matmul(attn_weights,V)returnoutput,attn_weights

看到masked_fill没?这就是"拉黑"操作。padding位置直接填-inf,softmax后权重归零,相当于在相亲现场把不符合条件的直接请出去。dropout更是狠,训练时随机"闭麦"一些注意力权重,防止模型过拟合——就像你同时聊十个相亲对象,突然随机断网几个,逼你认真跟剩下的谈。

复杂度是O(n²·d_k),n是序列长度。这也是Transformer被吐槽最多的地方:序列一长,计算量爆炸。GPT-4处理长文档时,那算力消耗,比我交房租时的心绞痛还真实。

二、Multi-Head Attention:一个人同时开八个脑洞

单头注意力就像你只用一只眼看世界,虽然能看,但立体感差点意思。Multi-Head Attention呢?相当于给你脑袋上装八个摄像头,同时从八个角度观察同一个对象。语法关系、语义关联、指代消解、情感倾向……每个头负责一块,最后把八份报告拼一起,交一份综合情报。

代码里有个神操作:不是真的定义8组独立的Q/K/V投影层,而是各用一个Linear层,投影完再view拆成8份。数学上等价,但参数从3h个降到4个。Google这帮人是真会过日子,省下来的显存够我多跑两轮实验了。

多头注意力代码

classMultiHeadAttention(nn.Module):def__init__(self,d_model:int,n_heads:int,dropout:float=0.1):super().__init__()assertd_model%n_heads==0self.d_model=d_model self.n_heads=n_heads self.d_k=d_model//n_heads self.W_Q=nn.Linear(d_model,d_model,bias=False)self.W_K=nn.Linear(d_model,d_model,bias=False)self.W_V=nn.Linear(d_model,d_model,bias=False)self.W_O=nn.Linear(d_model,d_model,bias=False)self.attention=ScaledDotProductAttention(dropout)defforward(self,Q,K,V,mask=None):batch_size=Q.size(0)Q=self.W_Q(Q).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)K=self.W_K(K).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)V=self.W_V(V).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)attn_output,attn_weights=self.attention(Q,K,V,mask)attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)output=self.W_O(attn_output)returnoutput

注意那个.contiguous(),transpose只是换了个"看数据的姿势",内存里还是老样子。你要是不contiguous一下,后面的view直接报错,PyTorch的脾气比你女朋友还难猜。W_O是最后的"总编辑",把八个头的输出揉巴揉巴,合成一份d_model维度的终稿。

mask的广播机制也很有意思。mask形状是(batch, 1, 1, seq_len),scores是(batch, n_heads, seq_len, seq_len)。PyTorch自动帮你广播,不用手动unsqueeze。这感觉就像你去餐厅吃饭,服务员主动问你要不要加辣——细节到位,体验丝滑。

三、Position-wise FFN:每个token的"健身房私教课"

注意力层处理完,每个token还得去FFN里"撸个铁"。FFN(x) = ReLU(xW_1 + b_1)W_2 + b_2。两层全连接,中间夹个ReLU,跟夹心饼干似的。关键是"position-wise"——同一个参数矩阵,给序列里每个token轮流用,公平得很,跟健身房私教同时带十个学员,但训练计划一模一样。

内部维度d_ff通常是d_model的四倍,512变2048。这就好比把数据从单人间塞进四人间,折腾一番再搬回单人间。折腾的过程就是非线性变换,给模型增加"表达能力"。论文实验说一层不够,三层多余,两层刚刚好。Google这帮人是懂中庸之道的,比我家楼下奶茶店"半糖"的拿捏还精准。

FFN代码

classPositionWiseFeedForward(nn.Module):def__init__(self,d_model:int,d_ff:int,dropout:float=0.1):super().__init__()self.linear1=nn.Linear(d_model,d_ff)self.linear2=nn.Linear(d_ff,d_model)self.dropout=nn.Dropout(dropout)defforward(self,x):returnself.linear2(self.dropout(F.relu(self.linear1(x))))

dropout放在ReLU之后、第二次线性之前,这是行业惯例。原始论文用ReLU,后来GPT系列换成了GELU。GELU更平滑,像给ReLU做了个SPA,从"硬切换"变成"软着陆"。不过ReLU胜在简单直接,就像直男表白,虽然生硬,但好歹把意思传达到了。

四、Positional Encoding:给token发"座位号"

Self-Attention有个致命bug:它分不清"张三打了李四"和"李四打了张三"。你把句子里的词随便换位置,它输出一模一样。这就像一个脸盲症患者,看谁都像同一个人,完全靠衣服颜色区分——但你要是给他换件衣服,他就彻底懵了。

所以必须给每个token发张"座位号",告诉它你在第几个位置。Google用的是正余弦编码,公式长得跟高数期末考最后一道大题似的。但核心思想就一条:不同位置的编码不同,而且相对位置可以通过线性变换推导出来。sin(α+Δ) = sinα·cosΔ + cosα·sinΔ,三角函数恒等式,高中数学的遗产,现在被Google拿来给AI指路。

位置编码代码

classPositionalEncoding(nn.Module):def__init__(self,d_model:int,max_len:int=5000,dropout:float=0.1):super().__init__()self.dropout=nn.Dropout(dropout)pe=torch.zeros(max_len,d_model)position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))pe[:,0::2]=torch.sin(position*div_term)pe[:,1::2]=torch.cos(position*div_term)pe=pe.unsqueeze(0)self.register_buffer('pe',pe)defforward(self,x):x=x+self.pe[:,:x.size(1),:]returnself.dropout(x)

div_term用指数形式计算,是为了数值稳定性。你要是直接算10000^(2i/d_model),浮点数精度早崩了,跟用计算器算1除以3然后乘3,结果不是1一样让人抓狂。register_buffer让pe跟着模型到处跑(CPU、GPU随便切),但不会被优化器盯上——相当于公司里的保洁阿姨,到处都有她,但KPI考核里没她。

为什么不用可学习的位置嵌入?三个原因:能外推更长序列(训练时没见过5000长度,但编码天然支持)、省参数、相对位置有数学保证。Google这波叫"用数学省算力",跟我用优惠券点外卖一个思路,但人家省出来的是几台A100的电费。

五、Layer Normalization:给数据穿"统一制服"

Batch Norm在CV界混得风生水起,但到了NLP这儿就水土不服。为啥?序列长度不一样,batch大小也不稳定,Batch Norm统计的均值方差跟过山车似的。Layer Norm说:“算了,我不管batch了,我每个样本自己跟自己比。”

对每个样本的所有维度,先减均值、除标准差,再乘个γ加个β。γ和β是可学习的,相当于"制服虽然统一,但允许你微调尺寸"。这跟学校发校服一个道理:大家都穿蓝白相间,但胖瘦可以自己调。

LayerNorm代码

classLayerNorm(nn.Module):def__init__(self,d_model:int,eps:float=1e-6):super().__init__()self.gamma=nn.Parameter(torch.ones(d_model))self.beta=nn.Parameter(torch.zeros(d_model))self.eps=epsdefforward(self,x):mean=x.mean(dim=-1,keepdim=True)std=x.std(dim=-1,keepdim=True,unbiased=False)returnself.gamma*(x-mean)/(std+self.eps)+self.beta

unbiased=False用的是样本标准差,跟原始论文保持一致。eps=1e-6是防止除零的保险丝,虽然实际数据几乎不会遇到全零向量,但代码里不防一手,就跟开车不系安全带一样——大概率没事,但出事就是大事。生产环境直接用nn.LayerNorm就行,手写版纯粹是为了让你看清"校服是怎么裁剪的"。

六、Encoder Layer:自注意力+FFN的"组合拳"

Encoder层就是"自注意力打完,FFN补刀"。每个子层后面都跟一个残差连接和Layer Norm。残差连接x + sublayer(x)是深度网络的救命稻草——梯度可以沿着shortcut直接传回去,不用一层一层慢慢爬。这感觉就像你住30楼,电梯坏了,但旁边有个滑梯直通一楼。虽然滑梯有点陡,但好歹比爬楼梯快。

Encoder层代码

classEncoderLayer(nn.Module):def__init__(self,d_model,n_heads,d_ff,dropout=0.1):super().__init__()self.self_attn=MultiHeadAttention(d_model,n_heads,dropout)self.ffn=PositionWiseFeedForward(d_model,d_ff,dropout)self.norm1=LayerNorm(d_model)self.norm2=LayerNorm(d_model)self.dropout1=nn.Dropout(dropout)self.dropout2=nn.Dropout(dropout)defforward(self,x,mask=None):attn_output=self.self_attn(x,x,x,mask)x=x+self.dropout1(attn_output)x=self.norm1(x)ffn_output=self.ffn(x)x=x+self.dropout2(ffn_output)x=self.norm2(x)returnx

注意self_attn的三个参数都是x,这叫"自注意力"——Query、Key、Value全来自同一个序列,自己查自己,自己关注自己。有点像你深夜翻自己三年前的朋友圈,一边看一边自我剖析:“我当时怎么会发这个?”

这是Post-LN模式:先残差再归一化。后来有些变体改成Pre-LN,先归一化再残差,训练更稳定。但原始论文是Post-LN,咱们尊重经典,就像吃北京烤鸭必须配甜面酱,虽然有人爱蘸白糖,但传统不能丢。

七、Decoder Layer:戴眼罩的"传话游戏"

Decoder比Encoder多一个子层,叫Cross-Attention。Encoder把输入序列的信息压缩成一份"参考手册",Decoder一边看自己之前生成的token,一边翻这份手册,决定下一个词输出啥。这像极了我写代码时的状态:一边回忆自己上一行写了啥,一边查Stack Overflow。

但Decoder有个特殊规矩:自注意力层必须戴眼罩,只能看当前位置及之前的token,不能偷看未来。这叫causal mask,下三角矩阵,上三角全填-inf。为什么?因为翻译时你还没生成后面的词,要是让模型提前看答案,跟考试作弊有什么区别?GPT就是这么"自律"地长大的,虽然它后来学会了不少作弊技巧(比如背题库)。

Decoder层代码

classDecoderLayer(nn.Module):def__init__(self,d_model,n_heads,d_ff,dropout=0.1):super().__init__()self.self_attn=MultiHeadAttention(d_model,n_heads,dropout)self.cross_attn=MultiHeadAttention(d_model,n_heads,dropout)self.ffn=PositionWiseFeedForward(d_model,d_ff,dropout)self.norm1=LayerNorm(d_model)self.norm2=LayerNorm(d_model)self.norm3=LayerNorm(d_model)self.dropout1=nn.Dropout(dropout)self.dropout2=nn.Dropout(dropout)self.dropout3=nn.Dropout(dropout)defforward(self,x,enc_output,src_mask=None,tgt_mask=None):attn_output=self.self_attn(x,x,x,tgt_mask)x=x+self.dropout1(attn_output)x=self.norm1(x)attn_output=self.cross_attn(x,enc_output,enc_output,src_mask)x=x+self.dropout2(attn_output)x=self.norm2(x)ffn_output=self.ffn(x)x=x+self.dropout3(ffn_output)x=self.norm3(x)returnx

Cross-Attention的Q来自Decoder,K和V来自Encoder。Decoder每生成一个词,就拿着这个词去Encoder的"手册"里查:"前面输入的句子,哪个部分跟我现在最相关?"这机制让翻译准确率直接起飞,比传统RNN的"传话游戏"强太多了。RNN传话传到最后一个词,第一个词的信息早就失真得跟谣言一样了。

八、完整Transformer:组装高达的时刻

最后一步,把N层Encoder和N层Decoder堆起来,加上嵌入层、位置编码、输出投影,一台完整的Transformer就组装完毕。论文里N=6,d_model=512,n_heads=8,d_ff=2048。这些数字不是拍脑袋定的,是Google烧了不少TPU试出来的"黄金比例"。

完整Transformer代码

classTransformer(nn.Module):def__init__(self,src_vocab,tgt_vocab,d_model=512,n_heads=8,d_ff=2048,n_layers=6,dropout=0.1,max_len=5000):super().__init__()self.encoder_embed=nn.Embedding(src_vocab,d_model)self.decoder_embed=nn.Embedding(tgt_vocab,d_model)self.pos_encoding=PositionalEncoding(d_model,max_len,dropout)self.encoder_layers=nn.ModuleList([EncoderLayer(d_model,n_heads,d_ff,dropout)for_inrange(n_layers)])self.decoder_layers=nn.ModuleList([DecoderLayer(d_model,n_heads,d_ff,dropout)for_inrange(n_layers)])self.fc_out=nn.Linear(d_model,tgt_vocab)defforward(self,src,tgt,src_mask=None,tgt_mask=None):src_emb=self.pos_encoding(self.encoder_embed(src))forlayerinself.encoder_layers:src_emb=layer(src_emb,src_mask)tgt_emb=self.pos_encoding(self.decoder_embed(tgt))forlayerinself.decoder_layers:tgt_emb=layer(tgt_emb,src_emb,src_mask,tgt_mask)returnself.fc_out(tgt_emb)

nn.ModuleList确保每一层的参数都被PyTorch登记在册,不会变成"黑户"。Encoder和Decoder各自有独立的嵌入层,虽然理论上可以共享,但分开更灵活。fc_out把d_model投影到词表大小,输出就是下一个token的概率分布——相当于给词典里每个词打个分,分最高的就是"天选之子"。

九、写在最后:350行代码,八年AI霸权

你看完这350行,可能会觉得:就这?GPT-4、Claude、LLaMA这些动辄千亿参数的怪物,祖宗居然这么简洁?没错,伟大的架构往往简单到让人怀疑人生。就像爱因斯坦的E=mc²,就五个字符,但改变了整个物理学。

理解了这些基础组件,你再去看GPT系列"只用Decoder"、BERT系列"只用Encoder"、LLaMA把ReLU换成SwiGLU、把LayerNorm换成RMSNorm——这些变体就不再是黑魔法,而是"在祖宗的基础上装修房子"。有人拆墙,有人加隔断,但地基永远是这350行。

所以下次有人跟你吹"大模型多神秘",你可以淡定地喝口咖啡,说:“神秘啥?我看过它祖宗的源码,就350行,还没我微信聊天记录长。”

当然,看完这篇你要是还写不出Transformer,那很正常。我看完《舌尖上的中国》也没学会做佛跳墙。但起码,你再打开GitHub上那些开源大模型的代码时,不会一脸懵了。这,就是"拆穿底裤"的意义。

P.S. 无意间发现了一个巨牛的人工智能教程,非常通俗易懂,对AI感兴趣的朋友强烈推荐去看看,传送门https://blog.csdn.net/HHX_01

http://www.cnnetsun.cn/news/2742698.html

相关文章:

  • ECU软件迭代后,A2L文件地址飘了怎么办?ASAP2 Studio增量更新实战指南
  • 告别Redis?用C++手把手教你玩转LMDB这个嵌入式内存数据库
  • Agent 并不是越聪明越好:企业场景下的模型蒸馏与小模型应用
  • Navicat Premium无限试用解决方案:告别14天限制的智能重置工具
  • JSP+Servlet学生信息管理系统完整课程设计包(含数据库脚本、Eclipse工程与论文文档)
  • Kimi K2.6 vs GLM-5.1:开发者真实编程任务选型指南
  • AirSim Python API避坑指南:多旋翼控制、图像采集与天气模拟的实战心得
  • Mysql中事务(tp binlog日志,pos模式需要完整事件的起始)
  • 本科毕设可用的车牌识别系统:带GUI界面、预训练模型和完整演示素材
  • 会议管理系统
  • Thermacell 推出 Liv 2.0 智能驱蚊系统:覆盖更广、能驱蠓虫,但价格翻倍还需专业安装!
  • 高效玩赚营销!autoAGC海报搞定电商全场景引流
  • ROS参数服务器避坑指南:从launch文件到C++/Python代码,详解命名空间那些容易踩的坑
  • Gemini 3.1 Pro长对话认知退化实测与抗衰减工程实践
  • Gemma 2本地部署实战:消费级硬件上的安全可控推理指南
  • Qoder 明确标注 Kimi-K2.5:长上下文与结构化输出的工程级落地
  • GPT-5.5并不存在:AI模型版本命名规范与事实核查指南
  • CAPL脚本数据处理避坑指南:整型数组与Hex字符串互转的实战函数库
  • 055、角度环与角速度环的串级PID实现
  • 微信小程序智慧物业系统源码包:支持云开发与本地部署,含报修投票、装修申请等完整功能
  • 怎么做决策:做树状脉络分析利弊(重在思考失去,不要不珍惜现在),拉长时间线
  • 2026陕西省官方授权CPPM注册职业采购经理培训机构选择指南
  • 【技术架构】2026企业级AI落地实践:从RPA到AI Agent的原生CRM重构!
  • 告别裸机画点线:在STM32H743上为4.3寸屏移植STemWin GUI库的完整流程与内存优化技巧
  • 《逃离玫瑰岛》小说|下载|txt
  • 从芯片到场景:BOS半导体以Physical AI定义车载AI Box新范式
  • NarratoAI完整教程:三步掌握AI视频解说制作神器
  • Tatai 3.0:让任意服务器上的 Java 应用,拥有云原生级的高可用体验
  • 基于 Harmony 6.0 应用的校园失物招领系统首页实现
  • 你的旧笔记本别扔!巧用闲置MiniPCIe接口,低成本变身4G物联网网关或监控终端