第J4周:ResNet与DenseNet结合--DPN(pytorch版)

news2024/11/15 22:45:46

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

📌本周任务:📌

● 任务类型:自主探索⭐⭐

● 任务难度:偏难

●任务描述:

1、请根据J1~J3周的内容自有探索ResNet与DenseNet结合的可能性

2、是否可以根据两种特性构建一个新的模型框架?

3、请用之前的任一图像识别任务验证改进后模型的效果

🏡 我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Notebook
  • 深度学习环境:Pytorch
    • torch==2.3.1+cu118
    • torchvision==0.18.1+cu118

一、论文导读

论文:Dual Path Networks
论文链接:https://arxiv.org/abs/1707.01629
代码:https://github.com/cypw/DPNs
MXNet框架下可训练模型的DPN代码:https://github.com/miraclewkf/DPN

残差网络[2]和DenseNet[3]是short-cut系列网络的最为经典的两个基础网络,其中残差网络通过单位加的方式直接将输入加到输出的卷积上,DenseNet则是通过拼接的方式将输出与之后的每一层的输入进行拼接。

DPN(Dual Path Networks)是一种网络结构,它结合了DensNet和ResNetXt两种思想的优点。这种结构的目的是通过不同的路径来利用神经网络的不同特性,从而提高模型的效率和性能。

DenseNet 的特点是其稠密连接路径,使得网络能够在不同层级之间持续地探索新的特征。这种连接方式允许网络在不增加参数的情况下学习到更丰富的特征表示。
ResNeXt(残差分组卷积)则是通过残差路径实现特征的复用,这有助于减少模型的大小和复杂度。
DPN的设计思想在于融合这两种思想,通过两个并行的路径来进行信息传递:

一条路径是通过DenseNet的方式,即通过稠密连接路径,这样可以持续地探索新的特征。
另一条路径是通过ResNeXt的方式,即通过残差路径,可以实现特征的复用。
此外,DPN使用了分组卷积来降低计算量,并且可以在不改变原有网络结构的前提下,提升性能,使其适合用于检测和分割任务作为新的Backbone网络。

总结:DPN可以说是融合了ResNeXt和DenseNet的核心思想:Dual Path Network(DPN)以ResNet为主要框架,保证了特征的低冗余度,并在其基础上添加了一个非常小的DenseNet分支,用于生成新的特征。

那么DPN到底有哪些优点呢?可以看以下两点:
1、关于模型复杂度,作者的原文是这么说的:The DPN-92 costs about 15% fewer parameters than ResNeXt-101 (32 4d), while the DPN-98 costs about 26% fewer parameters than ResNeXt-101 (64 4d).
2、关于计算复杂度,作者的原文是这么说的:DPN-92 consumes about 19% less FLOPs than ResNeXt-101(32 4d), and the DPN-98 consumes about 25% less FLOPs than ResNeXt-101(64 4d).

由上图可知,其实DPN和ResNeXt(ResNet)的结构很相似。最开始一个7*7的卷积层和max pooling层,然后是4个stage,每个stage包含几个sub-stage(后面会介绍),再接着是一个global average pooling和全连接层,最后是softmax层。重点在于stage里面的内容,也是DPN算法的核心。

因为DPN算法简单讲就是将ResNeXt和DenseNet融合成一个网络,因此在介绍DPN的每个stage里面的结构之前,先简单过一下ResNet(ResNeXt和ResNet的子结构在宏观上是一样的)和DenseNet的核心内容。

下图中的(a)是ResNet的某个stage中的一部分。(a)的左边竖着的大矩形框表示输入输出内容,对一个输入x,分两条线走,一条线还是x本身,另一条线是x经过1×1卷积,3×3卷积,1×1卷积(这三个卷积层的组合又称作bottleneck),然后把这两条线的输出做一个element-wise addition,也就是对应值相加,就是(a)中的加号,得到的结果又变成下一个同样模块的输入,几个这样的模块组合在一起就成了一个stage(比如Table1中的conv3)。

(b)表示DenseNet的核心内容。(b)的左边竖着的多边形框表示输入输出内容,对输入x,只走一条线,那就是经过几层卷积后和x做一个通道的合并(cancat),得到的结果又成了下一个小模块的输入,这样每一个小模块的输入都在不断累加,举个例子:第二个小模块的输入包含第一个小模块的输出和第一个小模块的输入,以此类推。

DPN是怎么做呢?简单讲就是将Residual Network 和 Densely Connected Network融合在一起。下图中的(d)和(e)是一个意思,所以就按(e)来讲吧。(e)中竖着的矩形框和多边形框的含义和前面一样。具体在代码中,对于一个输入x(分两种情况:一种是如果x是整个网络第一个卷积层的输出或者某个stage的输出,会对x做一个卷积,然后做slice,也就是将输出按照channel分成两部分:data_o1和data_o2,可以理解为(e)中竖着的矩形框和多边形框;另一种是在stage内部的某个sub-stage的输出,输出本身就包含两部分:data_o1和data_o2),走两条线,一条线是保持data_o1和data_o2本身,和ResNet类似;另一条线是对x做1×1卷积,3×3卷积,1×1卷积,然后再做slice得到两部分c1和c2,最后c1和data_o1做相加(element-wise addition)得到sum,类似ResNet中的操作;c2和data_o2做通道合并(concat)得到dense(这样下一层就可以得到这一层的输出和这一层的输入),也就是最后返回两个值:sum和dense。
以上这个过程就是DPN中 一个stage中的一个sub-stage。有两个细节,一个是3×3的卷积采用的是group操作,类似ResNeXt,另一个是在每个sub-stage的首尾都会对dense部分做一个通道的加宽操作。

作者在MXNet框架下实现了DPN算法,具体的symbol可以看:https://github.com/cypw/DPNs/tree/master/settings,介绍得非常详细也很容易读懂。

实验结果:
Table2是在ImageNet-1k数据集上和目前最好的几个算法的对比:ResNet,ResNeXt,DenseNet。可以看出在模型大小,GFLOP和准确率方面DPN网络都更胜一筹。不过在这个对比中好像DenseNet的表现不如DenseNet那篇论文介绍的那么喜人,可能是因为DenseNet的需要更多的训练技巧。

Figure3是关于训练速度和存储空间的对比。现在对于模型的改进,可能准确率方面的提升已经很难作为明显的创新点,因为幅度都不大,因此大部分还是在模型大小和计算复杂度上优化,同时只要准确率还能提高一点就算进步了。 

总结:
作者提出的DPN网络可以理解为在ResNeXt的基础上引入了DenseNet的核心内容,使得模型对特征的利用更加充分。原理方面并不难理解,而且在跑代码过程中也比较容易训练,同时文章中的实验也表明模型在分类和检测的数据集上都有不错的效果。 

