Pytorch从零开始实战14

news2025/1/21 4:54:00

Pytorch从零开始实战——DenseNet + SENet算法实战

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——DenseNet + SENet算法实战
    • 环境准备
    • 数据集
    • 模型选择
    • 开始训练
    • 可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是使用DenseNet+SENet模型。
第一步,导入常用包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import random
from time import time
import numpy as np
import pandas as pd
import datetime
import gc
import os
import copy
import warnings
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count() # # (device(type='cuda'), 2)

数据集

本次实验继续使用猴痘病数据集,使用pathlib查看类别,本次类别只有0,1两种类别分别代表患病和不患病。

import pathlib
data_dir = './data/ill/'
data_dir = pathlib.Path(data_dir) # 转成pathlib.Path对象
data_paths = list(data_dir.glob('*')) 
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['Monkeypox', 'Others']

使用transforms对数据集进行统一处理,并且根据文件夹名映射对应标签

all_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])

total_data = datasets.ImageFolder("./data/ill/", transform=all_transforms)
total_data.class_to_idx # {'Monkeypox': 0, 'Others': 1}

随机查看5张图片

def plotsample(data):
    fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图
    for i in range(5):
        num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次
        #抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据
        #而展示图像用的imshow函数最常见的输入格式也是3通道
        npimg = torchvision.utils.make_grid(data[num][0]).numpy()
        nplabel = data[num][1] #提取标签 
        #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取
        axs[i].imshow(np.transpose(npimg, (1, 2, 0))) 
        axs[i].set_title(nplabel) #给每个子图加上标签
        axs[i].axis("off") #消除每个子图的坐标轴

plotsample(total_data)

在这里插入图片描述
根据8比2划分数据集和测试集,并且利用DataLoader划分批次和随机打乱

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,
                                        batch_size=batch_size,
                                        shuffle=True,
                                      )
test_dl = torch.utils.data.DataLoader(test_ds,
                                        batch_size=batch_size,
                                        shuffle=True,
                                     )

len(train_dl.dataset), len(test_dl.dataset) # (1713, 429)

模型选择

本次实验使用DenseNet + SENet模型,DenseNet的设计核心思想是通过密集连接来增强神经网络的信息流动,促进梯度的传播,以及提高参数的共享和重复使用。采用跨通道concat的形式来连接,会连接前面所有层作为输入。
核心公式为:
在这里插入图片描述
DenseNet中的基本组成单元是DenseBlock,它由多个密集连接的DenseLayer组成。每个DenseLayer都接收所有前面的DenseLayer特征作为输入,将其连接到自己的输出上,并传递给后续的层。如图所示,这是一个基本的DenseBlock模块。
在这里插入图片描述
整体网络架构图如下所示,借用K同学的图片
在这里插入图片描述

为了控制模型的复杂度并减少特征图的大小,DenseNet引入了Transition Block。过渡块包括批归一化、ReLU激活和 1x1 卷积,以减小特征图的通道数,并通过池化操作降低空间维度。
在这里插入图片描述
首先对DenseLayer类定义,本次实验使用add_module函数,默认是用于向类中添加一个子模块,第一个参数为模块名,第二个参数为模块实例,其实相当于加到父类的nn.Sequential里面,所以调用的时候使用super().forward(x),这段的核心是将输入 x 与新特征 t 进行通道维度上的连接,完成密集连接。

class DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super().__init__()
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        t = super().forward(x)
        if self.drop_rate > 0:
            t = F.dropout(t, p=self.drop_rate, training=self.training)
        return torch.cat([x, t], 1)

下面是DenseBlock的实现,通过循环创建了多个DenseLayer。其中的 num_input_features + i * growth_rate 用于指定输入通道的数量,确保每个DenseLayer的输入通道数逐渐增加。将新创建的DenseLayer添加为 DenseBlock 的子模块。循环结束后,DenseBlock 就包含了多个DenseLayer,每个DenseLayer都具有逐渐增加的输入通道数量。

class DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super().__init__()
        for i in range(num_layers):
            layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module("denselayer%d" % (i + 1), layer)

下面是Transition,实现过渡的功能,是在块之间降低通道数量和空间维度。

class Transition(nn.Sequential):
    def __init__(self, num_input_feature, num_output_features):
        super().__init__()
        self.add_module("norm", nn.BatchNorm2d(num_input_feature))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(2, stride=2))

SENet是一种深度神经网络结构,它的核心思想是允许网络在训练期间对每个通道进行自适应的加权,以使网络能够更加关注对任务有用的通道,并抑制对任务无关的通道。这有助于提高网络对输入数据的敏感性,并提升网络性能。SENet的结构包括两个主要组件:Squeeze 操作和 Excitation 操作。

Squeeze 操作(Global Average Pooling):通过全局平均池化,将每个通道的空间维度降为1。这样,对于每个通道,都得到一个单一的数值,反映了该通道对整个特征图的重要性。

Excitation 操作(通道注意力):在 Squeeze 操作后,通过一个小型的多层感知机(MLP)来学习通道之间的关系。这个小型MLP包含一个压缩操作和一个激励操作)。最后,利用学到的权重对每个通道的特征图进行加权,得到加权后的特征表示。
在这里插入图片描述
下面是SENet的实现,首先,通过全局平均池化层对输入特征图进行平均池化,将每个通道的空间维度降为1。然后,通过全连接层序列 fc 对降维后的特征进行处理,得到每个通道的注意力权重。最后,将得到的注意力权重通过 view 操作还原为与输入特征图相同的形状,并将其与输入特征图相乘,得到应用了注意力机制的特征图。

from torch.nn import init
class SEAttention(nn.Module):

    def __init__(self, channel=512, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

整体模型实现,self.features 是一个包含多个层的序列,包括初始卷积块、多个DenseBlock和Transition,以及最后的全局平均池化和分类器。遍历 block_config 中的配置,创建DenseBlock和Transition。参数初始化部分使用了 Kaiming 初始化和常数初始化。

其中,OrderedDict是Python中的一种有序字典数据结构,它保留了元素添加的顺序。在神经网络中,我们可以使用OrderedDict来指定模型的层次结构。

在进行平均池化之前,进入到SENet进行学习通道注意力权重从而提高网络的表征能力。

from collections import OrderedDict

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
        super().__init__()
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(3, stride=2, padding=1))
        ]))

        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features += num_layers * growth_rate
            if i != len(block_config) - 1:
                transition = Transition(num_features, int(num_features * compression_rate))
                self.features.add_module("transition%d" % (i + 1), transition)
                num_features = int(num_features * compression_rate)

        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
        self.features.add_module("relu5", nn.ReLU(inplace=True))
        self.se = SEAttention(channel=1024, reduction=8)
        self.classifier = nn.Linear(num_features, num_classes)
        

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    def forward(self, x):
        features = self.features(x)
        out = self.se(features)
        out = F.avg_pool2d(features, 7, stride=1)
        out = out.view(features.size(0), -1)
        out = self.classifier(out)
        return out

使用summary查看网络

from torchsummary import summary
model = DenseNet().to(device)
summary(model, input_size=(3, 224, 224))

在这里插入图片描述

开始训练

定义训练函数

def train(dataloader, model, loss_fn, opt):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_acc, train_loss = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches
    return train_acc, train_loss

定义测试函数

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_acc, test_loss = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
    
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
            test_loss += loss.item()

    test_acc /= size
    test_loss /= num_batches
    return test_acc, test_loss

定义学习率、损失函数、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)

开始训练,epoch设置为20

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []

T1 = time.time()

best_acc = 0
best_model = 0

for epoch in range(epochs):

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval() # 确保模型不会进行训练操作
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"
          % (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))

