DFMN 代码解读

news2025/1/12 1:10:00

目录

0. 环境配置

1. 运行程序

2. 读代码的思路

1)model.py

!!  关于继承

 !!  关于网络结构组织

!! 关于 forward

2)  数据预处理

3)train.py


0. 环境配置

很简单,提示缺包xxx,pip install xxx 就可以了

1. 运行程序

从 DFNI/train.py 开始运行,不报错程序能正常运行就可以了

2. 读代码的思路

因为 DFMN (更名后的DFNI) 的主要贡献是网络结构的设计,因此,我看代码的步骤是:

1)model.py  2) train.py  3) 数据预处理相关代码文件

1)model.py

!!  关于继承

因为是自己设计的网络结构,因此继承了 nn.Module, 初始化代码如下。

可以看到定义的 DFNI 网络,继承了nn.Module。需要注意两个地方: 

  1. class DFNI(nn.Module)  子类的参数要写父类名
  2. super(DFNI, self).__init__()  子类初始化时要继承父类的所有方法和属性

与父类不同的属性和方法,就需要重新定义进行覆盖。比如:self.firstPart,  self.midPart1.

 
import torch
import torch.nn as nn


class DFNI(nn.Module):
    def __init__(self, upscale_factor):
        super(DFNI, self).__init__()      
        self.firstPart = nn.Sequential(
            ####
        )
        self.midPart1 = nn.Sequential(
            ####
        )

        self.midPart2 = nn.Sequential(
            ####
        )
        #
        # for p in self.parameters():
        #     p.requires_grad = False

        self.finalPart = nn.Sequential(
            ###
        )
        self.con1 = nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2)
        self.con2 = nn.Conv2d(8, 16, kernel_size=5, padding=5 // 2)
        self.con3 = nn.Conv2d(16, 32, kernel_size=3, padding=3 // 2)
        self.con4 = nn.Conv2d(32, 64, kernel_size=1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x, xx):     
        #####
        #####
 

 !!  关于网络结构组织

包装一堆卷积和Relu:网络的组成部分最常见的就是卷积和Relu, 一堆卷积和非线性截断函数组成的模块,一般用nn.Sequential()封装成一个模块,并且重新起个名字,方便后续在 forward 调用。

 nn.Conv2d 的参数设置:nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2),在这里, 1是输入通道数,8是输出通道数,7是卷积核的大小,默认步长为1,用0填充,尺寸为7 // 2, padding 的尺寸设置为1/2的核尺寸,是为了保证卷积后图像尺寸不变。

nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))

in_channel: 输入数据的通道数,例RGB图片通道数为3;
out_channel: 输出数据的通道数,这个根据模型调整;
kennel_size: 卷积核大小,可以是int,或tuple;kennel_size=2,意味着卷积大小(2,2), kennel_size=(2,3),意味着卷积大小(2,3)即非正方形卷积
stride:步长,默认为1,与kennel_size类似,stride=2,意味着步长上下左右扫描皆为2, stride=(2,3),左右扫描步长为2,上下为3;
padding:零填充

self.firstPart = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2),
            nn.LeakyReLU(inplace=True),   #???LeakyReLU
            nn.Conv2d(8, 16, kernel_size=5, padding=5 // 2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=3 // 2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=1),
            nn.LeakyReLU(inplace=True)
        )

 nn.LeakyReLU(inplace=True)

ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率。 

!! 关于 forward

按照数据据的处理流程将定义的模块链接起来就可以了,前一模块的输出是后一模块的输入。

    def forward(self, x, xx):    #  xx 咋来的?
        res = x
        # x = self.firstPart(x)
        x1 = self.lrelu(self.con1(x))
        x2 = self.lrelu(self.con2(x1))
        x3 = self.lrelu(self.con3(x2))
        x4 = self.lrelu(self.con4(x3))
        x = torch.cat((x1, x2, x3, x4), dim=1)  # dim 为0、1、2 分别表示增加行、增加列、增加厚度三个方向  ??为什么dim 是1
        b1 = self.midPart1(x)
        b2 = self.midPart2(x)
        x = torch.cat((b1, b2), dim=1)
        x = torch.add(res, x) # 相加和cat有啥区别?在网络设计的维度上是如何设置的?
        x = self.finalPart(x)
        x = torch.add(x, xx)
        return x