参考文章:

DPN(Dual Path Network)算法详解_dpn(dual path network)算法详解-CSDN博客

DPN网络-CSDN博客

DPN详解(Dual Path Networks) - 知乎 (zhihu.com)

解读Dual Path Networks(DPN,原创) - 知乎 (zhihu.com)

二、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

import warnings
warnings.filterwarnings("ignore") #忽略警告信息

import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device

运行结果:

device(type='cuda')

2. 导入数据

import pathlib
data_dir=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos'
data_dir=pathlib.Path(data_dir)

img_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",img_count)

运行结果:

图片总数为: 565

3. 查看数据集分类

data_paths=list(data_dir.glob('*'))
classNames=[str(path).split('\\')[5] for path in data_paths]
classNames

运行结果:

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
import PIL,random

4. 随机查看图片

随机抽取数据集中的10张图片进行查看

import PIL,random
import matplotlib.pyplot as plt
from PIL import Image
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
plt.suptitle("OreoCC的案例",fontsize=15)
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.axis("off")
    image=random.choice(data_paths2) #随机选择一个图片
    plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名
    plt.imshow(Image.open(str(image)))  #显示图片

运行结果: 

 

5. 图片预处理    

import torchvision.transforms as transforms
from torchvision import transforms,datasets

train_transforms=transforms.Compose([
    transforms.Resize([224,224]), #将图片统一尺寸
    transforms.RandomHorizontalFlip(), #将图片随机水平翻转
    transforms.RandomRotation(0.2), #将图片按照0.2的弧度值随机旋转
    transforms.ToTensor(), #将图片转换为tensor
    transforms.Normalize(  #标准化处理->转换为正态分布,使模型更容易收敛
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

total_data=datasets.ImageFolder(
    r"D:\THE MNIST DATABASE\J-series\J1\bird_photos",
    transform=train_transforms
)
total_data

运行结果: 

Dataset ImageFolder
    Number of datapoints: 565
    Root location: D:\THE MNIST DATABASE\J-series\J1\bird_photos
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

将数据集分类情况进行映射输出:

total_data.class_to_idx

运行结果:

{'Bananaquit': 0,
 'Black Skimmer': 1,
 'Black Throated Bushtiti': 2,
 'Cockatoo': 3}

6. 划分数据集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size

train_dataset,test_dataset=torch.utils.data.random_split(
    total_data,[train_size,test_size]
)
train_dataset,test_dataset

运行结果:

(<torch.utils.data.dataset.Subset at 0x270de0de310>,
 <torch.utils.data.dataset.Subset at 0x270de0de950>)

查看训练集和测试集的数据数量:

train_size,test_size

运行结果:

(452, 113)

7. 加载数据集

batch_size=16
train_dl=torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)
test_dl=torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)

查看测试集的情况:

for x,y in train_dl:
    print("Shape of x [N,C,H,W]:",x.shape)
    print("Shape of y:",y.shape,y.dtype)
    break

运行结果:

Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64

二、手动搭建DPN模型

1、搭建DPN模型

import torch
import torch.nn as nn

class Block(nn.Module):
    """
    param:in_channel--输入通道数
    mid_channel--中间经历的通道数
    out_channel--ResNet部分使用的通道数(sum操作,这部分输出仍然是out_channel 1个通道)
    dense_channel--DenseNet部分使用的通道数(concat操作,这部分输出是2*dense_channel 1个通道)
    groups--conv2中的分组卷积参数
    is_shortcut--ResNet前是否进行shortcut操作
    """
    def __init__(self,in_channel,mid_channel,out_channel,dense_channel,stride,groups,is_shortcut=False):
        super(Block,self).__init__()
        
        self.is_shortcut=is_shortcut
        self.out_channel=out_channel
        self.conv1=nn.Sequential(
            nn.Conv2d(in_channel,mid_channel,kernel_size=1,bias=False),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU()
        )
        
        self.conv2=nn.Sequential(
            nn.Conv2d(mid_channel,mid_channel,kernel_size=3,stride=stride,padding=1,groups=groups,bias=False),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU()
        )
        
        self.conv3=nn.Sequential(
            nn.Conv2d(mid_channel,out_channel+dense_channel,kernel_size=1,bias=False),
            nn.BatchNorm2d(out_channel+dense_channel)
        )
        
        if self.is_shortcut:
            self.shortcut=nn.Sequential(
                nn.Conv2d(in_channel,out_channel+dense_channel,kernel_size=3,padding=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_channel+dense_channel)
            )
            
        self.relu=nn.ReLU(inplace=True)
        
    def forward(self,x):
        a=x
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        if self.is_shortcut:
            a=self.shortcut(a)
            
        #a[:,:self.out_channel,:,:]+[:,:self.out_channel,:,:]是使用ResNet的方法,
        #即采用sum的方式将特征图进行求和,通道数不变,都是out-channel个通道
        #[a[:,self.out_channel,:,:],x[:,self.out_channel:,:,:]]是使用DenseNet的方法,
        #即采用concat的方式将特征图在channel维度上直接进行叠加,通道数加倍,即2*dense_channel
        x=torch.cat([a[:,:self.out_channel,:,:]+x[:,:self.out_channel,:,:],
                     a[:,self.out_channel:,:,:],x[:,self.out_channel:,:,:]],dim=1)
        x=self.relu(x)
        
        return x
    
class DPN(nn.Module):
    def __init__(self,cfg):
        super(DPN,self).__init__()
        
        self.group=cfg['group']
        self.in_channel=cfg['in_channel']
        mid_channels=cfg['mid_channels']
        out_channels=cfg['out_channels']
        dense_channels=cfg['dense_channels']
        num=cfg['num']
        
        self.conv1=nn.Sequential(
            nn.Conv2d(3,self.in_channel,7,stride=2,padding=3,bias=False,padding_mode='zeros'),
            nn.BatchNorm2d(self.in_channel),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0)
        )
        self.conv2=self._make_layers(mid_channels[0],out_channels[0],dense_channels[0],num[0],stride=1)
        self.conv3=self._make_layers(mid_channels[1],out_channels[1],dense_channels[1],num[1],stride=2)
        self.conv4=self._make_layers(mid_channels[2],out_channels[2],dense_channels[2],num[2],stride=2)
        self.conv5=self._make_layers(mid_channels[3],out_channels[3],dense_channels[3],num[3],stride=2)
        self.pool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(cfg['out_channels'][3]+(num[3]+1)*cfg['dense_channels'][3],cfg['classes']) #fc层需要计算
        
    def _make_layers(self,mid_channel,out_channel,dense_channel,num,stride):
        layers=[]
        """is_shortcut=True表示进行shortcut操作,则将浅层的特征进行一次卷积后与进行第三次卷积的特征图相加
        (ResNet方式)和concat(DenseNet方式)操作"""
        """第一次使用Block可以满足浅层特征的利用,后续重复的Block则不需要浅层特征,因此后续的Block的
        is_shortcut=False(默认值)"""
        layers.append(Block(self.in_channel,mid_channel,out_channel,dense_channel,
                            stride=stride,groups=self.group,is_shortcut=True))
        self.in_channel=out_channel+dense_channel*2
        for i in range(1,num):
            layers.append(Block(self.in_channel,mid_channel,out_channel,dense_channel,
                                stride=1,groups=self.group))
            """由于Block包含DenseNet在叠加特征图,所以第一次是2倍dense_channel,
            后面每次都会多出1倍dense_channel"""
            self.in_channel+=dense_channel
        return nn.Sequential(*layers)
    
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=self.conv4(x)
        x=self.conv5(x)
        x=self.pool(x)
        x=torch.flatten(x,start_dim=1)
        x=self.fc(x)
        return x

