知识蒸馏IRG算法实战:使用ResNet50蒸馏ResNet18

news2025/1/21 21:41:32

摘要

复杂度的检测模型虽然可以取得SOTA的精度,但它们往往难以直接落地应用。模型压缩方法帮助模型在效率和精度之间进行折中。知识蒸馏是模型压缩的一种有效手段,它的核心思想是迫使轻量级的学生模型去学习教师模型提取到的知识,从而提高学生模型的性能。已有的知识蒸馏方法可以分别为三大类:

  • 基于特征的(feature-based,例如VID、NST、FitNets、fine-grained feature imitation)
  • 基于关系的(relation-based,例如IRG、Relational KD、CRD、similarity-preserving knowledge distillation)
  • 基于响应的(response-based,例如Hinton的知识蒸馏开山之作。
    在这里插入图片描述
    今天我们就尝试用基于关系的IRG知识蒸馏算法完成这篇实战。IRG蒸馏是对模型里面的的Block和展平层做蒸馏,所以需要返回每个block层的值和展平层的值。所以我们对模型要做修改来适应IRG算法,并且为了使Teacher和Student的网络层之间的参数一致,我们这次选用ResNet50作为Teacher模型,选择ResNet18作为Student。

模型

模型没有用pytorch官方自带的,而是参照以前总结的ResNet模型修改的。ResNet模型结构如下图:
在这里插入图片描述

ResNet18, ResNet34

ResNet18, ResNet34模型的残差结构是一致的,结构如下:
在这里插入图片描述
代码如下:
resnet.py

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
# from torchsummary import summary


class ResidualBlock(nn.Module):
    """
    实现子module: Residual Block
    """

    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)


class ResNet(nn.Module):
    """
    实现主module:ResNet34
    ResNet34包含多个layer,每个layer又包含多个Residual block
    用子module来实现Residual block,用_make_layer函数来实现layer
    """

    def __init__(self, blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.model_name = 'resnet34'

        # 前几层: 图像转换
        self.pre = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1))

        # 重复的layer,分别有3,4,6,3个residual block
        self.layer1 = self._make_layer(64, 64, blocks[0])
        self.layer2 = self._make_layer(64, 128, blocks[1], stride=2)
        self.layer3 = self._make_layer(128, 256, blocks[2], stride=2)
        self.layer4 = self._make_layer(256, 512, blocks[3], stride=2)

        # 分类用的全连接
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        """
        构建layer,包含多个residual block
        """
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()
        )

        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))

        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.pre(x)
        l1_out = self.layer1(x)
        l2_out = self.layer2(l1_out)
        l3_out = self.layer3(l2_out)
        l4_out = self.layer4(l3_out)
        p_out = F.avg_pool2d(l4_out, 7)
        fea = p_out.view(p_out.size(0), -1)
        out=self.fc(fea)
        return l1_out,l2_out,l3_out,l4_out,fea,out

def ResNet18():
    return ResNet([2, 2, 2, 2])


def ResNet34():
    return ResNet([3, 4, 6, 3])


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = ResNet34()
    model.to(device)
    # summary(model, (3, 224, 224))

主要修改了输出结果,将每个block的结果输出出来。

RseNet50、 RseNet101、 RseNet152

这个三个模型的block是一致的,结构如下:
在这里插入图片描述
代码:
resnet_l.py

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101', 'ResNet152']


def Conv1(in_planes, places, stride=2):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_planes, out_channels=places, kernel_size=7, stride=stride, padding=3, bias=False),
        nn.BatchNorm2d(places),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )


class Bottleneck(nn.Module):
    def __init__(self, in_places, places, stride=1, downsampling=False, expansion=4):
        super(Bottleneck, self).__init__()
        self.expansion = expansion
        self.downsampling = downsampling

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=in_places, out_channels=places, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places * self.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(places * self.expansion),
        )

        if self.downsampling:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_places, out_channels=places * self.expansion, kernel_size=1, stride=stride,
                          bias=False),
                nn.BatchNorm2d(places * self.expansion)
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.bottleneck(x)

        if self.downsampling:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, blocks, num_classes=1000, expansion=4):
        super(ResNet, self).__init__()
        self.expansion = expansion

        self.conv1 = Conv1(in_planes=3, places=64)

        self.layer1 = self.make_layer(in_places=64, places=64, block=blocks[0], stride=1)
        self.layer2 = self.make_layer(in_places=256, places=128, block=blocks[1], stride=2)
        self.layer3 = self.make_layer(in_places=512, places=256, block=blocks[2], stride=2)
        self.layer4 = self.make_layer(in_places=1024, places=512, block=blocks[3], stride=2)

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def make_layer(self, in_places, places, block, stride):
        layers = []
        layers.append(Bottleneck(in_places, places, stride, downsampling=True))
        for i in range(1, block):
            layers.append(Bottleneck(places * self.expansion, places))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        l1_out = self.layer1(x)
        l2_out = self.layer2(l1_out)
        l3_out = self.layer3(l2_out)
        l4_out = self.layer4(l3_out)
        p_out = self.avgpool(l4_out)
        fea = p_out.view(p_out.size(0), -1)
        out = self.fc(fea)
        return l1_out, l2_out, l3_out, l4_out, fea, out


def ResNet50():
    return ResNet([3, 4, 6, 3])


def ResNet101():
    return ResNet([3, 4, 23, 3])


def ResNet152():
    return ResNet([3, 8, 36, 3])


if __name__ == '__main__':
    # model = torchvision.models.resnet50()
    model = ResNet50()
    print(model)

    input = torch.randn(1, 3, 224, 224)
    out = model(input)
    print(out.shape)

同上,将每个block都输出出来。

数据准备

数据使用我以前在图像分类任务中的数据集——植物幼苗数据集,先将数据集转为训练集和验证集。执行代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

训练Teacher模型

Teacher选用ResNet50。

步骤

新建teacher_train.py,插入代码:

导入需要的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from model.resnet_l import ResNet50

import json
import os