在正向传播的时候,因为有级联模块,还有并联模块,残差模块,所以各层的通道数如何设置很困惑?以下是代码作者给出的解答。

x = torch.cat((x1, x2, x3, x4), dim=1)  拼接时,为什么dim 为1?模型输入的是一个四维张量,对应a,第一个相当于个数,为1;第二个为通道数,第三为行,第四为列。即dim=1实现通道数拼接。
x = torch.add(res, x) res的通道数会根据x的通道个数进行复制扩展,比如1->65,然后再相加

在设计网络初期,如何跟踪每一层输出数据的维度呢?保证每一层设置的参数正确  尤其是有拼接的时候?模型的forward写完之后,然后用随机张量当作输入,调用这个模型测试一下;如果有设计size的问题的话,可以在模型中用print(xx.shape),观察一下。

#写在模型定义外
model = DFNI(4)
input1= torch.randn(1, 1, 175, 63)
input2 = torch.randn(1, 1, 700, 252)
model.load_state_dict(torch.load('DFNI_4.pt',map_location='cpu'))
out = model(input1, input2)

#写在模型内
print(x4.shape)

2)  数据预处理

论文的数据是取自开源数据。下采样过程为规律下采样,非规律的会引入噪声。

模型的数据包括:输入(下采样数据,传统方法实现的插值结果<运行inputdataGet>),输出:原始数据。

我们需要准备训练数据集和测试数据集,在Pytorch中,读取数据集需要用到Dataset和DataLoader两个类,Dataset负责对数据的读取,读取的内容是每一个数据和它对应的标签;DataLoader负责对Dataset读取的数据进行打包,然后分批次送入神经网络。

在自定义数据集中,关键是要实现数据类型转换为Dataset,这样就可以调用DataLoader了。

本例子中实现了,npy到Dataset的类型转换。

# 定义了DatasetFromFolder,继承自Dataset,目的:将自定义的数据转为Dataset类
class DatasetFromFolder(Dataset):
    def __init__(self, input1, input2, target):
        super(DatasetFromFolder, self).__init__()
        self.input1 = input1
        self.input2 = input2
        self.target = target

    def __getitem__(self, index):
        return self.input1[index], self.input2[index], self.target[index]

    def __len__(self):
        return len(self.target)


# input1, input2, target均来自.npy 文件,是一个npy数据转为Dataset的范例
input1, input2, target = dataGet_2(num)
input1 = input1.astype(np.float32)  #变化数组类型
input2 = input2.astype(np.float32)
target = target.astype(np.float32)


trainSet = DatasetFromFolder(input1, input2, target)
trainDataLoader = DataLoader(dataset=trainSet)

3)train.py

本论文的实验,由于数据量少,所以没考虑验证集(???但在实验结果的可信度上我也不确定),或许最后的残差的加入(加入传统插值训练结果进行训练),能够保证实验有较好的效果。

标准的训练、验证、测试模板参考,写的很赞!Pytorch模型训练和模型验证_MoxiMoses的博客-CSDN博客_pytorch训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/146037.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NuxtJS服务器端入门

一、搜索引擎优化 1、什么是SEO 总结&#xff1a;seo是网站为了提高自已的网站排名&#xff0c;获得更多的流量&#xff0c;对网站的结构及内容进行调整和优化&#xff0c;以便搜索引擎 &#xff08;百度&#xff0c;google等&#xff09;更好抓取到优质网站的内容&#xff0c…

GCN图神经网络和LSTM的介绍和使用场景 中英文

