当前位置: 首页 > news >正文

大模型高效微调--P-Tuning v2

文章目录

      • P-Tuning v2 概述
      • 核心改进
      • 关键技术细节
      • 代码示例
      • 性能对比
      • 局限性

https://github.com/THUDM/P-tuning-v2

P-Tuning v2 概述

P-Tuning v2 是清华大学团队提出的一种参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)方法,旨在改进传统微调方法在大型预训练语言模型(如GPT、BERT)上的效率和性能。它是P-Tuning的升级版本,通过优化提示(Prompt)设计和参数更新策略,显著提升了模型在低资源场景下的表现。

核心改进

连续提示优化
P-Tuning v2 引入了可训练的连续提示(Continuous Prompts),取代了传统离散提示。这些提示以嵌入向量的形式插入到模型的输入层或中间层,通过梯度下降动态调整,避免了人工设计提示的局限性。

分层提示注入
与P-Tuning仅在输入层添加提示不同,P-Tuning v2 在模型的每一层(或关键层)注入提示向量,形成分层提示结构。这种设计能更深度地引导模型行为,尤其适合深层Transformer架构。

参数效率提升
P-Tuning v2 仅需微调少量额外参数(通常占模型总参数的0.1%-1%),大幅降低了计算和存储开销,同时保持了与全参数微调相近的性能。

关键技术细节

提示向量初始化
提示向量通常随机初始化或从任务相关词嵌入中采样。实验表明,合理的初始化能加速收敛并提升最终效果。

训练目标
P-Tuning v2 通过标准的下游任务损失(如交叉熵)优化提示参数,同时可结合适配器(Adapter)或LoRA等轻量级模块进一步减少可训练参数。

适用场景

  • 小样本学习(Few-shot Learning)
  • 多任务学习(通过不同提示区分任务)
  • 资源受限的设备部署

代码示例

P-Tuning v2的核心逻辑:

importtorchclassPrefixEncoder(torch.nn.Module):r''' The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) '''def__init__(self,config):super().__init__()self.prefix_projection=config.prefix_projectionifself.prefix_projection:# Use a two-layer MLP to encode the prefixself.embedding=torch.nn.Embedding(config.pre_seq_len,config.hidden_size)self.trans=torch.nn.Sequential(torch.nn.Linear(config.hidden_size,config.prefix_hidden_size),torch.nn.Tanh(),torch.nn.Linear(config.prefix_hidden_size,config.num_hidden_layers*2*config.hidden_size))else:self.embedding=torch.nn.Embedding(config.pre_seq_len,config.num_hidden_layers*2*config.hidden_size)defforward(self,prefix:torch.Tensor):ifself.prefix_projection:prefix_tokens=self.embedding(prefix)past_key_values=self.trans(prefix_tokens)else:past_key_values=self.embedding(prefix)returnpast_key_values

  • https://github.com/THUDM/P-tuning-v2/blob/main/model/token_classification.py