T2 = time.time()
print('程序运行时间:%s秒' % (T2 - T1))

PATH = './best_model.pth'  # 保存的参数文件名
if best_model is not None:
    torch.save(best_model.state_dict(), PATH)
    print('保存最佳模型')
print("Done")

在这里插入图片描述

可视化

可视化训练过程与测试过程

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

总结

SE模块引入了通道注意力机制,使得网络在学习过程中能够更加自适应地关注对任务有用的通道,抑制对任务无关的通道。这有助于提高网络的特征表达能力。当前也可以与各种其他的深度神经网络结构集成。因此,可以在不改变整体网络架构的情况下,通过引入通道注意力机制来增强网络性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1339643.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

排列组合算法(升级版)

前言 在上一期博客中我们分享了一般的排列组合算法(没看的话点这里哦~),但是缺点很明显,没法进行取模运算,而且计算的范围十分有限,而今天分享的排列组合升级版算法能够轻松解决这些问题,话不多…

Bean 生命周期 和 SpringMVC 执行过程

这里简单记录下 Bean 生命周期的过程,方便自己日后面试用。源码部分还没看懂,这里先贴上结论 源码 结论

PAT 乙级 1028 人口普查

解题思路:此题我想到俩种方法,一种是排序方法,一种是不排序的方法,首先都是看是否是有效年龄,然后一种是排序,另一种是维护最大值和最小值的变量,一定要注意如果有效数字是0那就只输出0就可以了…

【PostgreSQL内核学习(二十)—— 数据库中的遗传算法】

数据库中的遗传算法 概述个体的编码方式及种群初始化geqo 函数 适应值geqo_eval 函数gimme_tree 函数 父体选择策略geqo_selection 函数 杂交算子边重组杂交 ERX ( edge recombination crossover)gimme_edge_table 函数gimme_tour 函数 变异算子geqo_mutation 函数 声明&#x…

在Ubuntu20.04配置PX4环境

目录 1.下载PX4源码2.安装PX4所有工具链3.编译PX4工程1.下载PX4源码 打开Ubuntu,Ctrl+Alt+T打开终端输入下面代码: git clone https://github.com/PX4/PX4-Autopilot.git --recursive出现上图中出现“Command ‘git’ not found, but can be installed with”,使用以下代码…

企业私有云容器化架构运维实战