GCN-LSTM 可以学习参考 英文内容部分源自youtube的教学视频 自己跟着英文敲的 给定一辆出租车行驶时在某个时间段的速度&#xff0c;下一个时刻速度会是多少&#xff1f;这是一个时间序列回归预测问题。获得了若干时间点的速度&#xff0c;目标是预测出租车速度序列中的下一个…

线性模型:AR、MA、ARMA、ARMAX、ARX、ARARMAX、OE、BJ等

目录 1 AR 1 2 MA 1 3 ARMA 1 4 ARMAX 2 5 ARX 2 6 ARARX 3 7 ARARMAX 3 8 OE 3 9 BJ 3 各种线性模型&#xff0c;这些模型算数学基础模型&#xff0c;不仅在计量经济学&#xff0c;也在工业控制等各领域有应用。包括AR、MA、ARMA、ARMAX、ARX、ARARMAX、OE、BJ等。 1 AR 自…

【疑难攻关】——floor报错注入

作者名&#xff1a;Demo不是emo 主页面链接&#xff1a;主页传送门创作初心&#xff1a;舞台再大&#xff0c;你不上台&#xff0c;永远是观众&#xff0c;没人会关心你努不努力&#xff0c;摔的痛不痛&#xff0c;他们只会看你最后站在什么位置&#xff0c;然后羡慕或鄙夷座右…

mysql多表查询30个经典案例

mysql多表查询30个经典案例插入两张表一个dept一个emp插入dept表数据插入emp表数据1.列出每个部门里面有那些员工及部门名称;2.运维部门的收入总和&#xff1b;3.HR部入职员工的员工号4.财务部门收入超过5000元的员工姓名5.找出销售部收入最低的员工的入职时间&#xff1b;6.找…

5G NR标准: 第20章 5G的演进

第20章 5G的演进 NR 的第一个版本&#xff0c;第 15 版&#xff0c;侧重于对 eMBB 的基本支持&#xff0c;在某种程度上&#xff0c;URLLC.1 如前几章所述&#xff0c;第 15 版是为即将发布的 NR 未来发展构建的基础 . NR 演进将带来额外的功能并进一步提升性能。 附加功能不…

Netty原理示图

1. AWT事件驱动 2. Websocket协议 3. 基于多个反应器的多线程模式 4. Netty Reactor 工作架构图 5. Bootstrap引导过程 Channel Channel是Java NIO的基础。它表示一个开放的连接&#xff0c;进行IO操作。基本的 I/O 操作&#xff08; bind() 、 connect() 、 read() 和 write(…

什么是异常?异常可以看作你敲出来的bug

异常异常的体系抛异常try -catchfinally自定义异常作为初学者&#xff0c;在刚开始写代码的时候&#xff0c;差不多写一行代码都要见一行红吧异常的体系 这里我们首先需要知道的一点是&#xff0c;所有的异常其实都是类。我们所有的异常都是继承于Throwable这个大类的&#xff…

喜讯!神策分析 Android SDK 入选数据安全“星熠”案例

随着《数据安全法》和《个人信息保护法》的相继出台实施&#xff0c;标志着数据安全保护法治时代的真正到来&#xff0c;国家对数据安全的重视达到了前所未有的高度。在此背景下&#xff0c;神策数据全面落地数据开发利用和数据安全领域的技术推广与产业创新&#xff0c;神策分…

有哪些视频素材网站值得推荐?

高质量视频素材网站&#xff0c;免费、可商用&#xff0c;建议收藏&#xff01; 1、菜鸟图库 https://www.sucai999.com/video.html?vNTYwNDUx 网站有超多视频素材&#xff0c;全部都是高清无水印&#xff0c;各种类型都有&#xff0c;像自然、城市、动物、科技、商业等等都…

【算法】哈希表

