知识蒸馏之交叉熵篇——代码实战
知识蒸馏之交叉熵篇——代码实战
。下述代码,总体上表示为把模型输出student_logits和真实答案labels做比较,计算一个“分类错误程度”的损失值,命名为ce_loss。
ce_loss=F.cross_entropy(student_logits,labels)那么问题来了------
1. 为什么用交叉熵?
因为这是分类任务里最常用的损失函数。比如模型要判断一张图是猫、狗、车。模型不会直接说“猫”,而是输出每个类别的分数:
student_logits=[2.1,0.3,-1.2]这些分数表示模型对每个类别的倾向。交叉熵适合衡量:
模型预测的类别概率与真实类别之间从概率上看差得有多远。
如果模型对正确类别非常自信,交叉熵小。如果模型对错误类别很自信,交叉熵大。
2. 交叉熵有什么作用?
它的作用是告诉模型:你错得有多离谱。举例,如果真实标签是“猫”。
现在模型 A 预测这个标签的概率分布结果为:
猫: 0.90, 狗: 0.08, 车: 0.02交叉熵损失很小,因为正确类别概率高。
模型 B 预测的结果为:
猫: 0.20, 狗: 0.70, 车: 0.10交叉熵损失较大,因为模型更相信“狗”。
训练时,神经网络会通过反向传播让这个损失变小。也就是让模型越来越倾向于给正确类别更高分。
3. 什么是ce_loss?有什么用处?
ce_loss是一个变量名,通常表示cross entropy loss,也就是交叉熵损失。
它一般是一个标量,比如:
tensor(0.7321)它的用途主要有两个:
ce_loss.backward()optimizer.step()ce_loss.backward()会计算梯度,告诉每个参数应该往哪个方向调整。optimizer.step()根据梯度更新模型参数。
所以ce_loss是训练模型时的核心指标之一:模型通过最小化它来学习。
4. 这个F是哪里定义的?里面大概都有些什么?
这里的F通常来自 PyTorch:
importtorch.nn.functionalasFF不是一个函数,而是一个模块,完整名字是:
torch.nn.functional里面有很多常用的神经网络函数,比如:
F.relu()F.softmax()F.cross_entropy()F.mse_loss()F.dropout()F.max_pool2d()F.one_hot()这些函数通常是“无状态”的,也就是只负责计算,不自己保存可训练参数。
比如:
F.relu(x)只是把小于 0 的数变成 0。
而类似:
nn.Linear(...)这种层会保存权重参数。
5.student_logits、labels分别代表什么?为什么定义这两个参数?
student_logits是学生模型的原始输出分数。
名字里有两个部分:student表示学生模型,logits表示还没有经过 softmax 的原始分类分数
例如一个 batch 有 2 条样本,每条样本分 3 类:
student_logits=torch.tensor([[2.1,0.3,-1.2],[0.1,1.5,0.4]])形状通常是:
[batch_size,num_classes]labels是真实类别标签:
labels=torch.tensor([0,1])意思是,第 1 个样本真实类别是第 0 类,第 2 个样本真实类别是第 1 类
定义这两个参数,是为了让损失函数知道:模型预测了什么?真实答案是什么?
有了这两个东西,才能计算模型错得多不多。
6. 这一整行代码是用来干什么的?
这一整行代码的作用是:
ce_loss=F.cross_entropy(student_logits,labels)把学生模型的输出student_logits和真实标签labels进行比较,计算分类损失,并保存到ce_loss变量里。
可以理解成:ce_loss = 模型预测结果 和 标准答案 之间的差距
在知识蒸馏代码里,它通常表示学生模型直接向真实标签学习的损失。比如总损失可能是:
loss=alpha*ce_loss+beta*distill_loss其中:ce_loss:学生模型向真实标签学习,distill_loss:学生模型向教师模型学习
附上实现mini知识蒸馏的代码:
importargparseimportrandomfrompathlibimportPathimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromsklearn.datasetsimportload_digitsfromtorch.utils.dataimportDataLoader,TensorDataset,random_splitfromtorchvisionimportdatasets,transformsclassTeacherCNN(nn.Module):def__init__(self):super().__init__()self.features=nn.Sequential(nn.Conv2d(1,32,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(32,64,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2),)self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(128*7*7,256),nn.ReLU(),nn.Dropout(0.2),nn.Linear(256,10),)defforward(self,x):returnself.classifier(self.features(x))classStudentCNN(nn.Module):def__init__(self):super().__init__()self.features=nn.Sequential(nn.Conv2d(1,16,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16,32,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2),)self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(32*7*7,64),nn.ReLU(),nn.Linear(64,10),)defforward(self,x):returnself.classifier(self.features(x))defset_seed(seed):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)defcount_params(model):returnsum(p.numel()forpinmodel.parameters()ifp.requires_grad)defbuild_loaders(data_dir,batch_size,dataset_name,seed):ifdataset_name=="digits":digits=load_digits()images=torch.tensor(digits.images,dtype=torch.float32).unsqueeze(1)/16.0images=F.interpolate(images,size=(28,28),mode="bilinear",align_corners=False)images=(images-0.5)/0.5labels=torch.tensor(digits.target,dtype=torch.long)dataset=TensorDataset(images,labels)train_size=int(0.8*len(dataset))test_size=len(dataset)-train_size generator=torch.Generator().manual_seed(seed)train_set,test_set=random_split(dataset,[train_size,test_size],generator=generator)train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=0)test_loader=DataLoader(test_set,batch_size=batch_size,shuffle=False,num_workers=0)returntrain_loader,test_loader transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)),])train_set=datasets.MNIST(data_dir,train=True,download=True,transform=transform)test_set=datasets.MNIST(data_dir,train=False,download=True,transform=transform)train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=0)test_loader=DataLoader(test_set,batch_size=batch_size,shuffle=False,num_workers=0)returntrain_loader,test_loaderdefevaluate(model,loader,device):model.eval()correct=0total=0loss_total=0.0withtorch.no_grad():forimages,labelsinloader:images=images.to(device)labels=labels.to(device)logits=model(images)loss_total+=F.cross_entropy(logits,labels).item()*images.size(0)preds=logits.argmax(dim=1)correct+=(preds==labels).sum().item()total+=labels.size(0)returnloss_total/total,correct/totaldeftrain_supervised(model,train_loader,test_loader,device,epochs,lr,name):optimizer=torch.optim.Adam(model.parameters(),lr=lr)model.to(device)forepochinrange(1,epochs+1):model.train()running_loss=0.0forimages,labelsintrain_loader:images=images.to(device)labels=labels.to(device)optimizer.zero_grad()logits=model(images)loss=F.cross_entropy(logits,labels)loss.backward()optimizer.step()running_loss+=loss.item()*images.size(0)test_loss,test_acc=evaluate(model,test_loader,device)train_loss=running_loss/len(train_loader.dataset)print(f"{name}epoch{epoch}: train_loss={train_loss:.4f}test_loss={test_loss:.4f}test_acc={test_acc:.4f}")defdistillation_loss(student_logits,teacher_logits,labels,temperature,alpha):ce_loss=F.cross_entropy(student_logits,labels)soft_student_log_probs=F.log_softmax(student_logits/temperature,dim=1)soft_teacher_probs=F.softmax(teacher_logits/temperature,dim=1)kd_loss=F.kl_div(soft_student_log_probs,soft_teacher_probs,reduction="batchmean")returnalpha*ce_loss+(1-alpha)*(temperature**2)*kd_lossdeftrain_distilled(student,teacher,train_loader,test_loader,device,epochs,lr,temperature,alpha):optimizer=torch.optim.Adam(student.parameters(),lr=lr)teacher.to(device)student.to(device)teacher.eval()forepochinrange(1,epochs+1):student.train()running_loss=0.0forimages,labelsintrain_loader:images=images.to(device)labels=labels.to(device)optimizer.zero_grad()student_logits=student(images)withtorch.no_grad():teacher_logits=teacher(images)loss=distillation_loss(student_logits,teacher_logits,labels,temperature,alpha)loss.backward()optimizer.step()running_loss+=loss.item()*images.size(0)test_loss,test_acc=evaluate(student,test_loader,device)train_loss=running_loss/len(train_loader.dataset)print("student_kd "f"epoch{epoch}: train_loss={train_loss:.4f}test_loss={test_loss:.4f}"f"test_acc={test_acc:.4f}temperature={temperature}alpha={alpha}")defmain():parser=argparse.ArgumentParser()parser.add_argument("--data-dir",type=Path,default=Path("data"))parser.add_argument("--dataset",choices=["digits","mnist"],default="digits")parser.add_argument("--batch-size",type=int,default=128)parser.add_argument("--epochs-teacher",type=int,default=3)parser.add_argument("--epochs-student",type=int,default=3)parser.add_argument("--lr",type=float,default=1e-3)parser.add_argument("--temperature",type=float,default=4.0)parser.add_argument("--alpha",type=float,default=0.5,help="Weight for hard-label cross entropy.")parser.add_argument("--seed",type=int,default=42)args=parser.parse_args()set_seed(args.seed)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")print(f"device={device}")train_loader,test_loader=build_loaders(args.data_dir,args.batch_size,args.dataset,args.seed)teacher=TeacherCNN()student_baseline=StudentCNN()student_kd=StudentCNN()print(f"teacher params={count_params(teacher):,}")print(f"student params={count_params(student_baseline):,}")print("\n== Train teacher ==")train_supervised(teacher,train_loader,test_loader,device,args.epochs_teacher,args.lr,"teacher")print("\n== Train student baseline ==")train_supervised(student_baseline,train_loader,test_loader,device,args.epochs_student,args.lr,"student_baseline")print("\n== Train student with knowledge distillation ==")train_distilled(student_kd,teacher,train_loader,test_loader,device,args.epochs_student,args.lr,args.temperature,args.alpha,)teacher_loss,teacher_acc=evaluate(teacher,test_loader,device)baseline_loss,baseline_acc=evaluate(student_baseline,test_loader,device)kd_loss,kd_acc=evaluate(student_kd,test_loader,device)print("\n== Final result ==")print(f"teacher: loss={teacher_loss:.4f}acc={teacher_acc:.4f}params={count_params(teacher):,}")print(f"student_baseline: loss={baseline_loss:.4f}acc={baseline_acc:.4f}params={count_params(student_baseline):,}")print(f"student_kd: loss={kd_loss:.4f}acc={kd_acc:.4f}params={count_params(student_kd):,}")if__name__=="__main__":main()直接运行命令为:python train_mnist_kd.py --epochs-teacher 3 --epochs-student 3
下载MNIST数据集后的运行命令为:python train_mnist_kd.py --dataset mnist --epochs-teacher 3 --epochs-student 3
