基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(三)之知识测试阶段与评估模块

news2024/11/17 21:21:55

去雨去雾去雪算法分为两个阶段,分别是知识收集阶段与知识测试阶段,前面我们已经学习了知识收集阶段,了解到知识阶段的特征迁移模块(CKT)与软损失(SCRLoss),那么在知识收集阶段的主要重点便是HCRLoss(硬损失),事实上,知识测试阶段要比知识收集阶段简单,因为这个模块只需要训练学生网络即可。

模型创新点

在进行知识测试阶段的代码学习之前,我们来回顾一下去雨去雪去雾网络的创新点:
首先是提出两阶段的知识蒸馏网络,即构建三个教师网络与一个学生网络,设置总训练次数为250,其中前125个epoch教师网络与学生网络一同训练,这里的训练是指将图像输入教师网络,随后将教师网络的输出结果与中间特征图保留,将其作为真值指导学生网络进行训练。
其次便是提出知识迁移模块(CKT)该模块的作用是将教师网络的特征迁移到学生网络。
随后便是软损失与硬损失计算了,这个其实是知识蒸馏中的概念。
总体来看去雨去雾去雪网络的设计虽然较为新颖,但事实上就是知识蒸馏网络的架构,本着这一点,程序理解起来也就容易多了。

在这里插入图片描述

接下来开始代码的学习:

小插曲(算力不足)

首先需要指出,前面将batch-size设置为4,但却会报错:

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

开始时博主以为是cuDNN与CUDA版本不匹配导致的,但后来一想不对呀,先前已经运行过呀,那么问题很可能便是batch出问题了,果然将batch改为3后就正常了,这是由于算力不足导致的,注意算力不足和显存不足还是有区别的。
将batch-size改为3后重新运行,开始知识测试阶段的探索。

知识测试阶段

事实上,知识测试阶段的实现与知识收集阶段几乎相同,并且要比知识收集阶段简单,其只是训练学生网络,并计算一个硬损失而已。
由于知识测试阶段与知识收集阶段几乎相同,因此有许多地方是重复的,这里博主便会简要介绍。
首先相同的是使用train_loader进行训练集的加载,并使用tqdm进行封装。
随后便是遍历过程,这个过程就要简单很多了,没有使用到教师网络,直接将图像输入学生网络进行预测即可,这里的学生网络与教师网络的构造是完全相同的,将结果分别计算L1损失与HCR_loss即可。不过需要注意的是由于该阶段不需要与教师网络进行特征迁移,因此就不需要返回中间特征图了,即设置return_feat=False

for target_images, input_images in pBar:
		if target_images is None: continue
		target_images = target_images.cuda()
		input_images = torch.cat(input_images).cuda()
		preds = model(input_images, return_feat=False)
		G_loss = criterion_l1(preds, target_images)
		HCR_loss = 0.2 * criterion_hcr(preds, target_images, input_images)
		total_loss = G_loss + HCR_loss

至于其他的基本就相同了,需要注意的是这里的batch设置为3。接下来记录一下数据的变化情况:

input_images:输入图像,torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
target_images:目标图像(真值),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
preds:预测图像(去噪后的图像),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度

在这里插入图片描述

随后计算L1损失与HCRLoss,由于在学生网络中使用的事实上是混合数据集,即不区分去噪类型,因此输入图像等都是直接使用tesnor格式,而非list格式。

G_loss:tensor(0.5621, device='cuda:0', grad_fn=<L1LossBackward>)

HCRLoss

SCRLoss相同,HCRLoss也是先将图像进行特征转换后再计算损失的

