当前位置: 首页 > news >正文

大模型学习(二、使用lora进行微调)

目录

🧩 1. 什么是 SFT(Supervised Fine-Tuning)

(1)定义

(2)LOSS的数学表示

(3)一个真实的例子解释LOSS

2.什么是 LoRA(Low-Rank Adaptation)

(1)定义

3.示例代码

(1)数据集格式

(2)代码

(3)运行结果

(4)读取lora参数,重新进行模型推理


一、lora和SFT的介绍

🧩 1. 什么是 SFT(Supervised Fine-Tuning)

(1)定义

SFT = 监督微调
本质是:

用「输入 → 标准输出」对模型做有监督学习

形式:

用户:问题 AI:标准答案

训练目标:

特点:

  • 数据:成对的(prompt, answer)

  • loss:交叉熵

  • 和分类任务本质一样,只是输出是文本

👉 SFT 解决的是:
“模型该学什么行为?”

(2)LOSS的数学表示

语言模型的训练目标:

(3)一个真实的例子解释LOSS

这个输出的巴黎是标签。之后拿到对应标签模型输入的概率,之后log求和。

2.什么是 LoRA(Low-Rank Adaptation)

(1)定义

LoRA = 一种参数高效微调方法(PEFT)

核心思想:

❌ 不改原模型参数
✅ 只在部分层插入小矩阵并训练它们

数学上:
原本权重:

LoRA 改为:

3.示例代码

(1)数据集格式

{"system": "你是一个名为沐雪的可爱AI女孩子", "conversation": [{"human": "如何集中精力做一件事情", "assistant": "首当其冲的肯定是选择一个合适的地方啦,比如说图书馆之类的,如果你不想出去,那就找一个安静的地方吧。然后扔掉手机这类会让你分心的东西,或者关掉通知,确保你不会突然被打扰。明确你要做的事情,把它细化成分几步去完成,设置期限,任务完成之后放松放松。如果你感觉到累了不行了就去外面转转吧,喝一杯咖啡,思考让你停下来的地方,然后活力满满地继续接下来的工作。"}]}

(2)代码

from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from peft import LoraConfig, get_peft_model import torch import json # ==================================================== # ① 模型路径与数据路径 # ==================================================== model_dir = r"C:\Users\64292\Desktop\大模型学习\xiaozhi\weitiao\Qwen2.5-1.5B-Instruct" data_path = r"C:\Users\64292\Desktop\大模型学习\xiaozhi\weitiao\competition_train.jsonl" # ==================================================== # ② 加载数据集 # ==================================================== dataset = load_dataset("json", data_files=data_path) # ==================================================== # ③ 预处理函数:把 system + human 拼成 prompt # ==================================================== def format_example(example): conversations = example["conversation"] if not conversations or len(conversations) == 0: return None conv = conversations[0] system = example.get("system", "") human = conv.get("human", "") assistant = conv.get("assistant", "") # 构建输入与输出 prompt = f"系统:{system}\n用户:{human}\nAI:" output = assistant.strip() return {"prompt": prompt, "output": output} dataset = dataset.map(format_example) dataset = dataset.filter(lambda x: x["prompt"] is not None) # ==================================================== # ④ 加载分词器与模型 # ==================================================== tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.float16, device_map="auto" ) # ==================================================== # ⑤ LoRA 配置(低显存训练) # ==================================================== lora_config = LoraConfig( r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # ==================================================== # ⑥ Tokenize 函数 # ==================================================== def preprocess(example): text = f"{example['prompt']}{example['output']}" tokenized = tokenizer(text, truncation=True, padding="max_length", max_length=512) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized tokenized_ds = dataset.map(preprocess, remove_columns=dataset["train"].column_names) # ==================================================== # ⑦ 训练配置 # ==================================================== args = TrainingArguments( output_dir="./qwen2.5-1.5b-lora-muxue", per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=2e-4, num_train_epochs=3, fp16=True, logging_steps=10, save_steps=200, save_total_limit=2, report_to="none" ) # ==================================================== # ⑧ 训练启动 # ==================================================== trainer = Trainer( model=model, args=args, train_dataset=tokenized_ds["train"] ) trainer.train() # ==================================================== # ⑨ 保存权重 # ==================================================== model.save_pretrained("./qwen2.5-1.5b-lora-muxue") tokenizer.save_pretrained("./qwen2.5-1.5b-lora-muxue") print("✅ 微调完成!权重保存在 ./qwen2.5-1.5b-lora-muxue")