企业私有云容器化架构运维实战 了解 什么是虚拟化: 虚拟化(Virtualization)技术最早出现在 20 世纪 60 年代的 IBM 大型机系统,在70年代的 System 370 系列中逐渐流行起来,这些机器通过一种叫虚拟机监控器(Virtual M…

【线性代数】通过矩阵乘法得到的线性方程组和原来的线性方程组同解吗?

一、通过矩阵乘法得到的线性方程组和原来的线性方程组同解吗? 如果你进行的矩阵乘法涉及一个线性方程组 Ax b,并且你乘以一个可逆矩阵 M,且产生新的方程组 M(Ax) Mb,那么这两个系统是等价的;它们具有相同的解集。这…

Selenium自动化教程02:浏览器options配置及常用的操作方法

1.配置Chrome浏览器的选项 # Author : 小红牛 # 微信公众号:WdPython options webdriver.ChromeOptions() # 创建配置对象 options.add_argument(langzh_CN.UTF-8) # 设置中文 options.add_argument(--headless) # 无头参数,浏览器隐藏在后台运行 options.add_…

Win10 华硕笔记本只有飞行模式 WIFI 消失(仅供参考)

一、问题描述 下班,将电脑设置为休眠模式,回家,然后就出现:只有飞行模式,WIFI 消失 虽然有线可以用,但是不爽啊! 在“网络和Internet设置中” ,只有“飞行模式”的开关 &#xff0c…

java球队信息管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web球队信息管理系统是一套完善的java web信息管理系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql5…

基于MATLAB的泊松分布,正态分布与伽玛分布(附完整代码与例题)

目录 一. 泊松分布 1.1 理论部分 1.2 MATLAB函数模型 1.3 例题 二. 正态分布 2.1 理论部分 2.2 MATLAB函数模型 2.3 例题 三. 伽玛分布 3.1 理论部分 3.2 MATLAB函数模型 3.3 例题 一. 泊松分布 1.1 理论部分 Poisson分布是离散的,其x值只能取自然数。…

uniapp Vue3 面包屑导航 带动态样式

上干货 <template><view class"bei"><view class"container"><view class"indicator"></view><!-- 遍历路由列表 --><view v-for"(item, index) in routes" :key"index" :class&quo…

卷积神经网络 反向传播

误差的计算 softmax 经过softmax处理后所有输出节点概率和为1 损失&#xff08;激活函数&#xff09; 多分类问题&#xff1a;输出只可能归于某一个类别&#xff0c;不可能同时归于多个类别。 误差的反向传播 求w的误差梯度 权值的更新 首先是更新输出层和隐藏层之间的权重…

RustDesk连接客户端提示key不匹配 Key Mismatch无法连接(已解决)

环境: RustDesk1.1.9 服务端docker部署 问题描述: RustDesk连接客户端提示key不匹配 Key Mismatch无法连接 解决方案: 1.docker部署RustDesk服务检查配置 networks:rustdesk-net:external: falsevolumes:hbbr:hbbs:services:hbbs:container_name: rustdesk-hbbsport…

webstrom 快速创建typescript 语法检测的Vue3项目

webstrom 快速创建typescript 语法检测的Vue3项目 若您想为您的Vue 3项目添加TypeScript支持&#xff0c;您需要进行以下步骤&#xff1a; 安装 typescript 和 vitejs/plugin-vue 作为开发依赖项&#xff1a; npm install --save-dev typescript vitejs/plugin-vue创建一个…

Cucumber-JVM的示例和运行解析

Cucumber-JVM 是一个支持 Behavior-Driven Development (BDD) 的 Java 框架。在 BDD 中&#xff0c;可以编写可读的描述来表达软件功能的行为&#xff0c;而这些描述也可以作为自动化测试。 Cucumber-JVM 的最小化环境 Cucumber-JVM是BDD的框架&#xff0c; 提供了GWT语法的相…

andriod安卓水果商城系统课设

​ 一、目的及任务要求 随着当今社会经济的快速发展和网络的迅速普及&#xff0c;手机基本成为了每个人都随身携带的电子产品。传统的购物方式已经满足不了现代人日益追求便利及高效率的购物心理&#xff0c;而通过移动手机上的在线购物系统&#xff0c;可以便捷地甚至足不出…

Vue 自定义ip地址输入组件

实现效果&#xff1a; 组件代码 <template><div class"ip-input flex flex-space-between flex-center-cz"><input type"text" v-model"value1" maxlength"3" ref"ip1" :placeholder"placeholder"…

Win10 + 4090显卡配置深度学习环境 + gaussian-splatting配置 + 实测自己的场景

目录 1 安装Anaconda 2023.09版本 2 安装CUDA11.8 3 安装深度学习库Cudnn8.6.0 4 安装VSCODE2019 5 安装Colmap3.8 6 安装git 7 安装Python3.10 Pytorch2.0.0 7 安装项目 8 采集数据 8.1 IPhone 14 pro 拍摄30张照片左右 做预处理 8.2 生成colmap位姿等信息 8.3 开…

starrocks集群fe/be节点进程守护脚本

自建starrocks集群&#xff0c;有时候服务会挂掉&#xff0c;无法自动拉起服务&#xff0c;于是采用supervisor进行进程守护。可能是版本的原因&#xff0c;supervisor程序总是异常&#xff0c;无法对fe//be进行守护。于是写了个简易脚本。 #!/bin/bash AppNameFecom.starrock…