基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(二)之知识收集阶段

news2024/12/23 22:49:49

前面学习了模型的构建与训练过程,然而在实验过程中,博主依旧对数据集与模型之间的关系有些疑惑,首先是论文说这是一个混合数据集,但事实上博主在实验时是将三个数据集分开的,那么在数据读取时是如何混合的呢,是每个epoch使用同一个数据集,下一个epoch再换数据集,还是再epoch中随机取数据集中的一部分。
此外,教师模型总共有三个,其模型构造是完全相同的,不同之处在于三个教师模型是在不同的数据集训练得到的,即其权重参数是固定的,那么在训练过程中,从代码来看,原始的教师网络权重是不改变的,那么说如何更新学生网络呢?带着这些疑问,开始今天的学习。

数据集加载

首先需要明确的是数据集加载时是将三个数据集进行了合并,只不过会按照三个数据集进行区别,即生成list形式。train_loader的相关参数设置如下:

在这里插入图片描述

模型训练

模型的训练分为两个阶段,分别是知识收集阶段与知识检验阶段,即knowlwdge collect(kc)knowledge exam(ke)两阶段。

在这里插入图片描述

在开始前,需要声明必须要将batch-size设置为3以上,否则会无法加载数据集
首先是知识收集阶段:
声明损失函数,这里的损失函数有两个,分别是L1损失与通过VGG网络计算的软损失(SCRLoss)

criterion_l1, criterion_scr, _ = criterions

在这里插入图片描述

模型开启traineval,关于两者的区别:

model.train()的作用是启用 Batch NormalizationDropout。在train模式,Dropout层会按照设定的参数p设置保留激活单元的概率,如keep_prob=0.8,Batch Normalization层会继续计算数据的meanvar并进行更新。
model.eval()的作用是不启用 Batch NormalizationDropout。在eval模式下,Dropout层会让所有的激活单元都通过,而Batch Normalization层会停止计算和更新meanvar,直接使用在训练阶段已经学出的meanvar值。在使用model.eval()时就是将模型切换到测试模式,在这里,模型就不会像在训练模式下一样去更新权重。
但是需要注意的是model.eval()不会影响各层的梯度计算行为,即会和训练模式一样进行梯度计算和存储,只是不进行反向传播。

model.train()#  model开启train
ckt_modules.train()
for teacher_network in teacher_networks:#为教师网络开启eval()
	teacher_network.eval()

随后便进入核心代码模块了:这里包含模型运算,特征映射,损失计算等过程
这里我们对应论文的创新点来看代码。
首先是进度条加载,这里是对数据集加载train_load的封装

pBar = tqdm(train_loader, desc='Training')

遍历数据,判断数据是否为空,这里曾经困扰过博主一段时间,因为每次遍历时target_image都为空,只要将batch-size设置为3以上即可。

for target_images, input_images in pBar:
	if target_images is None: continue
	target_images = target_images.cuda()
	input_images = [images.cuda() for images in input_images]
	preds_from_teachers = []

可以看到,此时已经将输入图像,目标图像转换为tensor格式,其中input_imageslist形式,每张图像为torch.Size([1, 3, 224, 224])

在这里插入图片描述

而target_images为完全为tensor格式,shape为torch.Size([3, 3, 224, 224])

在这里插入图片描述

简要描述知识收集阶段

teacher_networks即为教师网络列表,单个的教师网络模型与学生网络是相同的,将数据输入教师网络时,由于需要使用教师网络的中间特征,因此return_feat为True,最终的输出结果为预测结果图与中间特征图,预测结果图会作为 “真值” 来训练学生网络,并计算软损失,中间特征图会与学生网络进行映射到同一特征域来进行特征转移,并将教师网络的预测结果与学生网络的预测结果求SCRLoss。

preds_from_teachers = []
features_from_each_teachers = []
with torch.no_grad():
for i in range(len(teacher_networks)):
	preds, features = teacher_networks[i](input_images[i], return_feat=True)
	preds_from_teachers.append(preds)
	features_from_each_teachers.append(features)		

随后将图像输入教师模型,教师模型不更新权重,只是用模型输出的特征来帮助学生网络来训练,称为软损失。核心代码如下:

preds, features = teacher_networks[i](input_images[i], return_feat=True)