(3)运行结果

(4)读取lora参数,重新进行模型推理

from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch # ① 原始模型路径(基础模型) base_model_dir = r"C:\Users\64292\Desktop\xiaozhi\weitiao\Qwen2.5-1.5B-Instruct" # ② LoRA 权重路径(你的微调结果) lora_dir = r"./qwen2.5-1.5b-lora-muxue" # ③ 加载分词器 print("🚀 正在加载分词器和模型...") tokenizer = AutoTokenizer.from_pretrained(lora_dir, trust_remote_code=True) # ④ 加载基础模型 base_model = AutoModelForCausalLM.from_pretrained( base_model_dir, torch_dtype=torch.float16, device_map="auto" ) # ⑤ 加载 LoRA 微调权重 model = PeftModel.from_pretrained(base_model, lora_dir) model.eval() print("✅ 已加载 Qwen + LoRA 微调权重(人格:沐雪)!\n") # ⑥ 设定人格系统提示词 system_prompt = "你是一个名为世君同学的可爱AI女孩子,性格温柔、活泼、善解人意,说话要自然可爱。" chat_history = f"系统:{system_prompt}\n" # ⑦ 聊天循环 while True: user_input = input("👤 你:").strip() if user_input.lower() in ["exit", "quit", "q"]: print("👋 沐雪:再见呀~记得想我哦 💖") break # 将用户输入加入上下文 chat_history += f"用户:{user_input}\nAI:" # 编码输入 inputs = tokenizer(chat_history, return_tensors="pt").to(model.device) # 模型生成 outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # 解码生成文本 reply = tokenizer.decode(outputs[0], skip_special_tokens=True) # 提取模型新增部分(去掉历史) new_text = reply[len(chat_history):].strip() # 输出结果 print(f"🤖 沐雪:{new_text}\n") # 更新上下文 chat_history += new_text + "\n"
http://www.cnnetsun.cn/news/848382.html

相关文章:

  • 杰理之DAC 24bit 频响获取异常【篇】
  • springboot在线图书借阅平台系统设计实现
  • springboot幼儿园管理系统设计开发实现
  • 新药品管理法
  • P14968 Hoping that one Day题解
  • easyExcel 的动态列导出把文本转为数值格式,可以进行函数计算
  • 用户态/内核态 = 操作系统内核?
  • 从Vue到Spring Boot:一个Java全栈工程师的实战面试实录
  • java项目--智能无人机平台v3pro
  • 彻底爆了!阿里最新大模型,再次拿下第一!
  • 社会网络仿真软件:Gephi_(18).社会网络分析理论基础
  • ES6新增了哪些新特性
  • 目前全网唯一的Autosar TLS文章
  • 工作流程管理系统信息管理系统源码-SpringBoot后端+Vue前端+MySQL【可直接运行】
  • 我的思维模型 -- 5.工程学篇
  • 基于SpringBoot+Vue的社区养老服务平台管理系统设计与实现【Java+MySQL+MyBatis完整源码】
  • 基于SpringBoot+Vue的文理医院预约挂号系统管理系统设计与实现【Java+MySQL+MyBatis完整源码】
  • SQL注入知识要点总结
  • YOLO26手势识别项目实战3-石头剪刀布实时检测系统数据集说明(含训练代码、数据集和GUI交互界面)
  • 电容式三点式振荡电路/电感式三点振荡电路
  • BUCK降压电路Multisim电路仿真分析
  • 好用的PC电脑流程图软件无需下载在线绘制流程图模板大全
  • 基于SpringBoot+Vue的spring boot校园商铺管理系统管理系统设计与实现【Java+MySQL+MyBatis完整源码】
  • 企业级医药管理系统管理系统源码|SpringBoot+Vue+MyBatis架构+MySQL数据库【完整版】
  • 智能球机摄像头自带旋转355度视角
  • 科研人员新工具:gpt-oss-20b-WEBUI助力论文写作与分析
  • 前后端分离spring boot校园商铺管理系统系统|SpringBoot+Vue+MyBatis+MySQL完整源码+部署教程
  • 3分钟突破付费墙:Bypass Paywalls Clean让优质内容触手可及
  • 显存22GB以内搞定Qwen2.5-7B微调,4090D实测真香
  • Keil添加文件正确方式:针对STM32项目的通俗解释