深度学习实战基础案例——卷积神经网络(CNN)基于MobileNetV3的肺炎识别|第3例

news2024/11/20 12:20:21

文章目录

  • 前言
  • 一、数据集介绍
  • 二、前期工作
  • 三、数据集读取
  • 四、构建CA注意力模块
  • 五、构建模型
  • 六、开始训练

前言

Google公司继MobileNetV2之后,在2019年发表了它的改进版本MobileNetV3。而MobileNetV3共有两个版本,分别是MobileNetV3-Large和MobileNetV2-Small。改进后的MobileNetV3,在ImageNet数据集的分类精度上,它的MobileNetV3-Large版本相较于MobileNetV2提升了大概3.2%的精度同时延迟减少了20%,而MobileNetV3-Small则提升了6.6%的精度,减少了大概23%的延迟。

今天,我们用MobileNetV3来进行肺炎的识别,同时我们用CA注意力机制替换了原模型中的SE注意力模块。


我的环境:

  • 基础环境:python3.7
  • 编译器:jupyter notebook
  • 深度学习框架:pytorch

一、数据集介绍

ChestXRay2017数据集共包含5856张胸腔X射线透视图,诊断结果(即分类标签)主要分为正常和肺炎,其中肺炎又可以细分为:细菌性肺炎和病毒性肺炎。

胸腔X射线图像选自广州市妇幼保健中心的1至5岁儿科患者的回顾性研究。所有胸腔X射线成像都是患者常规临床护理的一部分。

为了分析胸腔X射线图像,首先对所有胸腔X光片进行了筛查,去除所有低质量或不可读的扫描,从而保证图片质量。然后由两名专业医师对图像的诊断进行分级,最后为降低图像诊断错误, 还由第三位专家检查了测试集。

主要分为train和test两大子文件夹,分别用于模型的训练和测试。在每个子文件内又分为了NORMAL(正常)和PNEUMONIA(肺炎)两大类。

在PNEUMONIA文件夹内含有细菌性和病毒性肺炎两类,可以通过图片的命名格式进行判别。
在这里插入图片描述

二、前期工作

from torch import nn
import torch.utils.data as Data
from torchvision.transforms import transforms
import torchvision
import torchsummary

# 设置device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

三、数据集读取

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

train_data=torchvision.datasets.ImageFolder(root=r"ChestXRay2017/chest_xray/train",transform=data_transform["train"])
train_dataloader=Data.DataLoader(train_data,batch_size=48,shuffle=True)

test_data=torchvision.datasets.ImageFolder(root=r"ChestXRay2017/chest_xray/test",transform=data_transform["val"])
test_dataloader=Data.DataLoader(test_data,batch_size=48,shuffle=True)

四、构建CA注意力模块

我们都知道注意力机制在各种计算机视觉任务中都是有帮助,如图像分类和图像分割。其中最为经典和被熟知的便是SENet,它通过简单地squeeze每个2维特征图,进而有效地构建通道之间的相互依赖关系。
在这里插入图片描述

SE Block虽然近2年来被广泛使用;然而,它只考虑通过建立通道之间的关系来重新衡量每个通道的重要性,而忽略了位置信息,但是位置信息对于生成空间选择性attention maps是很重要的。因此就有人引入了一种新的注意块,它不仅仅考虑了通道间的关系还考虑了特征空间的位置信息,即CA(Coordinate Attention)注意力机制。

在这里插入图片描述

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.relu6 = nn.ReLU6()

    def forward(self, x):
        return x * self.relu6(x + 3) / 6


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.relu = h_swish()

    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.relu(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        x_h = self.conv2(x_h).sigmoid()
        x_w = self.conv3(x_w).sigmoid()
        x_h = x_h.expand(-1, -1, h, w)
        x_w = x_w.expand(-1, -1, h, w)

        y = identity * x_w * x_h
        # y=x_w * x_h

        return y

class CA_SA(nn.Module):
    def __init__(self,inchannel,outchannel):
        super(CA_SA, self).__init__()
        self.CA=CoordAtt(inchannel,outchannel)
        self.SA=Spatial_Attention_Module(7)
    def forward(self,x):
        y=self.CA(x)
        z=self.SA(x)
        return x*y*z

五、构建模型

import torch.nn as nn
import torch
import torchsummary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义h-swith激活函数
class HardSwish(nn.Module):
    def __init__(self, inplace=True):
        super(HardSwish, self).__init__()
        self.relu6 = nn.ReLU6()

    def forward(self, x):
        return x * self.relu6(x + 3) / 6


# DW卷积
def ConvBNActivation(in_channels, out_channels, kernel_size, stride, activate):
    # 通过设置padding达到当stride=2时,hw减半的效果。此时不与kernel_size有关,所实现的公式为: padding=(kernel_size-1)//2
    # 当kernel_size=3,padding=1时: stride=2 hw减半, stride=1 hw不变
    # 当kernel_size=5,padding=2时: stride=2 hw减半, stride=1 hw不变
    # 从而达到了使用 stride 来控制hw的效果, 不用去关心kernel_size的大小,控制单一变量
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                  padding=(kernel_size - 1) // 2, groups=in_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6() if activate == 'relu' else HardSwish()
    )