将图像 i 输入对应的教师网络 i,这里的i指的是教师网络的索引,这里博主开始曾经有过疑惑,此时的batch_size为3,刚好与教师网络数量对应,因此可以使用该网络,那如果batch_size为6,9时呢,后面的岂不是都无法输入模型了吗,随后博主将batch_size改为6,发现此时的input_image依旧是list形式,但每个list中的内容已经发生了改变,可以看到其是按照不同的数据集类型做了区分,这就是为何input_image要使用listtarget_imagetensor的原因了。现在之前的疑惑也就消失了。
在这里插入图片描述

随后获得输出结果pred,即预测结果,也就是恢复后的图像。可以看到其与输入图像的维度是一致的,对于第一个网络的第一组输入图像,都为:torch.Size([3, 3, 224, 224])

在这里插入图片描述
而返回的中间特征图像如图所示,可以看到输出的不同大小的特征图,总共有4组,即4组不同大小的特征图,每组3张图像,通道数,宽高则不相同。
第 1 组数据集(教师网络)的中间特征图:

在这里插入图片描述
第 3 组数据集(教师网络)的中间特征图:

在这里插入图片描述

随后经过三个网络模型的运算,将结果加入列表:

preds_from_teachers.append(preds)
features_from_each_teachers.append(features)

在这里插入图片描述
在这里插入图片描述

随后将教师网络的预测值转换为tensor格式,因为在最终学生网络的输出是tensor

preds_from_teachers = torch.cat(preds_from_teachers)

原本list变为tensor
在这里插入图片描述
接下来这段是对feature按照特征图大小进行分组,现在的特征图是按照数据集划分为3组,为方便后面做特征映射,将其按照特征图大小分为四组。

for layer in range(len(features_from_each_teachers[0])):
	features_from_teachers.append([features_from_each_teachers[i][layer] for i in range(len(teacher_networks))])

在这里插入图片描述

随后便是将输入图像输入学生网络输出结果与中间特征图,这里是不区分数据集的,完全是混合的

preds_from_student, features_from_student = model(torch.cat(input_images), return_feat=True)

由于博主将batch设置为6会报显存溢出,因此这里改为4,可以看到中间特征图依旧是四组,不过每组的第一个值由6变为了4,其余都没有改变。
可以看到list为4组,代表4组不同尺度特征图,每组里面又有一个list,每个list中包含不同数据集(教师网络的特征图)分别是2,1,1。

在这里插入图片描述
同理输出结果也是由6变4。

在这里插入图片描述

CKT模块(特征转移)

随后便是中间特征图映射了,其过程其实也很简单,即将教师网络特征如与学生网络特征图同时输入CKT模型中,并获得输出结果,将输出结果做损失即可。
在这里插入图片描述

PFE_loss, PFV_loss = 0., 0.
for i, (s_features, t_features) in enumerate(zip(features_from_student, features_from_teachers)):
	t_proj_features, t_recons_features, s_proj_features = ckt_modules[i](t_features, s_features)
	PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
	PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

可以看到输入的教师网络特征与学生网络特征也不是相同格式的:

在这里插入图片描述
输入值:
经过遍历后,学生网络的特征图分为四组,分别对应不同尺度的特征图,但没有区分数据集,因为本身学生网络就是不区分数据集的。
在这里插入图片描述
而教师网络却是list形式,每个数据集分别对应2,1,1个图像数量
在这里插入图片描述
CKT网络定义:

class CKTModule(nn.Module):
    def __init__(self, channel_t, channel_s, channel_h, n_teachers):
        super().__init__()
        self.teacher_projectors = TeacherProjectors(channel_t, channel_h, n_teachers)
        self.student_projector = StudentProjector(channel_s, channel_h)
    def forward(self, teacher_features, student_feature):
        teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)
        student_projected_feature = self.student_projector(student_feature)
        return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature

具体结构如下,CKT模块共有4个,即对应不同尺度的特征图,注意功能便是进行一系列的特征映射与转换。

CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )

特征转移实际上也是通过损失函数来进行的,即通过一个网学习特征,从而达到特征转移的效果。
最终获得三个结果,分别是教师网络结构特征,教师网络重构特征,学生网络结构特征。核心代码如下:

teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)
student_projected_feature = self.student_projector(student_feature)
return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature

输出值:
与输入值一样,学生网络结构特征的输出值为tensor形式

在这里插入图片描述
而教师网络特征与教师网络重构特征的输出值依旧为list形式。