HCRLoss(
  (vgg): Vgg19(
    (slice1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (slice2): Sequential(
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
    )
    (slice3): Sequential(
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
    )
    (slice4): Sequential(
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (17): ReLU(inplace=True)
      (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (20): ReLU(inplace=True)
    )
    (slice5): Sequential(
      (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (22): ReLU(inplace=True)
      (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (24): ReLU(inplace=True)
      (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (26): ReLU(inplace=True)
      (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (29): ReLU(inplace=True)
    )
  )
  (l1): L1Loss()
)
HCRLoss:tensor(0.3274, device='cuda:0', grad_fn=<MulBackward0>)

评估模块

至此,知识测试阶段便完成了,随后便是模型评估了。这里默认设置评估时的batch-size为1,即每次输入一张图像。
所谓的评估指的是对学生网络的评估,该模块其实与知识测试阶段类似,不同之处在于这里是需要计算SSIMPSNR的。至于其他则是完全相同,核心代码如下:

for target, image in pBar:
		if torch.cuda.is_available():
			image = image.cuda()
			target = target.cuda()
		pred = model(image)   		
		psnr_list.append(torchPSNR(pred, target).item())
		ssim_list.append(pytorch_ssim.ssim(pred, target).item())

由于batch-size设置为1,因此targettorch.Size([1, 3, 480, 640])image也为torch.Size([1, 3, 480, 640]),这里需要注意的是,在训练阶段(包含知识收集与知识测试阶段),数据集中的图像都要转换为224x224的大小,而在评估阶段则不需要进行转换了,即使用的是原图像的大小。
直接将输入图输入模型,获的去噪后的图像pred大小为torch.Size([1, 3, 480, 640])

pred = model(image)

在这里插入图片描述

随后将预测图像与真值图像进行计算PSNR与SSIM

psnr_list.append(torchPSNR(pred, target).item())
ssim_list.append(pytorch_ssim.ssim(pred, target).item())

PSNR计算

@torch.no_grad()
def torchPSNR(prd_img, tar_img):
	if not isinstance(prd_img, torch.Tensor):
		prd_img = torch.from_numpy(prd_img)
		tar_img = torch.from_numpy(tar_img)

	imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)
	rmse = (imdff**2).mean().sqrt()
	ps = 20 * torch.log10(1/rmse)
	return ps

SSIM计算

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)
    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()
        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)        
            self.window = window
            self.channel = channel
        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)  
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    return _ssim(img1, img2, window, window_size, channel, size_average)

将每个循环得到的psnrssim加入列表

在这里插入图片描述
最后的PSNRSSIM是对list中的所有值求平均:

print("PSNR: {:.3f}".format(np.mean(psnr_list)))
print("SSIM: {:.3f}".format(np.mean(ssim_list)))

至此,知识测试阶段与评估模块就讲解完成了,接下来博主将对该模型进行改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1065397.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript中的模块化编程,包括CommonJS和ES6模块的区别。

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 模块化编程概述⭐ CommonJS 模块⭐ ES6 模块⭐ 区别⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、…

苹果恢复微信聊天记录的3个实用方法!

愁死我了&#xff01;朋友们&#xff01;把手机借给了亲戚家的小孩玩&#xff0c;拿回手机后发现很重要的聊天记录丢失了&#xff0c;怎么办呀&#xff0c;有什么方法能够恢复回来吗&#xff1f; 微信是架起我们与家人、朋友、同事之间沟通的桥梁&#xff0c;无论是工作还是生活…

【软考】8.1 程序语言基本概念-成分-函数

《程序设计语言的基本概念》 汇编&#xff1a;将汇编语言翻译成目标程序执行编译&#xff1a;生成独立的可执行文件&#xff08;逻辑上与源程序等价的目标程序&#xff09;&#xff1b;直接运行&#xff1b;运行时无法控制源程序&#xff1b;效率高解释&#xff1a;不生成可执行…

mmap底层驱动实现(remap_pfn_range函数)

mmap底层驱动实现 myfb.c&#xff08;申请了128K空间&#xff09; #include <linux/init.h> #include <linux/tty.h> #include <linux/device.h> #include <linux/export.h> #include <linux/types.h> #include <linux/module.h> #inclu…

Mybatis 使用参数时$与#的区别

之前我们介绍了mybatis中参数的使用&#xff0c;本篇我们在此基础上介绍Mybatis中使用参数时$与#的区别。 如果您对mybatis中参数的使用不太了解&#xff0c;建议您先进行了解后再阅读本篇&#xff0c;可以参考&#xff1a; Mybatis参数(parameterType)https://blog.csdn.net…

知识图谱和大语言模型的共存之道

源自&#xff1a;开放知识图谱 “人工智能技术与咨询” 发布 导 读 01 知识图谱和大语言模型的历史 图1 图2 图3 图4 图5 02 知识图谱和大语言模型作为知识库的优缺点 图6 图7 表1 表2 图8 图9 03 知识图谱和大语言模型双知识平台融合 图10 图11 04 总结与展望 声明:公众号转…

C# OpenCvSharp Yolov8 Pose 姿态识别

效果 项目 代码 using OpenCvSharp; using OpenCvSharp.Dnn; using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms;namespace OpenC…

中国企业400电话在线申请办理

在当今竞争激烈的商业环境中&#xff0c;企业需要寻求各种方式来提升客户服务和市场竞争力。而拥有一个专属的400电话号码&#xff0c;不仅可以为企业带来更多的商机&#xff0c;还能提升企业形象和客户满意度。本文将介绍如何在线申请办理中国企业400电话&#xff0c;并提供一…

京东数据接口|电商运营中数据分析的重要性

在电商运营中&#xff0c;数据分析是非常重要的一环&#xff0c;它可以帮助电商企业更好地了解市场、了解消费者、了解产品、了解销售渠道等各种信息&#xff0c;从而制定更为科学有效的运营策略&#xff0c;提高销售效益。 数据方面用户可以直接选择使用数据接口来获取&#…

面试高频手撕算法 - 背包问题1

目录 1. 前言 2. 什么是 01 背包&#xff0c; 什么是完全背包 3. 01 背包 3.1 【模板】01背包 3.2 分割等和子集 3.3 分割等和子集 3.4 最后一块石头的重量 1. 前言 为什么要专门去搞一下这个背包问题呢 ? 因为作者已经在两场面试中吃了这个亏, 尤其是在面深信服的测开岗…

信创办公–基于WPS的EXCEL最佳实践系列 (条件格式)

信创办公–基于WPS的EXCEL最佳实践系列 &#xff08;设置条件格式&#xff09; 目录 应用背景操作步骤1、选用条件格式1.1 筛选出迟到次数超过3次的数据1.2 筛选出早退次数位于前三的数据1.3 个人加班时长在总体中所占的在的位置 2、删除条件格式2.1 清除规则2.2 管理规则 应用…

钡铼BL124PN:简单快速转换Profinet到Ethernet/IP

钡铼技术BL124PN是一款高性能的Profinet转Ethernet/IP网关设备。该网关专为工业自动化领域设计&#xff0c;用于实现不同协议之间的互连和通信。BL124PN采用可靠稳定的硬件和先进的通信技术&#xff0c;具有以下主要特点&#xff1a; 协议转换能力&#xff1a;BL124PN能够将Pr…

WIN10 查看端口占用情况

输入命令&#xff0c;其中 5082 为需要查看的端口 C:\Users\chenjian>netstat -ano|findstr "5082"TCP 0.0.0.0:5082 0.0.0.0:0 LISTENING 21708可以看到 5082 这个端口被 “21708”这个进程占用了。 输入命令查看进程的信息 C…

ST2110基础介绍(初识)

前言 随着超高清视频产业迅速发展&#xff0c;4K/8K超高清信号对带宽提出更高的要求&#xff0c;传统的基于SDI (数字串行接口)采集制作、调度分发的方式已经不能满足技术更新的需求。行业内的共识是采用基于ICT(网络和通信技术)技术的IP化架构&#xff0c;一方面解决高带宽信…

当长假来临,如何定向应用AI?科技力量变革您的假日生活!

“今夜月明人尽望&#xff0c;不知秋思落谁家。”中秋国庆的双节组合&#xff0c;让万千中国家庭迎来了难得的团圆欢庆时刻。长达八天的假期已经开启&#xff0c;现在的你是不是已经背上行囊&#xff0c;浪迹远方了呢&#xff1f; &#xff08;金秋时分&#xff0c;假日光景&am…

Java基于SpringBoot的社区医院管理服务

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W,Csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 文章目录 1、效果演示2、 前言介绍3. 技术栈4系统设计4.1数据库设计4.2系统整体设计4.2.1 系统设计思想4.2.…

网络安全工程师日常工作有哪些?初学者怎么适应

在往期的很多文章当中介绍过如何成为一名合格的网络安全工程师所需技能以及学习方法&#xff0c;今天给大家更新网络安全工程师的主要日常工作&#xff0c;相信刚学习的小伙伴一定很好奇&#xff0c;话不多说&#xff0c;往下看吧。 网络安全工程师分为哪些方向 网络安全工程…

SwiftUIArkUI-曲线动画Path和路径动画motionPath

OpenHarmony Path ArkUI 高性能 motionPath 动效 三次贝塞尔曲线 曲线动画 SwiftUI SwiftUI通过Path可以绘制路径动画&#xff0c;通过addCurve可用绘制三次贝塞尔曲线。 ArkUI是鸿蒙的核心UI布局框架&#xff0c;使用motionPath绘制路径动画&#xff0c;通过绘制路径可以自定…

开箱即用版本 满分室间质评之GATK Somatic SNV+Indel+CNV+SV

最近准备为sliverworkspace 图形化生信平台开发报告设计器&#xff0c;需要一个较为复杂的pipeline作为测试数据&#xff0c;就想起来把之前的 满分室间质评之GATK Somatic SNVIndelCNVSV&#xff08;下&#xff09;性能优化翻出来用一下。跑了一遍发现还是各种问题&#xff0c…

想用iPhone录视频同时拍照片?这篇教你!附送10个苹果手机技巧

如果你在使用苹果手机录制视频时还想要拍摄静态照片&#xff0c;打开录制功能后&#xff0c;可以点击苹果手机屏幕右侧的白色圈圈&#xff0c;这时候&#xff0c;静态的照片就会自动保存到相册啦&#xff01; 哈哈&#xff0c;如果你想要反过来操作——用苹果手机拍照时顺便录…