&#x1f600;大家好&#xff0c;我是白晨&#xff0c;一个不是很能熬夜&#x1f62b;&#xff0c;但是也想日更的人✈。如果喜欢这篇文章&#xff0c;点个赞&#x1f44d;&#xff0c;关注一下&#x1f440;白晨吧&#xff01;你的支持就是我最大的动力&#xff01;&#x1f4…

Unity3D打包Assetbundle丢失Shader问题

详情见&#xff1a;https://www.pianshen.com/article/5391338163/1、Unity3D在打包Assetbundle时&#xff0c;可能会遇到Shader丢失的问题&#xff0c;解决方法&#xff1a;打开Edit->Project Settings->Graphics&#xff0c;在Always Included Shaders列表添加上所需的…

微信小程序测试(简单项目测试)

Flex布局简介 布局的传统解决方案&#xff0c;基于盒状模型&#xff0c;依赖 display属性 position属性 float属性 什么是flex布局&#xff1f; Flex是Flexible Box的缩写&#xff0c;意为”弹性布局”&#xff0c;用来为盒状模型提供最大的灵活性。 任何一个容器都可以指…

小程序 - 起步:小程序的构成、宿主环境、协同工作和发布

小程序 - 起步:小程序的构成、宿主环境、协同工作和发布 Date: January 5, 2023 Sum: 小程序的构成、宿主环境、协同工作和发布 小程序简介 小程序与普通网页开发的区别 1. 运行环境不同 网页运行在浏览器环境中 小程序运行在微信环境中 2. API 不同 由于运行环境的不同…

P1308 [NOIP2011 普及组] 统计单词数————C++

题目 [NOIP2011 普及组] 统计单词数 题目描述 一般的文本编辑器都有查找单词的功能&#xff0c;该功能可以快速定位特定单词在文章中的位置&#xff0c;有的还能统计出特定单词在文章中出现的次数。 现在&#xff0c;请你编程实现这一功能&#xff0c;具体要求是&#xff1…

数字验证学习笔记——SystemVerilog芯片验证21 ——覆盖率类型

一、覆盖率类型 覆盖率是衡量设计验证完备性的一个通用词语。随着测试逐步覆盖各种合理的组合&#xff0c;仿真过程过程会慢慢勾画你的设计情况。覆盖率工具会在仿真过程中收集信息&#xff0c;然后进行后续处理并且得到覆盖率报告。通过这个报告找出覆盖之外的盲区&#xff0…

冒泡排序模拟qsort函数

欢迎来到 Claffic 的博客 &#x1f49e;&#x1f49e;&#x1f49e; 前言&#xff1a; 学习C语言&#xff0c;一般情况下都会接触到冒泡排序&#xff0c;你知道吗&#xff0c;用冒泡排序的思想可以模拟实现qsort函数&#xff08;库函数的一种&#xff0c;可以实现快排&#xff…

图解面试题:经典50题!掌握这些题,面试也太简单了!

已知有如下4张表&#xff1a;学生表&#xff1a;student(学号,学生姓名,出生年月,性别)成绩表&#xff1a;score(学号,课程号,成绩)课程表&#xff1a;course(课程号,课程名称,教师号)教师表&#xff1a;teacher(教师号,教师姓名)1.汇总分析-查询学生的总成绩并进行排名/* 【知…

CSS基础知识(盒子模型)

继承上一篇CSS的三大特性的优先级继续讲解。 1.1优先级 优先级注意点&#xff1a; 权重是有4组数字组成的&#xff0c;但是不会有进位。可以理解为类选择器永远大于元素选择器&#xff0c;id选择器永远大于类选择器以此类推。等级判断从左向右&#xff0c;如果某一位数值相同…

前端学习之CSS基础

前言 html标签就不说了&#xff0c;这次学习CSS样式&#xff0c;就是美化html标签。 快速了解什么是css 普通标签&#xff1a; 加了css样式&#xff1a; <img src"https://static.runoob.com/images/icon/mobile-icon.png" style"height:100px" /&…