在这里插入图片描述

在这里插入图片描述

随后求特征损失与重构损失即可。

PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

在这里插入图片描述

在这里插入图片描述

最终求总损失与SCR损失即可,值得注意的是SCR损失需要使用VGG网络做特征变换后再计算。
L1损失较为简单,输入为学生网络预测值与教师网络预测值

T_loss = criterion_l1(preds_from_student, preds_from_teachers)
SCR_loss = 0.1 * criterion_scr(preds_from_student, target_images, torch.cat(input_images))

关于criterion_l1函数,其实际上是首先使用VGG网络进行特征变换,其输入数据分别是学生网络预测值,目标图像以及输入图像。
SCRLoss定义如下:根据在forward中的代码可知,其首先将输入值分别输入VGG网络进行特征变换,随后在将输出值计算L1损失。
其中,detch方法是返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_gradfalse,得到的这个tensor永远不需要计算其梯度,不具有grad。即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()tensor就会停止,不能再继续向前进行传播。
最终乘以对应的权重,返回最后的损失。

class SCRLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = Vgg19().cuda()
        self.l1 = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
    def forward(self, a, p, n):
        a_vgg, p_vgg, n_vgg = self.vgg(a), self.vgg(p), self.vgg(n)
        loss = 0
        d_ap, d_an = 0, 0
        for i in range(len(a_vgg)):
            d_ap = self.l1(a_vgg[i], p_vgg[i].detach())
            d_an = self.l1(a_vgg[i], n_vgg[i].detach())
            contrastive = d_ap / (d_an + 1e-7)
            loss += self.weights[i] * contrastive
        return loss

可以看到最后的损失值是Tensor形式的。

在这里插入图片描述

至此,知识收集阶段便完成了。接下来便是知识测试阶段。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1063168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【前端开发---Vue3】前段开发之详细的Vue3入门教程,特别适合小白系统学习,入门到熟练使用Vue看这一篇就够了!

前言: 这篇文章更加侧重的是Vue3不同于Vue2的知识点,如果学习Vue2请看下面这篇文章 Vue2详细系统入门教程 11.2 Vue3 声明:图片资源来自于黑马程序员公开学习资料 本人在学习当中,详细整理了笔记,供大家参考学习 1…

基于最近电平逼近的开环MMC逆变器Simulink仿真模型

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

重新认识AUTOSAR Interface

核心: S/R interface: data elementC/S interface: operation (with arguement)M/S interface: mode group (macro) 其实 data element,operation,mode group 才是核心,他们可以看作是用户自定义的变量类…

柯桥实用口语学习,韩语口头禅系列短句-恋爱篇

사랑해.我爱你。 너한테 미치겠어.我为你疯狂。 난 니거야.我是你的。 넌 내거야.你是我的。 너 잘 생겼어.你很帅。 네가 뽀뽀/키스 해도 돼? 我可以吻你吗?

基于虚拟阻抗的下垂控制——孤岛双机并联Simulink仿真

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

[Spring] Spring5——事务简介

目录 一、事务概述 1、什么是事务 2、事务的四个特性(ACID) 二、搭建事务操作环境 1、dao、service 两层结构 2、示例 3、模拟异常(事务场景引入) 三、Spring 事务管理 1、事务管理介绍 2、声明式事务管理——注解方式 …

c++ 学习 之 继承的基本语法

继承可以减少重复的代码 语法&#xff1a; class 子类 : 继承方式 父类子类 也称为 派生类 父类 也称为 基类 class BasePage { public:void header(){cout << "首页&#xff0c;公开课&#xff0c;登录&#xff0c;注册。。。&#xff08;公共头部&#xff09…

C/C++——内存管理

1.为什么存在动态内存分配 灵活性 静态内存分配是在编译时确定的&#xff0c;程序执行过程中无法改变所分配的内存大小&#xff1b;动态内存分配可以根本程序的运行环境来动态分配和释放空间&#xff0c;提供了更大的灵活性 动态数据结构 有些数据结构的大小和结构在编译时…

input输入多行文本:删除“首先 其次 此外 总的来说”

input允许多行输入 233.3表示停止输入input输入多行文本文本 &#xff08;空行&#xff09; &#xff08;空行&#xff09; &#xff08;空行&#xff09; 正文 &#xff08;空行&#xff09; &#xff08;空行&#xff09; &#xff08;空行&#xff09; 正文 &#xff08;空行…