2、建立DPN92并显示模型结构

def DPN92(n_class=4):
    cfg={
        "group":32,
        "in_channel":64,
        "mid_channels":(96,192,384,768),
        "out_channels":(256,512,1024,2048),
        "dense_channels":(16,32,24,128),
        "num":(3,4,20,3),
        "classes":(n_class)
    }
    return DPN(cfg)

def DPN98(n_class4):
    cfg={
        "group":40,
        "in_channel":96,
        "mid_channels":(160,320,640,1280),
        "out_channels":(256,512,1024,2048),
        "dense_channels":(16,32,32,128),
        "num":(3,6,20,3),
        "classes":(n_class)
    }
    return DPN(cfg)

model=DPN92().to(device)
model

运行结果:

DPN(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Block(
      (conv1): Sequential(
        (0): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(64, 272, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): Block(
      (conv1): Sequential(
        (0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (2): Block(
      (conv1): Sequential(
        (0): Conv2d(304, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(96, 272, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
  )
  (conv3): Sequential(
    (0): Block(
      (conv1): Sequential(
        (0): Conv2d(320, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(320, 544, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): Block(
      (conv1): Sequential(
        (0): Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (2): Block(
      (conv1): Sequential(
        (0): Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (3): Block(
      (conv1): Sequential(
        (0): Conv2d(640, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(192, 544, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
  )
  (conv4): Sequential(
    (0): Block(
      (conv1): Sequential(
        (0): Conv2d(672, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(672, 1048, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): Block(
      (conv1): Sequential(
        (0): Conv2d(1072, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (2): Block(
      (conv1): Sequential(
        (0): Conv2d(1096, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (3): Block(
      (conv1): Sequential(
        (0): Conv2d(1120, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (4): Block(
      (conv1): Sequential(
        (0): Conv2d(1144, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (5): Block(
      (conv1): Sequential(
        (0): Conv2d(1168, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (6): Block(
      (conv1): Sequential(
        (0): Conv2d(1192, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (7): Block(
      (conv1): Sequential(
        (0): Conv2d(1216, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (8): Block(
      (conv1): Sequential(
        (0): Conv2d(1240, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (9): Block(
      (conv1): Sequential(
        (0): Conv2d(1264, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (10): Block(
      (conv1): Sequential(
        (0): Conv2d(1288, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (11): Block(
      (conv1): Sequential(
        (0): Conv2d(1312, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (12): Block(
      (conv1): Sequential(
        (0): Conv2d(1336, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (13): Block(
      (conv1): Sequential(
        (0): Conv2d(1360, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (14): Block(
      (conv1): Sequential(
        (0): Conv2d(1384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (15): Block(
      (conv1): Sequential(
        (0): Conv2d(1408, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (16): Block(
      (conv1): Sequential(
        (0): Conv2d(1432, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (17): Block(
      (conv1): Sequential(
        (0): Conv2d(1456, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (18): Block(
      (conv1): Sequential(
        (0): Conv2d(1480, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (19): Block(
      (conv1): Sequential(
        (0): Conv2d(1504, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(384, 1048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(1048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
  )
  (conv5): Sequential(
    (0): Block(
      (conv1): Sequential(
        (0): Conv2d(1528, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(768, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(1528, 2176, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): Block(
      (conv1): Sequential(
        (0): Conv2d(2304, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (2): Block(
      (conv1): Sequential(
        (0): Conv2d(2432, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(768, 2176, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(2176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2560, out_features=4, bias=True)
)

3、查看模型详情

#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))

运行结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 55, 55]               0
            Conv2d-5           [-1, 96, 55, 55]           6,144
       BatchNorm2d-6           [-1, 96, 55, 55]             192
              ReLU-7           [-1, 96, 55, 55]               0
            Conv2d-8           [-1, 96, 55, 55]           2,592
       BatchNorm2d-9           [-1, 96, 55, 55]             192
             ReLU-10           [-1, 96, 55, 55]               0
           Conv2d-11          [-1, 272, 55, 55]          26,112
      BatchNorm2d-12          [-1, 272, 55, 55]             544
           Conv2d-13          [-1, 272, 55, 55]         156,672
      BatchNorm2d-14          [-1, 272, 55, 55]             544
             ReLU-15          [-1, 288, 55, 55]               0
            Block-16          [-1, 288, 55, 55]               0
           Conv2d-17           [-1, 96, 55, 55]          27,648
      BatchNorm2d-18           [-1, 96, 55, 55]             192
             ReLU-19           [-1, 96, 55, 55]               0
           Conv2d-20           [-1, 96, 55, 55]           2,592
      BatchNorm2d-21           [-1, 96, 55, 55]             192
             ReLU-22           [-1, 96, 55, 55]               0
           Conv2d-23          [-1, 272, 55, 55]          26,112
      BatchNorm2d-24          [-1, 272, 55, 55]             544
             ReLU-25          [-1, 304, 55, 55]               0
            Block-26          [-1, 304, 55, 55]               0
           Conv2d-27           [-1, 96, 55, 55]          29,184
      BatchNorm2d-28           [-1, 96, 55, 55]             192
             ReLU-29           [-1, 96, 55, 55]               0
           Conv2d-30           [-1, 96, 55, 55]           2,592
      BatchNorm2d-31           [-1, 96, 55, 55]             192
             ReLU-32           [-1, 96, 55, 55]               0
           Conv2d-33          [-1, 272, 55, 55]          26,112
      BatchNorm2d-34          [-1, 272, 55, 55]             544
             ReLU-35          [-1, 320, 55, 55]               0
            Block-36          [-1, 320, 55, 55]               0
           Conv2d-37          [-1, 192, 55, 55]          61,440
      BatchNorm2d-38          [-1, 192, 55, 55]             384
             ReLU-39          [-1, 192, 55, 55]               0
           Conv2d-40          [-1, 192, 28, 28]          10,368
      BatchNorm2d-41          [-1, 192, 28, 28]             384
             ReLU-42          [-1, 192, 28, 28]               0
           Conv2d-43          [-1, 544, 28, 28]         104,448
      BatchNorm2d-44          [-1, 544, 28, 28]           1,088
           Conv2d-45          [-1, 544, 28, 28]       1,566,720
      BatchNorm2d-46          [-1, 544, 28, 28]           1,088
             ReLU-47          [-1, 576, 28, 28]               0
            Block-48          [-1, 576, 28, 28]               0
           Conv2d-49          [-1, 192, 28, 28]         110,592
      BatchNorm2d-50          [-1, 192, 28, 28]             384
             ReLU-51          [-1, 192, 28, 28]               0
           Conv2d-52          [-1, 192, 28, 28]          10,368
      BatchNorm2d-53          [-1, 192, 28, 28]             384
             ReLU-54          [-1, 192, 28, 28]               0
           Conv2d-55          [-1, 544, 28, 28]         104,448
      BatchNorm2d-56          [-1, 544, 28, 28]           1,088
             ReLU-57          [-1, 608, 28, 28]               0
            Block-58          [-1, 608, 28, 28]               0
           Conv2d-59          [-1, 192, 28, 28]         116,736
      BatchNorm2d-60          [-1, 192, 28, 28]             384
             ReLU-61          [-1, 192, 28, 28]               0
           Conv2d-62          [-1, 192, 28, 28]          10,368
      BatchNorm2d-63          [-1, 192, 28, 28]             384
             ReLU-64          [-1, 192, 28, 28]               0
           Conv2d-65          [-1, 544, 28, 28]         104,448
      BatchNorm2d-66          [-1, 544, 28, 28]           1,088
             ReLU-67          [-1, 640, 28, 28]               0
            Block-68          [-1, 640, 28, 28]               0
           Conv2d-69          [-1, 192, 28, 28]         122,880
      BatchNorm2d-70          [-1, 192, 28, 28]             384
             ReLU-71          [-1, 192, 28, 28]               0
           Conv2d-72          [-1, 192, 28, 28]          10,368
      BatchNorm2d-73          [-1, 192, 28, 28]             384
             ReLU-74          [-1, 192, 28, 28]               0
           Conv2d-75          [-1, 544, 28, 28]         104,448
      BatchNorm2d-76          [-1, 544, 28, 28]           1,088
             ReLU-77          [-1, 672, 28, 28]               0
            Block-78          [-1, 672, 28, 28]               0
           Conv2d-79          [-1, 384, 28, 28]         258,048
      BatchNorm2d-80          [-1, 384, 28, 28]             768
             ReLU-81          [-1, 384, 28, 28]               0
           Conv2d-82          [-1, 384, 14, 14]          41,472
      BatchNorm2d-83          [-1, 384, 14, 14]             768
             ReLU-84          [-1, 384, 14, 14]               0
           Conv2d-85         [-1, 1048, 14, 14]         402,432
      BatchNorm2d-86         [-1, 1048, 14, 14]           2,096
           Conv2d-87         [-1, 1048, 14, 14]       6,338,304
      BatchNorm2d-88         [-1, 1048, 14, 14]           2,096
             ReLU-89         [-1, 1072, 14, 14]               0
            Block-90         [-1, 1072, 14, 14]               0
           Conv2d-91          [-1, 384, 14, 14]         411,648
      BatchNorm2d-92          [-1, 384, 14, 14]             768
             ReLU-93          [-1, 384, 14, 14]               0
           Conv2d-94          [-1, 384, 14, 14]          41,472
      BatchNorm2d-95          [-1, 384, 14, 14]             768
             ReLU-96          [-1, 384, 14, 14]               0
           Conv2d-97         [-1, 1048, 14, 14]         402,432
      BatchNorm2d-98         [-1, 1048, 14, 14]           2,096
             ReLU-99         [-1, 1096, 14, 14]               0
           Block-100         [-1, 1096, 14, 14]               0
          Conv2d-101          [-1, 384, 14, 14]         420,864
     BatchNorm2d-102          [-1, 384, 14, 14]             768
            ReLU-103          [-1, 384, 14, 14]               0
          Conv2d-104          [-1, 384, 14, 14]          41,472
     BatchNorm2d-105          [-1, 384, 14, 14]             768
            ReLU-106          [-1, 384, 14, 14]               0
          Conv2d-107         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-108         [-1, 1048, 14, 14]           2,096
            ReLU-109         [-1, 1120, 14, 14]               0
           Block-110         [-1, 1120, 14, 14]               0
          Conv2d-111          [-1, 384, 14, 14]         430,080
     BatchNorm2d-112          [-1, 384, 14, 14]             768
            ReLU-113          [-1, 384, 14, 14]               0
          Conv2d-114          [-1, 384, 14, 14]          41,472
     BatchNorm2d-115          [-1, 384, 14, 14]             768
            ReLU-116          [-1, 384, 14, 14]               0
          Conv2d-117         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-118         [-1, 1048, 14, 14]           2,096
            ReLU-119         [-1, 1144, 14, 14]               0
           Block-120         [-1, 1144, 14, 14]               0
          Conv2d-121          [-1, 384, 14, 14]         439,296
     BatchNorm2d-122          [-1, 384, 14, 14]             768
            ReLU-123          [-1, 384, 14, 14]               0
          Conv2d-124          [-1, 384, 14, 14]          41,472
     BatchNorm2d-125          [-1, 384, 14, 14]             768
            ReLU-126          [-1, 384, 14, 14]               0
          Conv2d-127         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-128         [-1, 1048, 14, 14]           2,096
            ReLU-129         [-1, 1168, 14, 14]               0
           Block-130         [-1, 1168, 14, 14]               0
          Conv2d-131          [-1, 384, 14, 14]         448,512
     BatchNorm2d-132          [-1, 384, 14, 14]             768
            ReLU-133          [-1, 384, 14, 14]               0
          Conv2d-134          [-1, 384, 14, 14]          41,472
     BatchNorm2d-135          [-1, 384, 14, 14]             768
            ReLU-136          [-1, 384, 14, 14]               0
          Conv2d-137         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-138         [-1, 1048, 14, 14]           2,096
            ReLU-139         [-1, 1192, 14, 14]               0
           Block-140         [-1, 1192, 14, 14]               0
          Conv2d-141          [-1, 384, 14, 14]         457,728
     BatchNorm2d-142          [-1, 384, 14, 14]             768
            ReLU-143          [-1, 384, 14, 14]               0
          Conv2d-144          [-1, 384, 14, 14]          41,472
     BatchNorm2d-145          [-1, 384, 14, 14]             768
            ReLU-146          [-1, 384, 14, 14]               0
          Conv2d-147         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-148         [-1, 1048, 14, 14]           2,096
            ReLU-149         [-1, 1216, 14, 14]               0
           Block-150         [-1, 1216, 14, 14]               0
          Conv2d-151          [-1, 384, 14, 14]         466,944
     BatchNorm2d-152          [-1, 384, 14, 14]             768
            ReLU-153          [-1, 384, 14, 14]               0
          Conv2d-154          [-1, 384, 14, 14]          41,472
     BatchNorm2d-155          [-1, 384, 14, 14]             768
            ReLU-156          [-1, 384, 14, 14]               0
          Conv2d-157         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-158         [-1, 1048, 14, 14]           2,096
            ReLU-159         [-1, 1240, 14, 14]               0
           Block-160         [-1, 1240, 14, 14]               0
          Conv2d-161          [-1, 384, 14, 14]         476,160
     BatchNorm2d-162          [-1, 384, 14, 14]             768
            ReLU-163          [-1, 384, 14, 14]               0
          Conv2d-164          [-1, 384, 14, 14]          41,472
     BatchNorm2d-165          [-1, 384, 14, 14]             768
            ReLU-166          [-1, 384, 14, 14]               0
          Conv2d-167         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-168         [-1, 1048, 14, 14]           2,096
            ReLU-169         [-1, 1264, 14, 14]               0
           Block-170         [-1, 1264, 14, 14]               0
          Conv2d-171          [-1, 384, 14, 14]         485,376
     BatchNorm2d-172          [-1, 384, 14, 14]             768
            ReLU-173          [-1, 384, 14, 14]               0
          Conv2d-174          [-1, 384, 14, 14]          41,472
     BatchNorm2d-175          [-1, 384, 14, 14]             768
            ReLU-176          [-1, 384, 14, 14]               0
          Conv2d-177         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-178         [-1, 1048, 14, 14]           2,096
            ReLU-179         [-1, 1288, 14, 14]               0
           Block-180         [-1, 1288, 14, 14]               0
          Conv2d-181          [-1, 384, 14, 14]         494,592
     BatchNorm2d-182          [-1, 384, 14, 14]             768
            ReLU-183          [-1, 384, 14, 14]               0
          Conv2d-184          [-1, 384, 14, 14]          41,472
     BatchNorm2d-185          [-1, 384, 14, 14]             768
            ReLU-186          [-1, 384, 14, 14]               0
          Conv2d-187         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-188         [-1, 1048, 14, 14]           2,096
            ReLU-189         [-1, 1312, 14, 14]               0
           Block-190         [-1, 1312, 14, 14]               0
          Conv2d-191          [-1, 384, 14, 14]         503,808
     BatchNorm2d-192          [-1, 384, 14, 14]             768
            ReLU-193          [-1, 384, 14, 14]               0
          Conv2d-194          [-1, 384, 14, 14]          41,472
     BatchNorm2d-195          [-1, 384, 14, 14]             768
            ReLU-196          [-1, 384, 14, 14]               0
          Conv2d-197         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-198         [-1, 1048, 14, 14]           2,096
            ReLU-199         [-1, 1336, 14, 14]               0
           Block-200         [-1, 1336, 14, 14]               0
          Conv2d-201          [-1, 384, 14, 14]         513,024
     BatchNorm2d-202          [-1, 384, 14, 14]             768
            ReLU-203          [-1, 384, 14, 14]               0
          Conv2d-204          [-1, 384, 14, 14]          41,472
     BatchNorm2d-205          [-1, 384, 14, 14]             768
            ReLU-206          [-1, 384, 14, 14]               0
          Conv2d-207         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-208         [-1, 1048, 14, 14]           2,096
            ReLU-209         [-1, 1360, 14, 14]               0
           Block-210         [-1, 1360, 14, 14]               0
          Conv2d-211          [-1, 384, 14, 14]         522,240
     BatchNorm2d-212          [-1, 384, 14, 14]             768
            ReLU-213          [-1, 384, 14, 14]               0
          Conv2d-214          [-1, 384, 14, 14]          41,472
     BatchNorm2d-215          [-1, 384, 14, 14]             768
            ReLU-216          [-1, 384, 14, 14]               0
          Conv2d-217         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-218         [-1, 1048, 14, 14]           2,096
            ReLU-219         [-1, 1384, 14, 14]               0
           Block-220         [-1, 1384, 14, 14]               0
          Conv2d-221          [-1, 384, 14, 14]         531,456
     BatchNorm2d-222          [-1, 384, 14, 14]             768
            ReLU-223          [-1, 384, 14, 14]               0
          Conv2d-224          [-1, 384, 14, 14]          41,472
     BatchNorm2d-225          [-1, 384, 14, 14]             768
            ReLU-226          [-1, 384, 14, 14]               0
          Conv2d-227         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-228         [-1, 1048, 14, 14]           2,096
            ReLU-229         [-1, 1408, 14, 14]               0
           Block-230         [-1, 1408, 14, 14]               0
          Conv2d-231          [-1, 384, 14, 14]         540,672
     BatchNorm2d-232          [-1, 384, 14, 14]             768
            ReLU-233          [-1, 384, 14, 14]               0
          Conv2d-234          [-1, 384, 14, 14]          41,472
     BatchNorm2d-235          [-1, 384, 14, 14]             768
            ReLU-236          [-1, 384, 14, 14]               0
          Conv2d-237         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-238         [-1, 1048, 14, 14]           2,096
            ReLU-239         [-1, 1432, 14, 14]               0
           Block-240         [-1, 1432, 14, 14]               0
          Conv2d-241          [-1, 384, 14, 14]         549,888
     BatchNorm2d-242          [-1, 384, 14, 14]             768
            ReLU-243          [-1, 384, 14, 14]               0
          Conv2d-244          [-1, 384, 14, 14]          41,472
     BatchNorm2d-245          [-1, 384, 14, 14]             768
            ReLU-246          [-1, 384, 14, 14]               0
          Conv2d-247         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-248         [-1, 1048, 14, 14]           2,096
            ReLU-249         [-1, 1456, 14, 14]               0
           Block-250         [-1, 1456, 14, 14]               0
          Conv2d-251          [-1, 384, 14, 14]         559,104
     BatchNorm2d-252          [-1, 384, 14, 14]             768
            ReLU-253          [-1, 384, 14, 14]               0
          Conv2d-254          [-1, 384, 14, 14]          41,472
     BatchNorm2d-255          [-1, 384, 14, 14]             768
            ReLU-256          [-1, 384, 14, 14]               0
          Conv2d-257         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-258         [-1, 1048, 14, 14]           2,096
            ReLU-259         [-1, 1480, 14, 14]               0
           Block-260         [-1, 1480, 14, 14]               0
          Conv2d-261          [-1, 384, 14, 14]         568,320
     BatchNorm2d-262          [-1, 384, 14, 14]             768
            ReLU-263          [-1, 384, 14, 14]               0
          Conv2d-264          [-1, 384, 14, 14]          41,472
     BatchNorm2d-265          [-1, 384, 14, 14]             768
            ReLU-266          [-1, 384, 14, 14]               0
          Conv2d-267         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-268         [-1, 1048, 14, 14]           2,096
            ReLU-269         [-1, 1504, 14, 14]               0
           Block-270         [-1, 1504, 14, 14]               0
          Conv2d-271          [-1, 384, 14, 14]         577,536
     BatchNorm2d-272          [-1, 384, 14, 14]             768
            ReLU-273          [-1, 384, 14, 14]               0
          Conv2d-274          [-1, 384, 14, 14]          41,472
     BatchNorm2d-275          [-1, 384, 14, 14]             768
            ReLU-276          [-1, 384, 14, 14]               0
          Conv2d-277         [-1, 1048, 14, 14]         402,432
     BatchNorm2d-278         [-1, 1048, 14, 14]           2,096
            ReLU-279         [-1, 1528, 14, 14]               0
           Block-280         [-1, 1528, 14, 14]               0
          Conv2d-281          [-1, 768, 14, 14]       1,173,504
     BatchNorm2d-282          [-1, 768, 14, 14]           1,536
            ReLU-283          [-1, 768, 14, 14]               0
          Conv2d-284            [-1, 768, 7, 7]         165,888
     BatchNorm2d-285            [-1, 768, 7, 7]           1,536
            ReLU-286            [-1, 768, 7, 7]               0
          Conv2d-287           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-288           [-1, 2176, 7, 7]           4,352
          Conv2d-289           [-1, 2176, 7, 7]      29,924,352
     BatchNorm2d-290           [-1, 2176, 7, 7]           4,352
            ReLU-291           [-1, 2304, 7, 7]               0
           Block-292           [-1, 2304, 7, 7]               0
          Conv2d-293            [-1, 768, 7, 7]       1,769,472
     BatchNorm2d-294            [-1, 768, 7, 7]           1,536
            ReLU-295            [-1, 768, 7, 7]               0
          Conv2d-296            [-1, 768, 7, 7]         165,888
     BatchNorm2d-297            [-1, 768, 7, 7]           1,536
            ReLU-298            [-1, 768, 7, 7]               0
          Conv2d-299           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-300           [-1, 2176, 7, 7]           4,352
            ReLU-301           [-1, 2432, 7, 7]               0
           Block-302           [-1, 2432, 7, 7]               0
          Conv2d-303            [-1, 768, 7, 7]       1,867,776
     BatchNorm2d-304            [-1, 768, 7, 7]           1,536
            ReLU-305            [-1, 768, 7, 7]               0
          Conv2d-306            [-1, 768, 7, 7]         165,888
     BatchNorm2d-307            [-1, 768, 7, 7]           1,536
            ReLU-308            [-1, 768, 7, 7]               0
          Conv2d-309           [-1, 2176, 7, 7]       1,671,168
     BatchNorm2d-310           [-1, 2176, 7, 7]           4,352
            ReLU-311           [-1, 2560, 7, 7]               0
           Block-312           [-1, 2560, 7, 7]               0
AdaptiveAvgPool2d-313           [-1, 2560, 1, 1]               0
          Linear-314                    [-1, 4]          10,244
================================================================
Total params: 67,994,324
Trainable params: 67,994,324
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 489.24
Params size (MB): 259.38
Estimated Total Size (MB): 749.20
----------------------------------------------------------------

三、 训练模型

1. 编写训练函数

#训练循环
def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset) #训练集的大小
    num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
    
    train_loss,train_acc=0,0 #初始化训练损失和正确率
    
    for x,y in dataloader: #获取图片及其标签
        x,y=x.to(device),y.to(device)
        
        #计算预测误差
        pred=model(x) #网络输出
        loss=loss_fn(pred,y) #计算网络输出pred和真实值y之间的差距,y为真实值,计算二者误差即为损失
        
        #反向传播
        optimizer.zero_grad() #grad属性归零
        loss.backward() #反向传播
        optimizer.step() #每一步自动更新
        
        #记录acc与loss
        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
        
    train_acc/=size
    train_loss/=num_batches
    
    return train_acc,train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset) #训练集的大小
    num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)
    test_loss,test_acc=0,0 #初始化测试损失和正确率
    
    #当不进行训练时,停止梯度更新,节省计算内存消耗
    for imgs,target in dataloader: #获取图片及其标签
        with torch.no_grad():
            imgs,target=imgs.to(device),target.to(device)
            
            #计算误差
            target_pred=model(imgs) #网络输出
            #计算网络输出和真实值之间的差距,targets为真实值,计算二者误差即为损失
            loss=loss_fn(target_pred,target) 
            
            #记录acc和loss
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
            
    test_acc/=size
    test_loss/=num_batches
    
    return test_acc,test_loss

3. 正式训练

import copy

optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss() #创建损失函数

epochs=40

train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]

best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标

#释放未使用的GPU内存,以便其他GPU应用程序可以使用这些资源
if hasattr(torch.cuda,'empty_cache'):
    torch.cuda.empty_cache()
    
for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    #scheduler.step() #更新学习率(调用官方动态学习率接口时使用)
    
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    
    #保存最佳模型到best_model
    if epoch_test_acc>best_acc:
        best_acc=epoch_test_acc
        best_model=copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    #获取当前学习率
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
                            epoch_test_acc*100,epoch_test_loss,lr))
    
PATH=r'D:\THE MNIST DATABASE\J-series\J4_best_model.pth'
torch.save(model.state_dict(),PATH)

print('Done')

运行结果:

Epoch: 1,Train_acc:60.4%,Train_loss:0.994,Test_acc:48.7%,Test_loss:1.224,Lr:1.00E-04
Epoch: 2,Train_acc:70.1%,Train_loss:0.830,Test_acc:79.6%,Test_loss:0.542,Lr:1.00E-04
Epoch: 3,Train_acc:79.2%,Train_loss:0.570,Test_acc:61.9%,Test_loss:1.461,Lr:1.00E-04
Epoch: 4,Train_acc:83.0%,Train_loss:0.491,Test_acc:81.4%,Test_loss:0.896,Lr:1.00E-04
Epoch: 5,Train_acc:84.7%,Train_loss:0.435,Test_acc:83.2%,Test_loss:0.541,Lr:1.00E-04
Epoch: 6,Train_acc:86.3%,Train_loss:0.428,Test_acc:90.3%,Test_loss:0.404,Lr:1.00E-04
Epoch: 7,Train_acc:88.5%,Train_loss:0.301,Test_acc:65.5%,Test_loss:1.831,Lr:1.00E-04
Epoch: 8,Train_acc:85.4%,Train_loss:0.406,Test_acc:73.5%,Test_loss:0.829,Lr:1.00E-04
Epoch: 9,Train_acc:90.3%,Train_loss:0.296,Test_acc:73.5%,Test_loss:1.233,Lr:1.00E-04
Epoch:10,Train_acc:91.2%,Train_loss:0.253,Test_acc:89.4%,Test_loss:0.512,Lr:1.00E-04
Epoch:11,Train_acc:92.9%,Train_loss:0.208,Test_acc:52.2%,Test_loss:2.842,Lr:1.00E-04
Epoch:12,Train_acc:95.6%,Train_loss:0.134,Test_acc:90.3%,Test_loss:0.292,Lr:1.00E-04
Epoch:13,Train_acc:91.2%,Train_loss:0.215,Test_acc:79.6%,Test_loss:0.806,Lr:1.00E-04
Epoch:14,Train_acc:93.4%,Train_loss:0.236,Test_acc:90.3%,Test_loss:0.511,Lr:1.00E-04
Epoch:15,Train_acc:90.0%,Train_loss:0.219,Test_acc:88.5%,Test_loss:0.388,Lr:1.00E-04
Epoch:16,Train_acc:95.8%,Train_loss:0.099,Test_acc:90.3%,Test_loss:0.422,Lr:1.00E-04
Epoch:17,Train_acc:98.0%,Train_loss:0.103,Test_acc:81.4%,Test_loss:0.491,Lr:1.00E-04
Epoch:18,Train_acc:95.8%,Train_loss:0.146,Test_acc:86.7%,Test_loss:0.475,Lr:1.00E-04
Epoch:19,Train_acc:95.1%,Train_loss:0.108,Test_acc:89.4%,Test_loss:1.873,Lr:1.00E-04
Epoch:20,Train_acc:98.5%,Train_loss:0.065,Test_acc:92.0%,Test_loss:0.940,Lr:1.00E-04
Epoch:21,Train_acc:95.4%,Train_loss:0.151,Test_acc:87.6%,Test_loss:0.398,Lr:1.00E-04
Epoch:22,Train_acc:96.7%,Train_loss:0.101,Test_acc:77.0%,Test_loss:0.911,Lr:1.00E-04
Epoch:23,Train_acc:97.8%,Train_loss:0.064,Test_acc:93.8%,Test_loss:0.249,Lr:1.00E-04
Epoch:24,Train_acc:97.6%,Train_loss:0.073,Test_acc:87.6%,Test_loss:0.883,Lr:1.00E-04
Epoch:25,Train_acc:98.2%,Train_loss:0.068,Test_acc:92.9%,Test_loss:0.245,Lr:1.00E-04
Epoch:26,Train_acc:98.0%,Train_loss:0.125,Test_acc:88.5%,Test_loss:0.323,Lr:1.00E-04
Epoch:27,Train_acc:96.0%,Train_loss:0.128,Test_acc:88.5%,Test_loss:0.403,Lr:1.00E-04
Epoch:28,Train_acc:97.3%,Train_loss:0.083,Test_acc:92.9%,Test_loss:0.356,Lr:1.00E-04
Epoch:29,Train_acc:96.9%,Train_loss:0.083,Test_acc:87.6%,Test_loss:0.478,Lr:1.00E-04
Epoch:30,Train_acc:96.5%,Train_loss:0.124,Test_acc:85.0%,Test_loss:0.595,Lr:1.00E-04
Epoch:31,Train_acc:94.9%,Train_loss:0.142,Test_acc:85.0%,Test_loss:0.533,Lr:1.00E-04
Epoch:32,Train_acc:95.6%,Train_loss:0.125,Test_acc:93.8%,Test_loss:0.313,Lr:1.00E-04
Epoch:33,Train_acc:97.8%,Train_loss:0.058,Test_acc:92.0%,Test_loss:0.349,Lr:1.00E-04
Epoch:34,Train_acc:95.4%,Train_loss:0.123,Test_acc:89.4%,Test_loss:0.547,Lr:1.00E-04
Epoch:35,Train_acc:97.6%,Train_loss:0.075,Test_acc:89.4%,Test_loss:0.722,Lr:1.00E-04
Epoch:36,Train_acc:96.5%,Train_loss:0.078,Test_acc:92.0%,Test_loss:0.254,Lr:1.00E-04
Epoch:37,Train_acc:98.9%,Train_loss:0.031,Test_acc:89.4%,Test_loss:0.289,Lr:1.00E-04
Epoch:38,Train_acc:99.1%,Train_loss:0.028,Test_acc:93.8%,Test_loss:0.191,Lr:1.00E-04
Epoch:39,Train_acc:98.7%,Train_loss:0.035,Test_acc:87.6%,Test_loss:0.673,Lr:1.00E-04
Epoch:40,Train_acc:98.9%,Train_loss:0.037,Test_acc:95.6%,Test_loss:0.182,Lr:1.00E-04
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']   #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False   #正常显示负号
plt.rcParams['figure.dpi']=300   #分辨率
 
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
 
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

2. 指定图片进行预测 

from PIL import Image
 
classes=list(total_data.class_to_idx)
 
def predict_one_image(image_path,model,transform,classes):
    
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)   #展示预测的图片
    
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    
    model.eval()
    output=model(img)
    
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print(f'预测结果是:{pred_class}')

预测图片:

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\J-series\J1\bird_photos\Black Skimmer\001.jpg',
                  model=model,transform=train_transforms,classes=classes)

运行结果:

预测结果是:Black Skimmer

五、心得体会 

在本周项目训练中,体会了在pytorch环境下手动搭建DPN模型的过程,加深了对DPN模型结构的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2144193.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

nodejs 010:Webpack 可视化分析插件 webpack-bundle-analyzer的使用

安装 yarn add --dev webpack-bundle-analyzer 原始webpack.config.js 代码定义了 Webpack 的配置&#xff0c;主要任务是将 JavaScript 和 CSS 文件打包&#xff0c;并将 CSS 提取到单独的文件中&#xff0c;配置了对 Electron 应用的支持&#xff0c;同时还将 React 相关的…

Blue Screen of Death(BSOD)

Blue Screen of Death&#xff08;BSOD&#xff09;蓝屏 进来就是蓝屏。。。 按【电源】开关&#xff0c;连续三次 然后非常非常慢&#xff0c;启动了十几分钟 svchost (30028,R,98) TILEREPOSITORYS-1-5-18: 打开日志文件 C:\WINDOWS\system32\config\systemprofile\AppData…

6、定义字段状态变式

定义解释 字段状态变式是分配给公司代码的一项重要参数,在字段状态变式中罗列了很多字段状态组,而字段状态组是会计科目中的一个重要参数.它控制在输入一张会计记帐凭证时,该科目的那些辅助核算项目是必须输入的,哪些是不允许输入的,哪些是可以选择的 重点&#xff1a;科目组…

Adobe After Effects的插件--------Shatter 碎片

Shatter是AE的内置插件,其可模拟爆炸、破碎效果。 该效果将【效果图层】细化成一个个【碎片单体】,当爆破时这些【碎片单体】将被冲击,从【效果图层】上滑落。 视图 用不同的方式显示【效果图层】,以便调试。值有: 已渲染:显示【效果图层】的源图层线框正视图:只显示【…

SOCKS4和SOCKS5的区别是什么?

SOCKS4和SOCKS5是两种常用的网络代理协议&#xff0c;它们在功能、性能和应用场景上存在一些关键的区别。以下是对这两种协议区别的详细解析&#xff1a; 1. 支持的协议类型 SOCKS4&#xff1a;只支持TCP协议&#xff08;传输控制协议&#xff09;。这意味着SOCKS4代理只能用…

在vmvare安装飞牛私有云 fnOS体验教程

飞牛私有云&#xff08;fnOS&#xff09;是由飞牛网&#xff08;Feiniu&#xff09;开发的一款私有云操作系统&#xff0c;旨在为企业提供高效、安全、可扩展的云计算解决方案。 官网地址&#xff1a;https://www.fnnas.com/ 本章教程&#xff0c;主要介绍如何通过vmvare安装使…

Node.js 学习

目录 1.Node.js入门 1.1 什么是 Node.js 1.2 fs模块-读写文件 1.3 path模块-路径处理 1.4 案例-压缩前端html 1.5 认识URL中的端口号 1.6 http模块-创建Web服务 1.7 案例-浏览时钟 2.Node.js 模块化 2.1 模块化简介 2.1.1 什么是模块化&#xff1f; 2.1.2 CommonJS…

C++_类和对象(中、下篇)—— const成员函数、取地址运算符的重载、深入构造函数、类型转换、static成员、友元

目录 三、类和对象&#xff08;中&#xff09; 6、取地址运算符重载 1、const成员函数 2、取地址运算符的重载 四、类和对象&#xff08;下&#xff09; 1、深入构造函数 2、类型转换 3、static成员 4、友元 三、类和对象&#xff08;中&#xff09; 6、取地址运算…

从数据仓库到数据中台再到数据飞轮:我了解的数据技术进化史

这里写目录标题 前言数据仓库&#xff1a;数据整合的起点数据中台&#xff1a;数据共享的桥梁数据飞轮&#xff1a;业务与数据的双向驱动结语 前言 在当今这个数据驱动的时代&#xff0c;企业发展离不开对数据的深度挖掘和高效利用。从最初的数据仓库&#xff0c;到后来的数据…

docker可视化管理工具推荐!docker.ui

正式介绍之前&#xff0c;可以看下这款工具的截图&#xff0c;开源地址在文末提供&#xff1a; docker.ui&#xff1a;一个可视化的docker管理工具 docker是一个开源的容器平台&#xff0c;可以让开发者和运维人员快速地构建、运行和部署应用。 docker的优势在于它可以实现应…

Cpp类和对象(上)(3)

文章目录 前言一、面向过程与面向对象初步认识二、类的引入三、类的定义四、类的访问限定符及类的封装类的访问限定符类的封装 五、类的作用域(类域)六、类的实例化七、类对象模型如何计算类对象的大小类对象的存储方式猜测 八、this指针this指针的引出this指针的特性 九、C语言…

dcmtk在MWLSCP会忽略对于字符集的匹配

版本信息 dcmtk v3.6.4 2018-11-29 发现的原因 在我将dcmtk的wlmscpfs当作MWLSCP使用的时候&#xff0c;我在SCU端为了防止过来的数据中存在不识别的字符集&#xff0c;对于收到的数据数据进行了字符集的过滤&#xff0c;但是发现过滤没有生效。 确保数据源 首先需要确认数…

pywebview 中错误使用async

错误代码 正确示例 完整代码 前端代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>pywebview …

哈希表数据结构学习

哈希表数据结构学习 哈希表基本概念哈希方法单值哈希与多值哈希哈希冲突1. 开放寻址法&#xff08;Open Addressing&#xff09;2. 链地址法&#xff08;Chaining&#xff09;3. 再哈希法&#xff08;Rehashing&#xff09;4. 建立公共溢出区&#xff08;Overflow Area&#xf…

OpenHarmony(鸿蒙南向开发)——标准系统方案之瑞芯微RK3568移植案例(下)

往期知识点记录&#xff1a; OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——轻量系统STM32F407芯片移植案例 OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——Combo解决方案之W800芯片移植案例 OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——小型系统STM32M…

『功能项目』窗口可拖拽脚本【59】

本章项目成果展示 我们打开上一篇58第三职业弓弩的平A的项目&#xff0c; 本章要做的事情是给坐骑界面挂载一个脚本让其显示出来的时候可以进行拖拽 创建脚本&#xff1a;DraggableWindow.cs using UnityEngine; using UnityEngine.EventSystems; public class DraggableWindo…

使用three.js+vue3完成无人机上下运动

效果图如上 代码&#xff1a; <template><div class"drones"><div ref"dronesContainer" class"drones-container"></div></div></template><script setup>import { ref, onMounted, onUnmounted, …

性能再升级,华为Mate 70 Pro曝光,设计新颖且配置遥遥领先

在智能手机市场竞争日益激烈的今天&#xff0c;各大厂商都在努力提升自家产品的性能和设计。 华为作为中国领先的手机品牌&#xff0c;一直备受关注。 近日&#xff0c;有关华为Mate 70 Pro的曝光信息引发了广泛关注&#xff0c;据悉&#xff0c;这款新机将在性能、设计和配置…

vue和thinkphp路由伪静态配置

vue路由伪静态配置&#xff1a; location / { try_files $uri $uri/ /index.html; } thinkphp 路由伪静态配置 location ~* (runtime|application)/{ return 403; } location / { if (!-e $request_filename){ rewrite ^(.*)$ /index.php?s$1 last; break; } }

【Java】基础语法介绍

目录 一、注释 二、标识符与关键字 三、输入和输出 3.1 输出 3.2 输入 四、数据类型 3.1 基本数据类型 3.2 引用数据类型 3.3 var关键字 五、运算符 六、分支和循环 5.1 分支 5.2 循环 七、类和对象 6.1 类的定义与对象的创建 6.2 空对象 6.3 类的属性 6.4 类…