定义训练和验证函数

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        l1_out,l2_out,l3_out,l4_out,fea, out = model(data)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            l1_out,l2_out,l3_out,l4_out,fea, out = model(data)
            loss = criterion(out, target)
            _, pred = torch.max(out.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

定义全局参数

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'CoatNet'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理与增强

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


读取数据

使用pytorch默认读取数据的方式。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

    # 实例化模型并且移动到GPU
    criterion = nn.CrossEntropyLoss()

    model_ft = ResNet50()
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 12)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    # 训练
    val_acc_list= {}
    for epoch in range(1, EPOCHS + 1):
        train(model_ft, DEVICE, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft, DEVICE, test_loader)
        val_acc_list[epoch]=acc
        with open('result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'CoatNet/model_final.pth')

完成上面的代码就可以开始训练Teacher网络了。

学生网络

学生网络选用ResNet18,是一个比较小一点的网络了,模型的大小有40M。训练100个epoch。

步骤

新建student_train.py,插入代码:

导入需要的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from model.resnet import ResNet18

import json
import os

定义训练和验证函数

# 定义训练过程

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        l1_out,l2_out,l3_out,l4_out,fea,out = model(data)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            l1_out,l2_out,l3_out,l4_out,fea,out = model(data)
            loss = criterion(out, target)
            _, pred = torch.max(out.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc


定义全局参数

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'resnet'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理与增强

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


读取数据

使用pytorch默认读取数据的方式。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

	 # 实例化模型并且移动到GPU
    criterion = nn.CrossEntropyLoss()

    model_ft = ResNet18()
    print(model_ft)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 12)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    # 训练
    val_acc_list= {}
    for epoch in range(1, EPOCHS + 1):
        train(model_ft, DEVICE, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft, DEVICE, test_loader)
        val_acc_list[epoch]=acc
        with open('result_student.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'resnet/model_final.pth')

完成上面的代码就可以开始训练Student网络了。

蒸馏学生网络

学生网络继续选用ResNet18,使用Teacher网络蒸馏学生网络,训练100个epoch。
IRG知识蒸馏的脚本详见:
https://wanghao.blog.csdn.net/article/details/127802486?spm=1001.2014.3001.5502。
代码如下:
irg.py

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class IRG(nn.Module):
	'''
	Knowledge Distillation via Instance Relationship Graph
	http://openaccess.thecvf.com/content_CVPR_2019/papers/
	Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf

	The official code is written by Caffe
	https://github.com/yufanLIU/IRG
	'''
	def __init__(self, w_irg_vert, w_irg_edge, w_irg_tran):
		super(IRG, self).__init__()

		self.w_irg_vert = w_irg_vert
		self.w_irg_edge = w_irg_edge
		self.w_irg_tran = w_irg_tran

	def forward(self, irg_s, irg_t):
		fm_s1, fm_s2, feat_s, out_s = irg_s
		fm_t1, fm_t2, feat_t, out_t = irg_t

		loss_irg_vert = F.mse_loss(out_s, out_t)

		irg_edge_feat_s = self.euclidean_dist_feat(feat_s, squared=True)
		irg_edge_feat_t = self.euclidean_dist_feat(feat_t, squared=True)
		irg_edge_fm_s1  = self.euclidean_dist_fm(fm_s1, squared=True)
		irg_edge_fm_t1  = self.euclidean_dist_fm(fm_t1, squared=True)
		irg_edge_fm_s2  = self.euclidean_dist_fm(fm_s2, squared=True)
		irg_edge_fm_t2  = self.euclidean_dist_fm(fm_t2, squared=True)
		loss_irg_edge = (F.mse_loss(irg_edge_feat_s, irg_edge_feat_t) +
						 F.mse_loss(irg_edge_fm_s1,  irg_edge_fm_t1 ) +
						 F.mse_loss(irg_edge_fm_s2,  irg_edge_fm_t2 )) / 3.0

		irg_tran_s = self.euclidean_dist_fms(fm_s1, fm_s2, squared=True)
		irg_tran_t = self.euclidean_dist_fms(fm_t1, fm_t2, squared=True)
		loss_irg_tran = F.mse_loss(irg_tran_s, irg_tran_t)

		# print(self.w_irg_vert * loss_irg_vert)
		# print(self.w_irg_edge * loss_irg_edge)
		# print(self.w_irg_tran * loss_irg_tran)
		# print()

		loss = (self.w_irg_vert * loss_irg_vert +
				self.w_irg_edge * loss_irg_edge +
				self.w_irg_tran * loss_irg_tran)

		return loss

	def euclidean_dist_fms(self, fm1, fm2, squared=False, eps=1e-12):
		'''
		Calculating the IRG Transformation, where fm1 precedes fm2 in the network.
		'''
		if fm1.size(2) > fm2.size(2):
			fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
		if fm1.size(1) < fm2.size(1):
			fm2 = (fm2[:,0::2,:,:] + fm2[:,1::2,:,:]) / 2.0

		fm1 = fm1.view(fm1.size(0), -1)
		fm2 = fm2.view(fm2.size(0), -1)
		fms_dist = torch.sum(torch.pow(fm1-fm2, 2), dim=-1).clamp(min=eps)

		if not squared:
			fms_dist = fms_dist.sqrt()

		fms_dist = fms_dist / fms_dist.max()

		return fms_dist

	def euclidean_dist_fm(self, fm, squared=False, eps=1e-12): 
		'''
		Calculating the IRG edge of feature map. 
		'''
		fm = fm.view(fm.size(0), -1)
		fm_square = fm.pow(2).sum(dim=1)
		fm_prod   = torch.mm(fm, fm.t())
		fm_dist   = (fm_square.unsqueeze(0) + fm_square.unsqueeze(1) - 2 * fm_prod).clamp(min=eps)

		if not squared:
			fm_dist = fm_dist.sqrt()

		fm_dist = fm_dist.clone()
		fm_dist[range(len(fm)), range(len(fm))] = 0
		fm_dist = fm_dist / fm_dist.max()

		return fm_dist

	def euclidean_dist_feat(self, feat, squared=False, eps=1e-12):
		'''
		Calculating the IRG edge of feat.
		'''
		feat_square = feat.pow(2).sum(dim=1)
		feat_prod   = torch.mm(feat, feat.t())
		feat_dist   = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)

		if not squared:
			feat_dist = feat_dist.sqrt()

		feat_dist = feat_dist.clone()
		feat_dist[range(len(feat)), range(len(feat))] = 0
		feat_dist = feat_dist / feat_dist.max()

		return feat_dist

步骤

新建kd_train.py,插入代码:

导入需要的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets

from model.resnet import ResNet18

import json
import os

from irg import IRG

定义训练和验证函数

# 定义训练过程
def train(s_net,t_net, device, criterionCls,criterionKD,train_loader, optimizer, epoch):
    s_net.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        l1_out_s,l2_out_s,l3_out_s,l4_out_s,fea_s, out_s = s_net(data)
        cls_loss = criterionCls(out_s, target)
        l1_out_t,l2_out_t,l3_out_t,l4_out_t,fea_t, out_t = t_net(data)  # 训练出教师的 teacher_output
        kd_loss = criterionKD([l3_out_s, l4_out_s, fea_s, out_s],
                              [l3_out_t.detach(),
                               l4_out_t.detach(),
                               fea_t.detach(),
                               out_t.detach()]) * lambda_kd
        loss = cls_loss + kd_loss
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device,criterionCls, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            l1_out_s, l2_out_s, l3_out_s, l4_out_s, fea_s, out_s = model(data)
            loss = criterionCls(out_s, target)
            _, pred = torch.max(out_s.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc


定义全局参数

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'resnet_kd'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    w_irg_vert=0.1
    w_irg_edge=5.0
    w_irg_tran=5.0
    lambda_kd=1.0

图像预处理与增强

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


读取数据

使用pytorch默认读取数据的方式。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

 model_ft = ResNet18()
    print(model_ft)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 12)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    teacher_model=torch.load('./CoatNet/best.pth')
    teacher_model.eval()
    # 实例化模型并且移动到GPU
    criterionKD = IRG(w_irg_vert, w_irg_edge, w_irg_tran)
    criterionCls = nn.CrossEntropyLoss()
    # 训练
    val_acc_list= {}
    for epoch in range(1, EPOCHS + 1):
        train(model_ft,teacher_model, DEVICE,criterionCls,criterionKD, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft,DEVICE,criterionCls , test_loader)
        val_acc_list[epoch]=acc
        with open('result_kd.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'resnet_kd/model_final.pth')

完成上面的代码就可以开始蒸馏模式!!!

结果比对

加载保存的结果,然后绘制acc曲线。

import numpy as np
from matplotlib import pyplot as plt
import json
teacher_file='result.json'
student_file='result_student.json'
student_kd_file='result_kd.json'
def read_json(file):
    with open(file, 'r', encoding='utf8') as fp:
        json_data = json.load(fp)
        print(json_data)
    return json_data

teacher_data=read_json(teacher_file)
student_data=read_json(student_file)
student_kd_data=read_json(student_kd_file)


x =[int(x) for x in  list(dict(teacher_data).keys())]
print(x)

plt.plot(x, list(teacher_data.values()), label='teacher')
plt.plot(x,list(student_data.values()), label='student without KD')
plt.plot(x, list(student_kd_data.values()), label='student with KD')

plt.title('Test accuracy')
plt.legend()

plt.show()

总结

本文重点讲解了如何使用IRG知识蒸馏算法对Student模型进行蒸馏。希望能帮助到大家,如果觉得有用欢迎收藏、点赞和转发;如果有问题也可以留言讨论。
本次实战用到的代码和数据集详见:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/14644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang入门笔记(11)—— 包

使用包的原因&#xff1a; 1.不可能把所有的不同业务功能的函数都放在一个源文件中&#xff0c;这样不便于管理。通常的做法是&#xff1a;我们会把具有相同一些功能和业务的维度的函数&#xff0c;分门别类的放在不同的源文件中。 2.不同的包名&#xff0c;可以解决两个函数…

Java面对对象的特征之二:继承性 :why?

一、继承性的好处&#xff1a; 减少了代码的冗余&#xff0c;提高了代码便于功能的扩展为之后的多态性的使用&#xff0c;提高了前提 二、继承性的格式&#xff1a;class A extends B{} A:子类、派生类、subclass B&#xff1a;父类、超类、基类、superclass 提现&#xff1…

数据中台解决方案-最新全套文件

数据中台解决方案-最新全套文件一、建设背景面临的挑战1、数据孤岛2、管理困难3、感知不强4、融合不足5、响应滞后二、思路架构三、建设方案四、获取 - 数据中台全套最新解决方案合集一、建设背景 在数字化转型及大数据战略布局背景下&#xff0c;建设大数据平台及数据应用。其…

idea连接kerberos认证的hive

其实用dbeaver连接hive就可以了。但是呢&#xff0c;idea也有这个功能&#xff0c;本着研究下的想法就试试。 结果最后成功了 最后记录下。 参考文章。感觉不太行 PyCharm,idea通过插件database连接带Kerberos的hive_不饿同学的博客-CSDN博客 里面提到了两个解决办法&#…

D. Non-zero Segments(前缀和)

Problem - 1426D - Codeforces 题意: 科利亚得到一个整数数组a1,a2,...,an。这个数组既可以包含正整数也可以包含负整数&#xff0c;但是Kolya不喜欢0&#xff0c;所以这个数组不包含任何零。 Kolya不喜欢他的数组中某些子段的总和为0&#xff0c;子段是数组中一些连续的元素…

IDR 学习笔记

Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance 主页&#xff1a;https://lioryariv.github.io/idr/ 论文&#xff1a;https://arxiv.org/abs/2003.09852 代码&#xff1a;https://github.com/lioryariv/idr 效果展示 idr_fountain效果图…

【面试题】循环队列队列实现栈栈实现队列

1️⃣设计循环队列OJ链接 2️⃣用队列实现栈OJ链接 3️⃣用栈实现队列OJ链接 这几道面试题需要栈和队列的知识&#xff0c;它们的相关知识可以看我的上一篇文章 1️⃣设计循环队列 先来了解一下环形队列&#xff0c;这也是循环队列的思想&#xff0c;空间是固定的&#xff0c;数…

Kafka分区策略

默认分区器DefaultPartitioner &#xff08;1&#xff09;指明partition的情况下&#xff0c;直 接将指明的值作为partition值&#xff1b; &#xff08;2&#xff09;没有指明partition值但有key的情况下&#xff0c;将key的hash值与topic的 partition数进行取余得到partiti…

代谢组学——最接近生物表型的组学

■ 什么是代谢组学 在基于基因组-转录组-蛋白质组-代谢组的系统生物学框架内&#xff0c;代谢组学 (metabolomics/metabonomics) 处于最下游&#xff0c;最接近生物表型&#xff0c;主要通过考察生物体系在某一特定时期内受到刺激或扰动前后所有小分子代谢物 (分子量小于 1500…

信创国产化大背景下,应用性能体验如何保障?

信创产业是拉动中国经济增长不可或缺的重要抓手。从2020年我国迈入信创发展元年&#xff0c;到2022年信创开始向行业“深水区”迈进&#xff0c;信创产业得到了国家相关政策的大力支持。今年9月底国家下发79号文&#xff0c;全面指导国资信创产业的发展和进度&#xff0c;明确要…

bootstrap导航窗格响应式二级菜单

这次碰到的需求是响应式二级导航窗格&#xff0c;默认的导航窗格只有点击下拉框的二级窗格&#xff0c;会有如下问题&#xff1a;一级菜单无法添加超链接&#xff0c;二级菜单展示要多点一下。 实现目标&#xff1a; 1.滑动到指定区域&#xff0c;展示二级菜单。 2.一级菜单和…

Vue3 - 响应式工具函数(使用教程)

前言 您需要对 ref()、reactive() 有所了解&#xff0c;否则要先学习这些。 Vue3 为响应式提供了一些工具函数&#xff0c;辅助开发&#xff1a; API说明isRef()检查某个值是否为 ref。isProxy()检查一个对象是否是由 reactive()、readonly()、shallowReactive() 或 shallowRe…

前端国际化如何对中文——>英文自动化翻译小demo

非专业的国际化语言。 需求是把zh.js文件中的对象的值转换为en.js&#xff08;也就是实现中英文翻译&#xff09; 结果&#xff1a; 话不多说&#xff0c;上技巧&#xff01; 首先找个免费翻译的API接口&#xff0c;我找的百度翻译的API接口。百度翻译开放平台看百度翻译技术…

仅此一招,再无消息乱序的烦恼

1. 概览 RocketMQ 早已提供了一组最佳实践&#xff0c;但工作在一线的伙伴却很少知道&#xff0c;项目中的各种随性代码经常导致消息错乱问题&#xff0c;严重影响业务的准确性。为了保障最佳实践的落地&#xff0c;降低一线伙伴的使用成本&#xff0c;统一 MQ 使用规范&#…

AF488 NHS,AF488 活性酯,Alexa Fluor488 NHS,水溶性小分子绿色荧光标记染料

AF488 NHS通过引入两个磺酸根离子&#xff0c;AF488的水溶性大大增强&#xff0c;荧光强度增加&#xff0c;pH稳定性&#xff0c;光稳定性也提高&#xff0c;但是它的激发和发射谱图基本保持不变。不像荧光素类染料&#xff0c;AF488的荧光在较宽的pH范围内(4 – 10)保持不变。…

ATF源码篇(八):docs文件夹-Components组件(7)固件配置框架

7、固件配置框架 fconf/索引 本文档概述了固件配置框架 7.1 固件配置框架是什么&#xff1f; 1 介绍 固件配置框架&#xff08;|FCONF|&#xff09;是平台特定数据的抽象层&#xff0c;允许查询“属性”并检索值&#xff0c;而请求实体不知道使用什么后备存储来保存数据。 …

Java接口(Interface)

文章目录接口语法注意事项和细节实现接口VS.继承类接口的多态特性小练习usb插槽就是现实中的接口。 你可以把手机,相机,u盘都插在usb插槽上,而不用担心那个插槽是专门插哪个的,原因是做usb插槽的厂家和做各种设备的厂家都遵守了统一的规定包括尺寸&#xff0c;排线等等。 首先创…

ISP-Gamma

参考:https://blog.csdn.net/lxy201700/article/details/24929013 http://www.cambridgeincolour.com/tutorials/gamma-correction.htm 1. 什么是Gamma Gamma是一种指数曲线&#xff0c;显示器用这个指数曲线来调整真实输出到显示屏幕上的颜色值&#xff0c;以此更好的适应人…

卷?这份Java后端架构指南首次公开就摘星百万,肝完直接60K+

最近和各位小伙伴儿私下聊的比较多&#xff0c;各个阶段的朋友都有&#xff1b;因为大环境的内卷&#xff0c;导致大家在求学、求职、提升自己的各个方面都多多少少有些迷茫焦虑&#xff1b; 这些其实是一个非常普遍且正常的现象&#xff0c;会焦虑的人&#xff0c;往往都是对…

大学生简单个人静态HTML网页设计作品 HTML+CSS制作我的家乡杭州 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载 HTML5期末大作业

常见网页设计作业题材有 个人、 美食、 公司、 学校、 旅游、 电商、 宠物、 电器、 茶叶、 家居、 酒店、 舞蹈、 动漫、 服装、 体育、 化妆品、 物流、 环保、 书籍、 婚纱、 游戏、 节日、 戒烟、 电影、 摄影、 文化、 家乡、 鲜花、 礼品、 汽车、 其他等网页设计题目, A…