国庆节:不仅仅是庆祝,更是成长与体验

目录 国庆节&#xff1a;不仅仅是庆祝&#xff0c;更是成长与体验引言第一部分&#xff1a;旅途风景目的地选择旅行亮点与国庆的联系 技术主题完成的博文国庆与技术 第三部分&#xff1a;回家的路为什么回家艰难险阻家与国庆 结论 国庆节&#xff1a;不仅仅是庆祝&#xff0c;更…

【Spring笔记02】Spring中的IOC容器和DI依赖注入介绍

这篇文章&#xff0c;主要介绍一下Spring中的IOC容器和DI依赖注入两个概念。 目录 一、IOC控制反转 1.1、什么是IOC 1.2、两种IOC容器 &#xff08;1&#xff09;基于BeanFactory的IOC容器 &#xff08;2&#xff09;基于ApplicationContext的IOC容器 二、DI依赖注入 2.…

Vue MVVM 模型

一、什么事MVVM 模型 MVVM 是 Model-View-ViewModel 的缩写&#xff0c;它是一种软件架构风格 Model&#xff1a;模型&#xff0c; 数据对象&#xff08;data 函数&#xff09;&#xff0c;如下图 View&#xff1a;视图&#xff0c;模板页面&#xff08;用于渲染数据&#xf…

掌握Mac菜单栏,尽在Bartender 5!菜单栏图标管理软件的终极推荐!

作为Mac用户&#xff0c;菜单栏是我们每天使用电脑时最常接触的区域之一。然而&#xff0c;随着我们安装越来越多的应用程序&#xff0c;菜单栏上的图标往往变得拥挤不堪&#xff0c;给我们的工作和生活带来了不便。 幸运的是&#xff0c;有了Bartender 5这款强大的菜单栏图标…

数据结构与算法(Python)

数据结构与算法 算法基础时间复杂度空间复杂度 递归实例&#xff1a;汉诺塔问题 查找顺序查找&#xff08;线性查找&#xff09;二分查找&#xff08;折半查找&#xff09;比较 排序冒泡排序选择排序插入排序快速排序快排和冒泡的时间比较 堆排序树堆堆的向下调整 堆排序过程时…

除静电设备的工作原理及应用

除静电设备主要包括静电消除器、静电接地装置、静电消除风机等&#xff0c;它们的工作原理和应用如下&#xff1a; 静电消除器&#xff1a;静电消除器的工作原理是利用电离和电击的原理来中和电荷。它包括一个金属板和一个高压电源。当静电消除器接通电源后&#xff0c;金属板…

Redis最常见应用场景

缓存&#xff08;Cache&#xff09; Redis的第一个应用场景是Redis作为缓存对象来加速Web应用的访问。 在该场景下&#xff0c;有一些存储于数据库中的数据会被频繁访问&#xff0c;如果频繁的访问数据库&#xff0c;数据库负载会升高&#xff0c;同时由于数据库IO比较慢&…

阿里云服务器更换公网IP地址的方法流程

阿里云服务器可以更换IP地址吗&#xff1f;可以的&#xff0c;创建6小时以内的云服务器ECS可以免费更换三次公网IP地址&#xff0c;超过6小时的云服务器&#xff0c;可以将公网固定IP地址转成弹性EIP&#xff0c;然后通过换绑EIP的方式来更换IP地址。阿里云服务器网分享阿里云服…

阿里云服务器地域节点怎么选择合适?啥是可用区?

阿里云服务器地域和可用区怎么选择&#xff1f;地域是指云服务器所在物理数据中心的位置&#xff0c;地域选择就近选择&#xff0c;访客距离地域所在城市越近网络延迟越低&#xff0c;速度就越快&#xff1b;可用区是指同一个地域下&#xff0c;网络和电力相互独立的区域&#…

基于遗传算法的新能源电动汽车充电桩与路径选择(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

JavaScript系列从入门到精通系列第十六篇:JavaScript使用函数作为属性以及枚举对象中的属性

文章目录 前言 1&#xff1a;对象属性可以是函数 2&#xff1a;对象属性函数被称为方法 一&#xff1a;枚举对象中的属性 1&#xff1a;for...in 枚举对象中的属性 前言 1&#xff1a;对象属性可以是函数 对象的属性值可以是任何的数据类型&#xff0c;也可以是函数。 v…