classBertPrefixForTokenClassification(BertPreTrainedModel):def__init__(self,config):super().__init__(config)self.num_labels=config.num_labels self.bert=BertModel(config,add_pooling_layer=False)self.dropout=torch.nn.Dropout(config.hidden_dropout_prob)self.classifier=torch.nn.Linear(config.hidden_size,config.num_labels)from_pretrained=Falseiffrom_pretrained:self.classifier.load_state_dict(torch.load('model/checkpoint.pkl'))forparaminself.bert.parameters():param.requires_grad=Falseself.pre_seq_len=config.pre_seq_len self.n_layer=config.num_hidden_layers self.n_head=config.num_attention_heads self.n_embd=config.hidden_size//config.num_attention_heads self.prefix_tokens=torch.arange(self.pre_seq_len).long()self.prefix_encoder=PrefixEncoder(config)bert_param=0forname,paraminself.bert.named_parameters():bert_param+=param.numel()all_param=0forname,paraminself.named_parameters():all_param+=param.numel()total_param=all_param-bert_paramprint('total param is {}'.format(total_param))# 9860105defget_prompt(self,batch_size):prefix_tokens=self.prefix_tokens.unsqueeze(0).expand(batch_size,-1).to(self.bert.device)past_key_values=self.prefix_encoder(prefix_tokens)# bsz, seqlen, _ = past_key_values.shapepast_key_values=past_key_values.view(batch_size,self.pre_seq_len,self.n_layer*2,self.n_head,self.n_embd)past_key_values=self.dropout(past_key_values)past_key_values=past_key_values.permute([2,0,3,1,4]).split(2)returnpast_key_valuesdefforward(self,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,inputs_embeds=None,labels=None,output_attentions=None,output_hidden_states=None,return_dict=None,):return_dict=return_dictifreturn_dictisnotNoneelseself.config.use_return_dict batch_size=input_ids.shape[0]past_key_values=self.get_prompt(batch_size=batch_size)prefix_attention_mask=torch.ones(batch_size,self.pre_seq_len).to(self.bert.device)attention_mask=torch.cat((prefix_attention_mask,attention_mask),dim=1)outputs=self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,past_key_values=past_key_values,)sequence_output=outputs[0]sequence_output=self.dropout(sequence_output)logits=self.classifier(sequence_output)attention_mask=attention_mask[:,self.pre_seq_len:].contiguous()loss=NoneiflabelsisnotNone:loss_fct=CrossEntropyLoss()# Only keep active parts of the lossifattention_maskisnotNone:active_loss=attention_mask.view(-1)==1active_logits=logits.view(-1,self.num_labels)active_labels=torch.where(active_loss,labels.view(-1),torch.tensor(loss_fct.ignore_index).type_as(labels))loss=loss_fct(active_logits,active_labels)else:loss=loss_fct(logits.view(-1,self.num_labels),labels.view(-1))ifnotreturn_dict:output=(logits,)+outputs[2:]return((loss,)+output)iflossisnotNoneelseoutputreturnTokenClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)

性能对比

在SuperGLUE基准测试中,P-Tuning v2 仅微调0.5%参数时,性能可达全参数微调的90%以上,同时训练速度提升3-5倍。对于超大规模模型(如百亿参数),其优势更加显著。

局限性

  • 提示长度和层数需通过实验调优
  • 对某些需要全局参数调整的任务(如文本生成)可能需结合其他PEFT方法

参考: https://github.com/zejunwang1/chatglm_tuning/blob/main/train_ptuning.py

http://www.cnnetsun.cn/news/169991.html

相关文章:

  • 7、PowerShell代码签名:保障脚本安全的全面指南
  • 12、网络带宽与 Windows Server 2003 相关技术解析
  • 17、Windows Server 2003 Active Directory 部署与管理全解析
  • Linly-Talker支持多语言输出:全球化数字人布局利器
  • 4、PowerShell 深入解析与实践指南
  • Linly-Talker在金融客服中的应用案例分享
  • Linly-Talker在远程教学中的创新应用场景
  • Linly-Talker深度解析:语音克隆与表情同步技术揭秘
  • 如何提升数字人真实感?Linly-Talker多模态融合策略
  • 用Linly-Talker构建虚拟主播:实时交互不是梦
  • Linly-Talker深度评测:AI数字人对话系统的未来已来
  • 提升客户体验:Linly-Talker在智能客服中的实践
  • Linly-Talker用户案例分享:某银行数字客服上线实录
  • 15、Windows 7 个性化设置与系统维护指南
  • Linly-Talker支持按部门分配算力资源吗?
  • 开发者必看:Linly-Talker API接口调用详解
  • Day 45 图像数据与显存
  • 18、Windows Vista 离线文件使用指南
  • Linly-Talker镜像提供API调用频次统计功能
  • 18、工作流开发:强类型活动与CAG的应用
  • Linly-Talker多场景适配:客服/导览/教学全面覆盖
  • Linly-Talker在展览馆展品解说中的创新实践
  • Linly-Talker能否生成宠物医生形象进行养宠科普?
  • Linly-Talker能否用于房地产楼盘介绍虚拟销售?
  • Linly-Talker如何优化弱网环境下的音画同步?
  • 26、虚拟机迁移配置全解析
  • Linly-Talker支持RBAC权限控制系统吗?
  • 计算机毕业设计springboot基于JavaWeb的宠物寄养系统设计与实现 基于SpringBoot的宠物托管服务平台的设计与实现 JavaWeb视角下的宠物临时照护系统构建与研发
  • 汇编语言全接触-39.获得结果
  • 经典算法题型之编辑距离(二)