class Inceptionnext(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, activate):
        super(Inceptionnext, self).__init__()

        gc = int(in_channels * 1 / 4)  # channel number of a convolution branch

        # self.dwconv_hw = nn.Conv2D(gc, gc, kernel_size,stride=stride,padding=(kernel_size-1)//2,groups=gc)

        self.dwconv_hw1 = nn.Conv2d(gc, gc, (1, kernel_size), stride=stride, padding=(0, (kernel_size - 1) // 2),
                                    groups=gc)
        self.dwconv_hw2 = nn.Conv2d(gc, gc, (kernel_size, 1), stride=stride, padding=((kernel_size - 1) // 2, 0),
                                    groups=gc)
        self.dwconv_hw = nn.Sequential(
            nn.Conv2d(gc, gc, (1, kernel_size), stride=stride, padding=(0, (kernel_size - 1) // 2), groups=gc),
            nn.Conv2d(gc, gc, (kernel_size, 1), stride=stride, padding=((kernel_size - 1) // 2, 0), groups=gc)
            )

        # self.dwconv_hw = nn.Sequential(
        #     nn.Conv2d(gc,gc//2,kernel_size=1,stride=1),
        #     nn.Conv2d(gc//2, gc//2, (1, kernel_size), stride=stride, padding=(0, (kernel_size - 1) // 2), groups=gc//2),
        #     nn.Conv2d(gc//2, gc//2, (kernel_size, 1), stride=stride, padding=((kernel_size - 1) // 2, 0), groups=gc//2)
        #     )


        self.dwconv_w = nn.Conv2d(gc, gc, kernel_size=(1, 11), stride=stride, padding=(0, 11 // 2), groups=gc)

        self.dwconv_h = nn.Conv2d(gc, gc, kernel_size=(11, 1), stride=stride, padding=(11 // 2, 0), groups=gc)

        self.batch2d = nn.BatchNorm2d(out_channels)

        self.activate = nn.ReLU6() if activate == 'relu' else HardSwish()

        self.split_indexes = (gc, gc, gc, in_channels - 3 * gc)

        self.cheap=nn.Sequential(
            nn.Conv2d(gc // 2, gc // 2, (1, 3), stride=stride, padding=(0, (3 - 1) // 2),
                      groups=gc//2),
            nn.Conv2d(gc // 2, gc // 2, (3, 1), stride=stride, padding=((3 - 1) // 2, 0), groups=gc//2)
        )

    def forward(self, x):
        # B, C, H, W = x.shape
        x_hw, x_w, x_h, x_id = torch.split(x, self.split_indexes, dim=1)

        x = torch.cat(
            (self.dwconv_hw(x_hw),
             self.dwconv_w(x_w),
             self.dwconv_h(x_h),
             x_id),
            dim=1)

        # x = torch.cat(
        #     (torch.cat((self.dwconv_hw(x_hw),self.cheap(self.dwconv_hw(x_hw))),dim=1),
        #      self.dwconv_w(x_w),
        #      self.dwconv_h(x_h),
        #      x_id),
        #     dim=1)

        x = self.batch2d(x)
        x = self.activate(x)
        return x

# PW卷积(接全连接层)
def Conv1x1BN(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
        nn.BatchNorm2d(out_channels)
    )


class SqueezeAndExcite(nn.Module):
    def __init__(self, in_channels, out_channels, se_kernel_size, divide=4):
        super(SqueezeAndExcite, self).__init__()
        mid_channels = in_channels // divide   # 维度变为原来的1/4

        # 将当前的channel平均池化成1
        self.pool = nn.AvgPool2d(kernel_size=se_kernel_size,stride=1)

        # 两个全连接层 最后输出每层channel的权值
        self.SEblock = nn.Sequential(
            nn.Linear(in_features=in_channels, out_features=mid_channels),
            nn.ReLU6(),
            nn.Linear(in_features=mid_channels, out_features=out_channels),
            HardSwish(),
        )

    def forward(self, x):
        a=x.shape
        b, c, h, w = a[0],a[1],a[2],a[3]
        out = self.pool(x)       # 不管当前的 h,w 为多少, 全部池化为1
        out = out.reshape([b, -1])    # 打平处理,与全连接层相连
        # 获取注意力机制后的权重
        out = self.SEblock(out)
        # out是每层channel的权重,需要扩维才能与原特征矩阵相乘
        out = out.reshape([b, c, 1, 1])  # 增维
        return out * x

# # 普通的1x1卷积
# class Conv1x1BNActivation(nn.Module):
#     def __init__(self,inchannel,outchannel,activate):
#         super(Conv1x1BNActivation, self).__init__()
#         self.first=nn.Sequential(
#             nn.Conv2d(inchannel,outchannel//2,kernel_size=1,stride=1),
#             nn.Conv2d(outchannel//2,outchannel//2,kernel_size=3,stride=1,padding=1,groups=outchannel//2)
#                                 )
#         self.second=nn.Conv2d(outchannel//2,outchannel//2,kernel_size=3,stride=1,padding=1,groups=outchannel//2)
#         self.BN=nn.BatchNorm2d(outchannel)
#         self.act=nn.ReLU6() if activate == 'relu' else HardSwish()
#     def forward(self,x):
#         x=self.first(x)
#         y=torch.cat((x,self.second(x)),dim=1)
#         y=self.BN(y)
#         y=self.act(y)
#         return y
def Conv1x1BNActivation(in_channels,out_channels,activate):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6() if activate == 'relu' else HardSwish()
    )

class SEInvertedBottleneck(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, kernel_size, stride, activate, use_se,
                 se_kernel_size=1):
        super(SEInvertedBottleneck, self).__init__()
        self.stride = stride
        self.use_se = use_se
        self.in_channels = in_channels
        self.out_channels = out_channels
        # mid_channels = (in_channels * expansion_factor)

        # 普通1x1卷积升维操作
        self.conv = Conv1x1BNActivation(in_channels, mid_channels, activate)

        # DW卷积 维度不变,但可通过stride改变尺寸 groups=in_channels
        if stride == 1:
            self.depth_conv = Inceptionnext(mid_channels, mid_channels, kernel_size, stride, activate)
        else:
            self.depth_conv = ConvBNActivation(mid_channels, mid_channels, kernel_size, stride, activate)
        # self.depth_conv = ConvBNActivation(mid_channels, mid_channels, kernel_size,stride,activate)

        # 注意力机制的使用判断
        if self.use_se:
            # self.SEblock = SqueezeAndExcite(mid_channels, mid_channels, se_kernel_size)
            # self.SEblock = CBAM.CBAMBlock("FC", 5, channels=mid_channels, ratio=9)
            self.SEblock = CoordAtt(mid_channels,mid_channels)
            # self.SEblock = CAblock.CA_SA(mid_channels, mid_channels)
        # PW卷积 降维操作
        self.point_conv = Conv1x1BN(mid_channels, out_channels)

        # shortcut的使用判断
        if self.stride == 1:
            self.shortcut = Conv1x1BN(in_channels, out_channels)

    def forward(self, x):
        # DW卷积
        out = self.depth_conv(self.conv(x))
        # 当 use_se=True 时使用注意力机制
        if self.use_se:
            out = self.SEblock(out)
        # PW卷积
        out = self.point_conv(out)
        # 残差操作
        # 第一种: 只看步长,步长相同shape不一样的输入输出使用1x1卷积使其相加
        # out = (out + self.shortcut(x)) if self.stride == 1 else out
        # 第二种: 同时满足步长与输入输出的channel, 不使用1x1卷积强行升维
        out = (out + x) if self.stride == 1 and self.in_channels == self.out_channels else out

        return out

class MobileNetV3(nn.Module):
    def __init__(self, num_classes=8, type='large'):
        super(MobileNetV3, self).__init__()
        self.type = type

        # 224x224x3 conv2d 3 -> 16 SE=False HS s=2
        self.first_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            HardSwish(),
        )
        # torch.Size([1, 16, 112, 112])

        # MobileNetV3_Large 网络结构
        if type == 'large':
            self.large_bottleneck = nn.Sequential(
                # torch.Size([1, 16, 112, 112]) 16 -> 16 -> 16 SE=False RE s=1
                SEInvertedBottleneck(in_channels=16, mid_channels=16, out_channels=16, kernel_size=3, stride=1,
                                     activate='relu', use_se=False),
                # torch.Size([1, 16, 112, 112]) 16 -> 64 -> 24 SE=False RE s=2
                SEInvertedBottleneck(in_channels=16, mid_channels=64, out_channels=24, kernel_size=3, stride=2,
                                     activate='relu', use_se=False),
                # torch.Size([1, 24, 56, 56])   24 -> 72 -> 24 SE=False RE s=1
                SEInvertedBottleneck(in_channels=24, mid_channels=72, out_channels=24, kernel_size=3, stride=1,
                                     activate='relu', use_se=False),
                # torch.Size([1, 24, 56, 56])   24 -> 72 -> 40 SE=True RE s=2
                SEInvertedBottleneck(in_channels=24, mid_channels=72, out_channels=40, kernel_size=5, stride=2,
                                     activate='relu', use_se=True, se_kernel_size=28),
                # torch.Size([1, 40, 28, 28])   40 -> 120 -> 40 SE=True RE s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=120, out_channels=40, kernel_size=5, stride=1,
                                     activate='relu', use_se=True, se_kernel_size=28),
                # torch.Size([1, 40, 28, 28])   40 -> 120 -> 40 SE=True RE s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=120, out_channels=40, kernel_size=5, stride=1,
                                     activate='relu', use_se=True, se_kernel_size=28),
                # torch.Size([1, 40, 28, 28])   40 -> 240 -> 80 SE=False HS s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=240, out_channels=80, kernel_size=3, stride=1,
                                     activate='hswish', use_se=False),
                # torch.Size([1, 80, 28, 28])   80 -> 200 -> 80 SE=False HS s=1
                SEInvertedBottleneck(in_channels=80, mid_channels=200, out_channels=80, kernel_size=3, stride=1,
                                     activate='hswish', use_se=False),
                # torch.Size([1, 80, 28, 28])   80 -> 184 -> 80 SE=False HS s=2
                SEInvertedBottleneck(in_channels=80, mid_channels=184, out_channels=80, kernel_size=3, stride=2,
                                     activate='hswish', use_se=False),
                # torch.Size([1, 80, 14, 14])   80 -> 184 -> 80 SE=False HS s=1
                SEInvertedBottleneck(in_channels=80, mid_channels=184, out_channels=80, kernel_size=3, stride=1,
                                     activate='hswish', use_se=False),
                # torch.Size([1, 80, 14, 14])   80 -> 480 -> 112 SE=True HS s=1
                SEInvertedBottleneck(in_channels=80, mid_channels=480, out_channels=112, kernel_size=3, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 112, 14, 14])  112 -> 672 -> 112 SE=True HS s=1
                SEInvertedBottleneck(in_channels=112, mid_channels=672, out_channels=112, kernel_size=3, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 112, 14, 14])  112 -> 672 -> 160 SE=True HS s=2
                SEInvertedBottleneck(in_channels=112, mid_channels=672, out_channels=160, kernel_size=5, stride=2,
                                     activate='hswish', use_se=True, se_kernel_size=7),
                # torch.Size([1, 160, 7, 7])    160 -> 960 -> 160 SE=True HS s=1
                SEInvertedBottleneck(in_channels=160, mid_channels=960, out_channels=160, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=7),
                # torch.Size([1, 160, 7, 7])    160 -> 960 -> 160 SE=True HS s=1
                SEInvertedBottleneck(in_channels=160, mid_channels=960, out_channels=160, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=7),
            )

            # torch.Size([1, 160, 7, 7])
            # 相比MobileNetV2,尾部结构改变,,变得更加的高效
            self.large_last_stage = nn.Sequential(
                nn.Conv2d(in_channels=160, out_channels=960, kernel_size=1, stride=1),
                nn.BatchNorm2d(960),
                HardSwish(),
                nn.AvgPool2d(kernel_size=7, stride=1),
                nn.Conv2d(in_channels=960, out_channels=1280, kernel_size=1, stride=1),
                HardSwish(),
            )

        # MobileNetV3_Small 网络结构
        if type == 'small':
            self.small_bottleneck = nn.Sequential(
                # torch.Size([1, 16, 112, 112]) 16 -> 16 -> 16 SE=False RE s=2
                SEInvertedBottleneck(in_channels=16, mid_channels=16, out_channels=16, kernel_size=3, stride=2,
                                     activate='relu', use_se=True, se_kernel_size=56),
                # torch.Size([1, 16, 56, 56])   16 -> 72 -> 24 SE=False RE s=2
                SEInvertedBottleneck(in_channels=16, mid_channels=72//2, out_channels=24, kernel_size=3, stride=2,
                                     activate='relu', use_se=False),
                # torch.Size([1, 24, 28, 28])   24 -> 88 -> 24 SE=False RE s=1
                SEInvertedBottleneck(in_channels=24, mid_channels=88//2, out_channels=24, kernel_size=3, stride=1,
                                     activate='relu', use_se=False),
                # torch.Size([1, 24, 28, 28])   24 -> 96 -> 40 SE=True RE s=2
                SEInvertedBottleneck(in_channels=24, mid_channels=96//2, out_channels=40, kernel_size=5, stride=2,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 40, 14, 14])   40 -> 240 -> 40 SE=True RE s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=240//2, out_channels=40, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 40, 14, 14])   40 -> 240 -> 40 SE=True RE s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=240//2, out_channels=40, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 40, 14, 14])   40 -> 120 -> 48 SE=True RE s=1
                SEInvertedBottleneck(in_channels=40, mid_channels=120//2, out_channels=48, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 48, 14, 14])   48 -> 144 -> 48 SE=True RE s=1
                SEInvertedBottleneck(in_channels=48, mid_channels=144//2, out_channels=48, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=14),
                # torch.Size([1, 48, 14, 14])   48 -> 288 -> 96 SE=True RE s=2
                SEInvertedBottleneck(in_channels=48, mid_channels=288//2, out_channels=96, kernel_size=5, stride=2,
                                     activate='hswish', use_se=True, se_kernel_size=7),
                # torch.Size([1, 96, 7, 7])     96 -> 576 -> 96 SE=True RE s=1
                SEInvertedBottleneck(in_channels=96, mid_channels=576//2, out_channels=96, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=7),
                # torch.Size([1, 96, 7, 7])     96 -> 576 -> 96 SE=True RE s=1
                SEInvertedBottleneck(in_channels=96, mid_channels=576//2, out_channels=96, kernel_size=5, stride=1,
                                     activate='hswish', use_se=True, se_kernel_size=7),
            )
            

            # torch.Size([1, 96, 7, 7])
            # 相比MobileNetV2,尾部结构改变,,变得更加的高效
            self.small_last_stage = nn.Sequential(
                nn.Conv2d(in_channels=96, out_channels=576, kernel_size=1, stride=1),
                nn.BatchNorm2d(576),
                HardSwish(),
                nn.AvgPool2d(kernel_size=7, stride=1),
                nn.Conv2d(in_channels=576, out_channels=1280, kernel_size=1, stride=1),
                HardSwish(),
            )
        self.dorpout = nn.Dropout(0.5)
        self.classifier =nn.Linear(in_features=1280, out_features=num_classes)

        # self.init_params()

    def forward(self, x):
        x = self.first_conv(x)  # torch.Size([1, 16, 112, 112])
        if self.type == 'large':
            x = self.large_bottleneck(x)  # torch.Size([1, 160, 7, 7])
            x = self.large_last_stage(x)  # torch.Size([1, 1280, 1, 1])
        if self.type == 'small':
            x = self.small_bottleneck(x)  # torch.Size([1, 96, 7, 7])
            x = self.small_last_stage(x)  # torch.Size([1, 1280, 1, 1])
        x = x.reshape((x.shape[0], -1))  # torch.Size([1, 1280])
        x = self.dorpout(x)
        x = self.classifier(x)  # torch.Size([1, 5])
        return x
if __name__ == '__main__':
    models = MobileNetV3(8,type='large').to(device)
    input = torch.randn(size=[1, 3, 224, 224]).to(device)
    out = models(input)
    print(out.shape)
    torchsummary.summary(models,input_size=(3,224,224))

六、开始训练

import numpy
models = MobileNetV3(8,type='large').to('cuda')
# 设置优化器
optim = torch.optim.Adam(lr=0.001, params=models.parameters())
# 设置损失函数
loss_fn = torch.nn.CrossEntropyLoss().to('cuda')
bestacc=0
for epoch in range(20):
    train_data=0
    acc_data=0
    loss_data=0
    models.train()
    for batch_id, data in enumerate(train_dataloader):
        x_data,label=data
        predicts=models(x_data.to('cuda'))
        loss=loss_fn(predicts, label.to('cuda'))
        acc=numpy.sum(numpy.argmax(predicts.cpu().detach().numpy(), axis=1)==label.numpy())
        train_data+=len(x_data)
        acc_data+=acc
        loss_data+=loss
        # callbacks.step(loss)
        loss.backward()
        optim.step()
        optim.zero_grad()
    accuracy=acc_data/train_data
    all_loss=loss_data/batch_id
    print(f"train:eopch:{epoch} train: acc:{accuracy} loss:{all_loss.item()}",end=' ')
    if epoch+1:
        models.eval()
        test_data=0
        acc_data=0
        for batch_id, data in enumerate(test_dataloader):
            x_data,label=data
            predicts=models(x_data.to('cuda'))
            acc=numpy.sum(numpy.argmax(predicts.cpu().detach().numpy(), axis=1)==label.numpy())
            test_data+=len(x_data)
            acc_data+=acc
        accuracy=acc_data/test_data
        print(f"test: acc:{accuracy}")
        if accuracy > bestacc:
            torch.save(models.state_dict(), "best.pth")
            bestacc = accuracy
    print("Done")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1062154.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

不容易解的题10.5

31.下一个排列 31. 下一个排列 - 力扣(LeetCode)https://leetcode.cn/problems/next-permutation/?envTypelist&envIdZCa7r67M会做就不算难题,如果没做过不知道思路,这道题将会变得很难。 这道题相当于模拟cpp的next_permu…

Windows系统上使用CLion远程开发Linux程序

CLion远程开发Linux程序 情景说明Ubuntu配置CLion配置同步 情景说明 在Windows系统上使用CLion开发Linux程序,安装CLion集成化开发环境时会自动安装cmake、mingw,代码提示功能也比较友好。 但是在socket开发时,包含sys/socket.h头文件时&am…

浅谈CDN内容分发与全局负载均衡

CDN简介 CDN的全称是Content Delivery Network,即内容分发网络。CDN是构建在现有网络基础之上的智能虚拟网络,依靠部署在各地的边缘服务器,通过中心平台的负载均衡、内容分发、调度等功能模块,使用户就近获取所需内容&#xff0c…

软件设计师_数据结构与算法_学习笔记

文章目录 6.1 数组与矩阵6.1.1 数组6.1.2 稀疏矩阵 6.2 线性表6.2.1 数据结构的定义6.2.2 顺序表与链表6.2.2.1 定义6.2.2.2 链表的操作 6.2.3 顺序存储和链式存储的对比6.2.4 队列、循环队列、栈6.2.4.2 循环队列队空与队满条件6.2.4.3 出入后不可能出现的序列练习 6.2.5 串 6…

C语言学习系列->联合体and枚举

文章目录 前言联合体概述联合体的特点联合体大小的计算优点练习 枚举概述优点使用 前言 在上一篇文章中,小编将结构体的学习笔记整理出来了。现在,小编将枚举和联合体笔记分享给大家。枚举和联合体与结构体一样,都是自定义类型,在…

竞赛 机器视觉的试卷批改系统 - opencv python 视觉识别

文章目录 0 简介1 项目背景2 项目目的3 系统设计3.1 目标对象3.2 系统架构3.3 软件设计方案 4 图像预处理4.1 灰度二值化4.2 形态学处理4.3 算式提取4.4 倾斜校正4.5 字符分割 5 字符识别5.1 支持向量机原理5.2 基于SVM的字符识别5.3 SVM算法实现 6 算法测试7 系统实现8 最后 0…

MySQL之逻辑备份与恢复

逻辑备份简介: 备份的是建表、建库、插入等操作所执行SQL语句,适用于中小型数据库,效率相对较低。 本质:导出的是SQL语句文件 优点:不论是什么存储引擎,都可以用mysqldump备成SQL语句 缺点:速度…

输入一个大写字母,程序根据输入字符在字母表的顺序位置n,输出一个高度为n的金字塔图形

python字母金字塔根据输入的字母输出一个字母金字塔输入一个大写字母,程序根据输入字符在字母表的顺序位置n,输出一个高度为n的金字塔图形,比如输入E时,此时 字母金字塔 # A # ABA # ABCBA # ABCDCBA # ABCDEDCBA 看到…

php单独使用think-rom数据库 | thinkphp手动关闭数据库连接

背景(think-orm2.0.61) 由于需要长时间运行一个php脚本,而运行过程并不是需要一直与数据库交互,但thinkphp主要是为web站点开发的框架,而站点一般都是数据获取完则进程结束,所以thinkphp没提供手动关闭数据…

Trie树(字典树)C++详解

字典树的定义 字典树是一个用来快速查找和存储字符串集合的数据结构。 字典树的形状 假设我们字典树里有以下5个单词: akio,akno,cspj,csps,trie 那么字典树长这样: trie 的结构非常好懂,我们…

软考高级之系统架构师之设计模式

概述 设计模式是一种通用的设计方法,实际开发中可能不止23种。为方便理解和应用,一般分为3类: 创建型,通过采用抽象类所定义的接口,封装系统中对象如何创建、组合等信息。工厂方法模式、抽象工厂模式、单例模式、建造…

堆的初步认识

在学习本节文章前要先了解:大顶堆与小顶堆: (优先级队列_加瓦不加班的博客-CSDN博客) 堆实现 计算机科学中,堆是一种基于树的数据结构,通常用完全二叉树实现。 什么叫完全二叉树? 答&#x…

8.Vue_Element

1 Ajax 1.1 Ajax介绍 1.1.1 Ajax概述 我们前端页面中的数据,如下图所示的表格中的学生信息,应该来自于后台,那么我们的后台和前端是互不影响的2个程序,那么我们前端应该如何从后台获取数据呢?因为是2个程序&#xf…

JavaEE-文件IO操作

构造方法 一般方法,有很多,我们以下只是列举几个经常使用的 注意在上述的操作过程中,无论是绝对路径下的这个文件还是相对路径下的这个文件,都是不存在的 Reader 使用 --> 文本文件 FileReader类所涉及到的一些方法 Fil…

Covert Communication 与选择波束(毫米波,大规模MIMO,可重构全息表面)

Covert Communication for Spatially Sparse mmWave Massive MIMO Channels 2023 TOC abstract 隐蔽通信,也称为低检测概率通信,旨在为合法用户提供可靠的通信,并防止任何其他用户检测到合法通信的发生。出于下一代通信系统安全链路的强烈…

C#学习系列相关之多线程(一)----常用多线程方法总结

一、多线程的用途 在介绍多线程的方法之前首先应当知道什么是多线程, 在一个进程内部可以执行多个任务,而这每一个任务我们就可以看成是一个线程。是程序使用CPU的基本单位。进程是拥有资源的基本单位, 线程是CPU调度的基本单位。多线程的作用…

iStoreOS搭建主路由有什么好处

iStoreOS 作为一种功能强大的软路由系统,搭建主路由可以带来多种好处。本文泪雪网将详细介绍 iStoreOS 搭建主路由的好处,包括增强网络安全性、提供更灵活的网络管理、实现高级功能和提升性能等方面。 一、增强网络安全性 iStoreOS 搭建主路由可以增强网…

计算机毕业设计 基于协调过滤算法的绿色食品推荐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

YOLOV7改进-添加EIOU,SIOU,AlphaIOU,FocalEIOU

打开utils->general.py 找到bbox_iou(),345行左右,将下面的与源码进行替换 def bbox_iou(box1, box2, x1y1x2y2True, GIoUFalse, DIoUFalse, CIoUFalse, SIoUFalse, EIoUFalse, WIoUFalse, FocalFalse, alpha1, gamma0.5, sc…

【智慧校园源码】中小学智慧班牌系统,实现校园信息化交流建设,提高班级管理效率

智慧班牌系统源码 电子班牌原生小程序源码 智慧校园云平台系统源码 智慧班牌可以通过以云平台为基础,结合互联网、物联网系统进行校园管理,实现学校数据、教学资源共享,推进校园信息化交流建设。而展示在班牌终端的信息可以随时更改和上传新的…