深度学习 Day24——J3-1DenseNet算法实战与解析

news2024/10/2 12:24:29
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

文章目录

  • 前言
  • 1 我的环境
  • 2 pytorch实现DenseNet算法
    • 2.1 前期准备
      • 2.1.1 引入库
      • 2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)
      • 2.1.3 导入数据
      • 2.1.4 可视化数据
      • 2.1.4 图像数据变换
      • 2.1.4 划分数据集
      • 2.1.4 加载数据
      • 2.1.4 查看数据
    • 2.2 搭建densenet121模型
    • 2.3 训练模型
      • 2.3.1 设置超参数
      • 2.3.2 编写训练函数
      • 2.3.3 编写测试函数
      • 2.3.4 正式训练
    • 2.4 结果可视化
    • 2.4 指定图片进行预测
    • 2.6 模型评估
  • 3 知识点详解
    • 3.1 nn.Sequential和nn.Module区别与选择
      • 3.1.1 nn.Sequential
      • 3.1.2 nn.Module
      • 3.1.3 对比
      • 3.1.4 总结
    • 3.2 python中OrderedDict的使用
  • 总结


前言

关键字: pytorch实现DenseNet算法,nn.Sequential和nn.Module区别与选择,python中OrderedDict的使用

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现DenseNet算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warnings

warnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\BreastCancer"
data_dir = Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['.DS_Store', '0', '1']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "1"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
    image_file = image_files[i]
    ax = plt.subplot(3, 4, i + 1)
    img = Image.open(str(image_file))
    plt.imshow(img)
    plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

在这里插入图片描述

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolder
    Number of datapoints: 13403
    Root location: D:\DeepLearning\data\BreastCancer
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
{'0': 0, '1': 1}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06BE0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06B20>
train_size=10722
test_size=2681

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=4)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建densenet121模型

"""构建DenseNet网络"""
# 这里我们采用了Pytorch的框架来实现DenseNet,
# 首先实现DenseBlock中的内部结构,这里是BN+ReLU+1×1Conv+BN+ReLU+3×3Conv结构,最后也加入dropout层用于训练过程。
class _DenseLayer(nn.Sequential):
    """Basic unit of DenseBlock (using bottleneck layer) """

    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


# 实现DenseBlock模块,内部是密集连接方式(输入特征数线性增长):
class _DenseBlock(nn.Sequential):
    """DenseBlock """

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


# 实现Transition层,它主要是一个卷积层和一个池化层:
class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


# 最后我们实现DenseNet网络:
class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 3 or 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
            (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=24, bn_size=4, compression=0.5, drop_rate=0,
                 num_classes=1000):
        super(DenseNet, self).__init__()

        # First Conv2d
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        ]))


        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features += num_layers * growth_rate
            if i != len(block_config) - 1:
                transition = _Transition(num_input_features=num_features,
                                         num_output_features=int(num_features * compression))
                self.features.add_module('transition%d' % (i + 1), transition)
                num_features = int(num_features * compression)

        # Final bn+relu
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))
        self.features.add_module('relu5', nn.ReLU(inplace=True))

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # params initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out



model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'}


def densenet121(pretrained=False, **kwargs):
    """DenseNet121"""
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),	**kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

"""搭建densenet121模型"""
# model = densenet121().to(device)  
model = densenet121(True).to(device)  # 使用预训练模型
print(model)
print(summary.summary(model, (3, 224, 224)))  # 查看模型的参数量以及相关指标
    

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
      BatchNorm2d-11           [-1, 96, 56, 56]             192
             ReLU-12           [-1, 96, 56, 56]               0
           Conv2d-13          [-1, 128, 56, 56]          12,288
      BatchNorm2d-14          [-1, 128, 56, 56]             256
             ReLU-15          [-1, 128, 56, 56]               0
           Conv2d-16           [-1, 32, 56, 56]          36,864
      BatchNorm2d-17          [-1, 128, 56, 56]             256
             ReLU-18          [-1, 128, 56, 56]               0
           Conv2d-19          [-1, 128, 56, 56]          16,384
      BatchNorm2d-20          [-1, 128, 56, 56]             256
             ReLU-21          [-1, 128, 56, 56]               0
           Conv2d-22           [-1, 32, 56, 56]          36,864
      BatchNorm2d-23          [-1, 160, 56, 56]             320
             ReLU-24          [-1, 160, 56, 56]               0
           Conv2d-25          [-1, 128, 56, 56]          20,480
      BatchNorm2d-26          [-1, 128, 56, 56]             256
             ReLU-27          [-1, 128, 56, 56]               0
           Conv2d-28           [-1, 32, 56, 56]          36,864
      BatchNorm2d-29          [-1, 192, 56, 56]             384
             ReLU-30          [-1, 192, 56, 56]               0
           Conv2d-31          [-1, 128, 56, 56]          24,576
      BatchNorm2d-32          [-1, 128, 56, 56]             256
             ReLU-33          [-1, 128, 56, 56]               0
           Conv2d-34           [-1, 32, 56, 56]          36,864
      BatchNorm2d-35          [-1, 224, 56, 56]             448
             ReLU-36          [-1, 224, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          28,672
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40           [-1, 32, 56, 56]          36,864
      BatchNorm2d-41          [-1, 256, 56, 56]             512
             ReLU-42          [-1, 256, 56, 56]               0
           Conv2d-43          [-1, 128, 56, 56]          32,768
        AvgPool2d-44          [-1, 128, 28, 28]               0
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 128, 28, 28]          16,384
      BatchNorm2d-48          [-1, 128, 28, 28]             256
             ReLU-49          [-1, 128, 28, 28]               0
           Conv2d-50           [-1, 32, 28, 28]          36,864
      BatchNorm2d-51          [-1, 160, 28, 28]             320
             ReLU-52          [-1, 160, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          20,480
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56           [-1, 32, 28, 28]          36,864
      BatchNorm2d-57          [-1, 192, 28, 28]             384
             ReLU-58          [-1, 192, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          24,576
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62           [-1, 32, 28, 28]          36,864
      BatchNorm2d-63          [-1, 224, 28, 28]             448
             ReLU-64          [-1, 224, 28, 28]               0
           Conv2d-65          [-1, 128, 28, 28]          28,672
      BatchNorm2d-66          [-1, 128, 28, 28]             256
             ReLU-67          [-1, 128, 28, 28]               0
           Conv2d-68           [-1, 32, 28, 28]          36,864
      BatchNorm2d-69          [-1, 256, 28, 28]             512
             ReLU-70          [-1, 256, 28, 28]               0
           Conv2d-71          [-1, 128, 28, 28]          32,768
      BatchNorm2d-72          [-1, 128, 28, 28]             256
             ReLU-73          [-1, 128, 28, 28]               0
           Conv2d-74           [-1, 32, 28, 28]          36,864
      BatchNorm2d-75          [-1, 288, 28, 28]             576
             ReLU-76          [-1, 288, 28, 28]               0
           Conv2d-77          [-1, 128, 28, 28]          36,864
      BatchNorm2d-78          [-1, 128, 28, 28]             256
             ReLU-79          [-1, 128, 28, 28]               0
           Conv2d-80           [-1, 32, 28, 28]          36,864
      BatchNorm2d-81          [-1, 320, 28, 28]             640
             ReLU-82          [-1, 320, 28, 28]               0
           Conv2d-83          [-1, 128, 28, 28]          40,960
      BatchNorm2d-84          [-1, 128, 28, 28]             256
             ReLU-85          [-1, 128, 28, 28]               0
           Conv2d-86           [-1, 32, 28, 28]          36,864
      BatchNorm2d-87          [-1, 352, 28, 28]             704
             ReLU-88          [-1, 352, 28, 28]               0
           Conv2d-89          [-1, 128, 28, 28]          45,056
      BatchNorm2d-90          [-1, 128, 28, 28]             256
             ReLU-91          [-1, 128, 28, 28]               0
           Conv2d-92           [-1, 32, 28, 28]          36,864
      BatchNorm2d-93          [-1, 384, 28, 28]             768
             ReLU-94          [-1, 384, 28, 28]               0
           Conv2d-95          [-1, 128, 28, 28]          49,152
      BatchNorm2d-96          [-1, 128, 28, 28]             256
             ReLU-97          [-1, 128, 28, 28]               0
           Conv2d-98           [-1, 32, 28, 28]          36,864
      BatchNorm2d-99          [-1, 416, 28, 28]             832
            ReLU-100          [-1, 416, 28, 28]               0
          Conv2d-101          [-1, 128, 28, 28]          53,248
     BatchNorm2d-102          [-1, 128, 28, 28]             256
            ReLU-103          [-1, 128, 28, 28]               0
          Conv2d-104           [-1, 32, 28, 28]          36,864
     BatchNorm2d-105          [-1, 448, 28, 28]             896
            ReLU-106          [-1, 448, 28, 28]               0
          Conv2d-107          [-1, 128, 28, 28]          57,344
     BatchNorm2d-108          [-1, 128, 28, 28]             256
            ReLU-109          [-1, 128, 28, 28]               0
          Conv2d-110           [-1, 32, 28, 28]          36,864
     BatchNorm2d-111          [-1, 480, 28, 28]             960
            ReLU-112          [-1, 480, 28, 28]               0
          Conv2d-113          [-1, 128, 28, 28]          61,440
     BatchNorm2d-114          [-1, 128, 28, 28]             256
            ReLU-115          [-1, 128, 28, 28]               0
          Conv2d-116           [-1, 32, 28, 28]          36,864
     BatchNorm2d-117          [-1, 512, 28, 28]           1,024
            ReLU-118          [-1, 512, 28, 28]               0
          Conv2d-119          [-1, 256, 28, 28]         131,072
       AvgPool2d-120          [-1, 256, 14, 14]               0
     BatchNorm2d-121          [-1, 256, 14, 14]             512
            ReLU-122          [-1, 256, 14, 14]               0
          Conv2d-123          [-1, 128, 14, 14]          32,768
     BatchNorm2d-124          [-1, 128, 14, 14]             256
            ReLU-125          [-1, 128, 14, 14]               0
          Conv2d-126           [-1, 32, 14, 14]          36,864
     BatchNorm2d-127          [-1, 288, 14, 14]             576
            ReLU-128          [-1, 288, 14, 14]               0
          Conv2d-129          [-1, 128, 14, 14]          36,864
     BatchNorm2d-130          [-1, 128, 14, 14]             256
            ReLU-131          [-1, 128, 14, 14]               0
          Conv2d-132           [-1, 32, 14, 14]          36,864
     BatchNorm2d-133          [-1, 320, 14, 14]             640
            ReLU-134          [-1, 320, 14, 14]               0
          Conv2d-135          [-1, 128, 14, 14]          40,960
     BatchNorm2d-136          [-1, 128, 14, 14]             256
            ReLU-137          [-1, 128, 14, 14]               0
          Conv2d-138           [-1, 32, 14, 14]          36,864
     BatchNorm2d-139          [-1, 352, 14, 14]             704
            ReLU-140          [-1, 352, 14, 14]               0
          Conv2d-141          [-1, 128, 14, 14]          45,056
     BatchNorm2d-142          [-1, 128, 14, 14]             256
            ReLU-143          [-1, 128, 14, 14]               0
          Conv2d-144           [-1, 32, 14, 14]          36,864
     BatchNorm2d-145          [-1, 384, 14, 14]             768
            ReLU-146          [-1, 384, 14, 14]               0
          Conv2d-147          [-1, 128, 14, 14]          49,152
     BatchNorm2d-148          [-1, 128, 14, 14]             256
            ReLU-149          [-1, 128, 14, 14]               0
          Conv2d-150           [-1, 32, 14, 14]          36,864
     BatchNorm2d-151          [-1, 416, 14, 14]             832
            ReLU-152          [-1, 416, 14, 14]               0
          Conv2d-153          [-1, 128, 14, 14]          53,248
     BatchNorm2d-154          [-1, 128, 14, 14]             256
            ReLU-155          [-1, 128, 14, 14]               0
          Conv2d-156           [-1, 32, 14, 14]          36,864
     BatchNorm2d-157          [-1, 448, 14, 14]             896
            ReLU-158          [-1, 448, 14, 14]               0
          Conv2d-159          [-1, 128, 14, 14]          57,344
     BatchNorm2d-160          [-1, 128, 14, 14]             256
            ReLU-161          [-1, 128, 14, 14]               0
          Conv2d-162           [-1, 32, 14, 14]          36,864
     BatchNorm2d-163          [-1, 480, 14, 14]             960
            ReLU-164          [-1, 480, 14, 14]               0
          Conv2d-165          [-1, 128, 14, 14]          61,440
     BatchNorm2d-166          [-1, 128, 14, 14]             256
            ReLU-167          [-1, 128, 14, 14]               0
          Conv2d-168           [-1, 32, 14, 14]          36,864
     BatchNorm2d-169          [-1, 512, 14, 14]           1,024
            ReLU-170          [-1, 512, 14, 14]               0
          Conv2d-171          [-1, 128, 14, 14]          65,536
     BatchNorm2d-172          [-1, 128, 14, 14]             256
            ReLU-173          [-1, 128, 14, 14]               0
          Conv2d-174           [-1, 32, 14, 14]          36,864
     BatchNorm2d-175          [-1, 544, 14, 14]           1,088
            ReLU-176          [-1, 544, 14, 14]               0
          Conv2d-177          [-1, 128, 14, 14]          69,632
     BatchNorm2d-178          [-1, 128, 14, 14]             256
            ReLU-179          [-1, 128, 14, 14]               0
          Conv2d-180           [-1, 32, 14, 14]          36,864
     BatchNorm2d-181          [-1, 576, 14, 14]           1,152
            ReLU-182          [-1, 576, 14, 14]               0
          Conv2d-183          [-1, 128, 14, 14]          73,728
     BatchNorm2d-184          [-1, 128, 14, 14]             256
            ReLU-185          [-1, 128, 14, 14]               0
          Conv2d-186           [-1, 32, 14, 14]          36,864
     BatchNorm2d-187          [-1, 608, 14, 14]           1,216
            ReLU-188          [-1, 608, 14, 14]               0
          Conv2d-189          [-1, 128, 14, 14]          77,824
     BatchNorm2d-190          [-1, 128, 14, 14]             256
            ReLU-191          [-1, 128, 14, 14]               0
          Conv2d-192           [-1, 32, 14, 14]          36,864
     BatchNorm2d-193          [-1, 640, 14, 14]           1,280
            ReLU-194          [-1, 640, 14, 14]               0
          Conv2d-195          [-1, 128, 14, 14]          81,920
     BatchNorm2d-196          [-1, 128, 14, 14]             256
            ReLU-197          [-1, 128, 14, 14]               0
          Conv2d-198           [-1, 32, 14, 14]          36,864
     BatchNorm2d-199          [-1, 672, 14, 14]           1,344
            ReLU-200          [-1, 672, 14, 14]               0
          Conv2d-201          [-1, 128, 14, 14]          86,016
     BatchNorm2d-202          [-1, 128, 14, 14]             256
            ReLU-203          [-1, 128, 14, 14]               0
          Conv2d-204           [-1, 32, 14, 14]          36,864
     BatchNorm2d-205          [-1, 704, 14, 14]           1,408
            ReLU-206          [-1, 704, 14, 14]               0
          Conv2d-207          [-1, 128, 14, 14]          90,112
     BatchNorm2d-208          [-1, 128, 14, 14]             256
            ReLU-209          [-1, 128, 14, 14]               0
          Conv2d-210           [-1, 32, 14, 14]          36,864
     BatchNorm2d-211          [-1, 736, 14, 14]           1,472
            ReLU-212          [-1, 736, 14, 14]               0
          Conv2d-213          [-1, 128, 14, 14]          94,208
     BatchNorm2d-214          [-1, 128, 14, 14]             256
            ReLU-215          [-1, 128, 14, 14]               0
          Conv2d-216           [-1, 32, 14, 14]          36,864
     BatchNorm2d-217          [-1, 768, 14, 14]           1,536
            ReLU-218          [-1, 768, 14, 14]               0
          Conv2d-219          [-1, 128, 14, 14]          98,304
     BatchNorm2d-220          [-1, 128, 14, 14]             256
            ReLU-221          [-1, 128, 14, 14]               0
          Conv2d-222           [-1, 32, 14, 14]          36,864
     BatchNorm2d-223          [-1, 800, 14, 14]           1,600
            ReLU-224          [-1, 800, 14, 14]               0
          Conv2d-225          [-1, 128, 14, 14]         102,400
     BatchNorm2d-226          [-1, 128, 14, 14]             256
            ReLU-227          [-1, 128, 14, 14]               0
          Conv2d-228           [-1, 32, 14, 14]          36,864
     BatchNorm2d-229          [-1, 832, 14, 14]           1,664
            ReLU-230          [-1, 832, 14, 14]               0
          Conv2d-231          [-1, 128, 14, 14]         106,496
     BatchNorm2d-232          [-1, 128, 14, 14]             256
            ReLU-233          [-1, 128, 14, 14]               0
          Conv2d-234           [-1, 32, 14, 14]          36,864
     BatchNorm2d-235          [-1, 864, 14, 14]           1,728
            ReLU-236          [-1, 864, 14, 14]               0
          Conv2d-237          [-1, 128, 14, 14]         110,592
     BatchNorm2d-238          [-1, 128, 14, 14]             256
            ReLU-239          [-1, 128, 14, 14]               0
          Conv2d-240           [-1, 32, 14, 14]          36,864
     BatchNorm2d-241          [-1, 896, 14, 14]           1,792
            ReLU-242          [-1, 896, 14, 14]               0
          Conv2d-243          [-1, 128, 14, 14]         114,688
     BatchNorm2d-244          [-1, 128, 14, 14]             256
            ReLU-245          [-1, 128, 14, 14]               0
          Conv2d-246           [-1, 32, 14, 14]          36,864
     BatchNorm2d-247          [-1, 928, 14, 14]           1,856
            ReLU-248          [-1, 928, 14, 14]               0
          Conv2d-249          [-1, 128, 14, 14]         118,784
     BatchNorm2d-250          [-1, 128, 14, 14]             256
            ReLU-251          [-1, 128, 14, 14]               0
          Conv2d-252           [-1, 32, 14, 14]          36,864
     BatchNorm2d-253          [-1, 960, 14, 14]           1,920
            ReLU-254          [-1, 960, 14, 14]               0
          Conv2d-255          [-1, 128, 14, 14]         122,880
     BatchNorm2d-256          [-1, 128, 14, 14]             256
            ReLU-257          [-1, 128, 14, 14]               0
          Conv2d-258           [-1, 32, 14, 14]          36,864
     BatchNorm2d-259          [-1, 992, 14, 14]           1,984
            ReLU-260          [-1, 992, 14, 14]               0
          Conv2d-261          [-1, 128, 14, 14]         126,976
     BatchNorm2d-262          [-1, 128, 14, 14]             256
            ReLU-263          [-1, 128, 14, 14]               0
          Conv2d-264           [-1, 32, 14, 14]          36,864
     BatchNorm2d-265         [-1, 1024, 14, 14]           2,048
            ReLU-266         [-1, 1024, 14, 14]               0
          Conv2d-267          [-1, 512, 14, 14]         524,288
       AvgPool2d-268            [-1, 512, 7, 7]               0
     BatchNorm2d-269            [-1, 512, 7, 7]           1,024
            ReLU-270            [-1, 512, 7, 7]               0
          Conv2d-271            [-1, 128, 7, 7]          65,536
     BatchNorm2d-272            [-1, 128, 7, 7]             256
            ReLU-273            [-1, 128, 7, 7]               0
          Conv2d-274             [-1, 32, 7, 7]          36,864
     BatchNorm2d-275            [-1, 544, 7, 7]           1,088
            ReLU-276            [-1, 544, 7, 7]               0
          Conv2d-277            [-1, 128, 7, 7]          69,632
     BatchNorm2d-278            [-1, 128, 7, 7]             256
            ReLU-279            [-1, 128, 7, 7]               0
          Conv2d-280             [-1, 32, 7, 7]          36,864
     BatchNorm2d-281            [-1, 576, 7, 7]           1,152
            ReLU-282            [-1, 576, 7, 7]               0
          Conv2d-283            [-1, 128, 7, 7]          73,728
     BatchNorm2d-284            [-1, 128, 7, 7]             256
            ReLU-285            [-1, 128, 7, 7]               0
          Conv2d-286             [-1, 32, 7, 7]          36,864
     BatchNorm2d-287            [-1, 608, 7, 7]           1,216
            ReLU-288            [-1, 608, 7, 7]               0
          Conv2d-289            [-1, 128, 7, 7]          77,824
     BatchNorm2d-290            [-1, 128, 7, 7]             256
            ReLU-291            [-1, 128, 7, 7]               0
          Conv2d-292             [-1, 32, 7, 7]          36,864
     BatchNorm2d-293            [-1, 640, 7, 7]           1,280
            ReLU-294            [-1, 640, 7, 7]               0
          Conv2d-295            [-1, 128, 7, 7]          81,920
     BatchNorm2d-296            [-1, 128, 7, 7]             256
            ReLU-297            [-1, 128, 7, 7]               0
          Conv2d-298             [-1, 32, 7, 7]          36,864
     BatchNorm2d-299            [-1, 672, 7, 7]           1,344
            ReLU-300            [-1, 672, 7, 7]               0
          Conv2d-301            [-1, 128, 7, 7]          86,016
     BatchNorm2d-302            [-1, 128, 7, 7]             256
            ReLU-303            [-1, 128, 7, 7]               0
          Conv2d-304             [-1, 32, 7, 7]          36,864
     BatchNorm2d-305            [-1, 704, 7, 7]           1,408
            ReLU-306            [-1, 704, 7, 7]               0
          Conv2d-307            [-1, 128, 7, 7]          90,112
     BatchNorm2d-308            [-1, 128, 7, 7]             256
            ReLU-309            [-1, 128, 7, 7]               0
          Conv2d-310             [-1, 32, 7, 7]          36,864
     BatchNorm2d-311            [-1, 736, 7, 7]           1,472
            ReLU-312            [-1, 736, 7, 7]               0
          Conv2d-313            [-1, 128, 7, 7]          94,208
     BatchNorm2d-314            [-1, 128, 7, 7]             256
            ReLU-315            [-1, 128, 7, 7]               0
          Conv2d-316             [-1, 32, 7, 7]          36,864
     BatchNorm2d-317            [-1, 768, 7, 7]           1,536
            ReLU-318            [-1, 768, 7, 7]               0
          Conv2d-319            [-1, 128, 7, 7]          98,304
     BatchNorm2d-320            [-1, 128, 7, 7]             256
            ReLU-321            [-1, 128, 7, 7]               0
          Conv2d-322             [-1, 32, 7, 7]          36,864
     BatchNorm2d-323            [-1, 800, 7, 7]           1,600
            ReLU-324            [-1, 800, 7, 7]               0
          Conv2d-325            [-1, 128, 7, 7]         102,400
     BatchNorm2d-326            [-1, 128, 7, 7]             256
            ReLU-327            [-1, 128, 7, 7]               0
          Conv2d-328             [-1, 32, 7, 7]          36,864
     BatchNorm2d-329            [-1, 832, 7, 7]           1,664
            ReLU-330            [-1, 832, 7, 7]               0
          Conv2d-331            [-1, 128, 7, 7]         106,496
     BatchNorm2d-332            [-1, 128, 7, 7]             256
            ReLU-333            [-1, 128, 7, 7]               0
          Conv2d-334             [-1, 32, 7, 7]          36,864
     BatchNorm2d-335            [-1, 864, 7, 7]           1,728
            ReLU-336            [-1, 864, 7, 7]               0
          Conv2d-337            [-1, 128, 7, 7]         110,592
     BatchNorm2d-338            [-1, 128, 7, 7]             256
            ReLU-339            [-1, 128, 7, 7]               0
          Conv2d-340             [-1, 32, 7, 7]          36,864
     BatchNorm2d-341            [-1, 896, 7, 7]           1,792
            ReLU-342            [-1, 896, 7, 7]               0
          Conv2d-343            [-1, 128, 7, 7]         114,688
     BatchNorm2d-344            [-1, 128, 7, 7]             256
            ReLU-345            [-1, 128, 7, 7]               0
          Conv2d-346             [-1, 32, 7, 7]          36,864
     BatchNorm2d-347            [-1, 928, 7, 7]           1,856
            ReLU-348            [-1, 928, 7, 7]               0
          Conv2d-349            [-1, 128, 7, 7]         118,784
     BatchNorm2d-350            [-1, 128, 7, 7]             256
            ReLU-351            [-1, 128, 7, 7]               0
          Conv2d-352             [-1, 32, 7, 7]          36,864
     BatchNorm2d-353            [-1, 960, 7, 7]           1,920
            ReLU-354            [-1, 960, 7, 7]               0
          Conv2d-355            [-1, 128, 7, 7]         122,880
     BatchNorm2d-356            [-1, 128, 7, 7]             256
            ReLU-357            [-1, 128, 7, 7]               0
          Conv2d-358             [-1, 32, 7, 7]          36,864
     BatchNorm2d-359            [-1, 992, 7, 7]           1,984
            ReLU-360            [-1, 992, 7, 7]               0
          Conv2d-361            [-1, 128, 7, 7]         126,976
     BatchNorm2d-362            [-1, 128, 7, 7]             256
            ReLU-363            [-1, 128, 7, 7]               0
          Conv2d-364             [-1, 32, 7, 7]          36,864
     BatchNorm2d-365           [-1, 1024, 7, 7]           2,048
            ReLU-366           [-1, 1024, 7, 7]               0
          Linear-367                 [-1, 1000]       1,025,000
================================================================
Total params: 7,978,856
Trainable params: 7,978,856
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.58
Params size (MB): 30.44
Estimated Total Size (MB): 325.59
----------------------------------------------------------------

2.3 训练模型

2.3.1 设置超参数

"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4  # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)  
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法

2.3.2 编写训练函数

"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)
        X, y = X.to(device), y.to(device) # 用于将数据存到显卡

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # 清空过往梯度
        loss.backward()  # 反向传播,计算当前梯度
        optimizer.step()  # 根据梯度更新网络参数

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

2.3.3 编写测试函数

"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()#统计预测正确的个数

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

2.3.4 正式训练

"""训练模型--正式训练"""
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_test_acc=0

for epoch in range(epochs):
    milliseconds_t1 = int(time.time() * 1000)

    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(lr_opt, epoch, learn_rate)

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, model_opt)
    scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = lr_opt.state_dict()['param_groups'][0]['lr']

    milliseconds_t2 = int(time.time() * 1000)
    template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')
    if best_test_acc < epoch_test_acc:
        best_test_acc = epoch_test_acc
        #备份最好的模型
        best_model = copy.deepcopy(model)
        template = (
            'Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E},Update the best model')
    print(
        template.format(epoch + 1, milliseconds_t2-milliseconds_t1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')

输出最高精度为Test_acc:100%

Epoch: 1, duration:74420ms, Train_acc:83.7%, Train_loss:0.902, Test_acc:85.8%,Test_loss:0.345, Lr:1.00E-04,Update the best model
Epoch: 2, duration:72587ms, Train_acc:86.4%, Train_loss:0.329, Test_acc:85.5%,Test_loss:0.343, Lr:1.00E-04
Epoch: 3, duration:72941ms, Train_acc:87.9%, Train_loss:0.292, Test_acc:89.2%,Test_loss:0.262, Lr:1.00E-04,Update the best model
Epoch: 4, duration:74155ms, Train_acc:88.8%, Train_loss:0.279, Test_acc:89.7%,Test_loss:0.248, Lr:1.00E-04,Update the best model
Epoch: 5, duration:75123ms, Train_acc:89.1%, Train_loss:0.265, Test_acc:89.0%,Test_loss:0.277, Lr:1.00E-04
Epoch: 6, duration:74381ms, Train_acc:89.6%, Train_loss:0.255, Test_acc:90.5%,Test_loss:0.249, Lr:1.00E-04,Update the best model
Epoch: 7, duration:73710ms, Train_acc:90.2%, Train_loss:0.243, Test_acc:84.1%,Test_loss:0.369, Lr:1.00E-04
Epoch: 8, duration:73995ms, Train_acc:90.7%, Train_loss:0.230, Test_acc:89.5%,Test_loss:0.250, Lr:1.00E-04
Epoch: 9, duration:73017ms, Train_acc:90.7%, Train_loss:0.223, Test_acc:89.3%,Test_loss:0.263, Lr:1.00E-04
Epoch:10, duration:73960ms, Train_acc:91.2%, Train_loss:0.224, Test_acc:91.6%,Test_loss:0.209, Lr:1.00E-04,Update the best model
Epoch:11, duration:74113ms, Train_acc:91.2%, Train_loss:0.219, Test_acc:90.5%,Test_loss:0.225, Lr:1.00E-04
Epoch:12, duration:73573ms, Train_acc:91.5%, Train_loss:0.213, Test_acc:88.5%,Test_loss:0.273, Lr:1.00E-04
Epoch:13, duration:73206ms, Train_acc:92.2%, Train_loss:0.202, Test_acc:85.1%,Test_loss:0.377, Lr:1.00E-04
Epoch:14, duration:73540ms, Train_acc:92.1%, Train_loss:0.195, Test_acc:91.2%,Test_loss:0.225, Lr:1.00E-04
Epoch:15, duration:73378ms, Train_acc:92.3%, Train_loss:0.192, Test_acc:87.6%,Test_loss:0.796, Lr:1.00E-04
Epoch:16, duration:73195ms, Train_acc:92.5%, Train_loss:0.187, Test_acc:92.5%,Test_loss:0.197, Lr:1.00E-04,Update the best model
Epoch:17, duration:73737ms, Train_acc:93.1%, Train_loss:0.174, Test_acc:92.7%,Test_loss:0.186, Lr:1.00E-04,Update the best model
Epoch:18, duration:73884ms, Train_acc:93.4%, Train_loss:0.171, Test_acc:80.6%,Test_loss:0.463, Lr:1.00E-04
Epoch:19, duration:73239ms, Train_acc:93.2%, Train_loss:0.168, Test_acc:91.2%,Test_loss:0.221, Lr:1.00E-04
Epoch:20, duration:73386ms, Train_acc:93.7%, Train_loss:0.159, Test_acc:92.5%,Test_loss:0.196, Lr:1.00E-04

2.4 结果可视化

"""训练模型--结果可视化"""
epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

2.4 指定图片进行预测

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片
    plt.show()

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
 
# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

"""指定图片进行预测"""
classes = list(total_data.class_to_idx)
# 预测训练集中的某张照片
predict_one_image(image_path=str(Path(data_dir) / "Cockatoo/001.jpg"),
                  model=model,
                  transform=train_transforms,
                  classes=classes)

在这里插入图片描述

输出

预测结果是:0

2.6 模型评估

"""模型评估"""
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
# 查看是否与我们记录的最高准确率一致
print(epoch_test_acc, epoch_test_loss)


输出

预测结果是:0
0.9268929503916449 0.185508520431107

3 知识点详解

3.1 nn.Sequential和nn.Module区别与选择

3.1.1 nn.Sequential

torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。 为了更容易理解,官方给出了一些案例:

# Sequential使用实例

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Sequential with OrderedDict使用实例
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

3.1.2 nn.Module

下面我们再用 Module 定义这个模型,下面是使用 Module 的模板

class 网络名字(nn.Module):
    def __init__(self, 一些定义的参数):
        super(网络名字, self).__init__()
        self.layer1 = nn.Linear(num_input, num_hidden)
        self.layer2 = nn.Sequential(...)
        ...

        定义需要用的网络层

    def forward(self, x): # 定义前向传播
        x1 = self.layer1(x)
        x2 = self.layer2(x)
        x = x1 + x2
        ...
        return x

注意的是,Module 里面也可以使用 Sequential,同时 Module 非常灵活,具体体现在 forward 中,如何复杂的操作都能直观的在 forward 里面执行

3.1.3 对比

为了方便比较,我们先用普通方法搭建一个神经网络。

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x
net1 = Net(1, 10, 1)

net2 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1)
)

打印这两个net

print(net1)
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
print(net2)
"""
Sequential (
  (0): Linear (1 -> 10)
  (1): ReLU ()
  (2): Linear (10 -> 1)
)
"""

我们可以发现,打印torch.nn.Sequential会自动加入激励函数,
在 net1 中, 激励函数实际上是在 forward() 功能中被调用的,没有在init中定义,所以在打印网络结构时不会有激励函数的信息.

解析源码,在torch.nn.Sequential中:

def forward(self, input):
    for module in self:
        input = module(input)
    return input

可以看到,torch.nn.Sequential的forward只是简单的顺序传播,操作性有限.

3.1.4 总结

使用torch.nn.Module,我们可以根据自己的需求改变传播过程,如RNN等
如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential即可。

参考链接:nn.Sequential和nn.Module区别与选择

3.2 python中OrderedDict的使用

很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),里面自带了一个子类

OrderedDict,实现了对字典对象中元素的排序。请看下面的实例:

import collections
print "Regular dictionary"
d={}
d['a']='A'
d['b']='B'
d['c']='C'
for k,v in d.items():
    print k,v

print "\nOrder dictionary"
d1 = collections.OrderedDict()
d1['a'] = 'A'
d1['b'] = 'B'
d1['c'] = 'C'
d1['1'] = '1'
d1['2'] = '2'
for k,v in d1.items():
    print k,v

输出:

Regular dictionary
a A
c C
b B

Order dictionary
a A
b B
c C
1 1
2 2

可以看到,同样是保存了ABC等几个元素,但是使用OrderedDict会根据放入元素的先后顺序进行排序。所以输出的值是排好序的。

OrderedDict对象的字典对象,如果其顺序不同那么Python也会把他们当做是两个不同的对象,请看事例:

print 'Regular dictionary:'
d2={}
d2['a']='A'
d2['b']='B'
d2['c']='C'

d3={}
d3['c']='C'
d3['a']='A'
d3['b']='B'

print d2 == d3

print '\nOrderedDict:'
d4=collections.OrderedDict()
d4['a']='A'
d4['b']='B'
d4['c']='C'

d5=collections.OrderedDict()
d5['c']='C'
d5['a']='A'
d5['b']='B'

print  d1==d2

输出:
Regular dictionary:
True

OrderedDict:
False

再看几个例子:

dd = {'banana': 3, 'apple':4, 'pear': 1, 'orange': 2}
#按key排序
kd = collections.OrderedDict(sorted(dd.items(), key=lambda t: t[0]))
print kd
#按照value排序
vd = collections.OrderedDict(sorted(dd.items(),key=lambda t:t[1]))
print vd

#输出
OrderedDict([('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1)])
OrderedDict([('pear', 1), ('orange', 2), ('banana', 3), ('apple', 4)])

总结

  数据量越大,训练时间越长,在DataLoader中增加num_workers,即增加线程数量,可能会导致内存不足出现,Couldn‘t open shared file mapping或者Out of memery的错误,可尝试减小num_corkers。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1362364.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

静态网页设计——宠物狗狗网(HTML+CSS+JavaScript)

前言 声明&#xff1a;该文章只是做技术分享&#xff0c;若侵权请联系我删除。&#xff01;&#xff01; 感谢大佬的视频&#xff1a; https://www.bilibili.com/video/BV1nk4y1X74M/?vd_source5f425e0074a7f92921f53ab87712357b 使用技术&#xff1a;HTMLCSSJS&#xff08;…

【Java EE初阶六】多线程案例(单例模式)

1. 单例模式 单例模式是一种设计模式&#xff0c;设计模式是我们必须要掌握的一个技能&#xff1b; 1.1 关于框架和设计模式 设计模式是软性的规定&#xff0c;且框架是硬性的规定&#xff0c;这些都是技术大佬已经设计好的&#xff1b; 一般来说设计模式有很多种&#xff0c;…

ROS-arbotix安装

方式一&#xff1a;命令行输入&#xff1a; sudo apt-get install ros-melodic-arbotix如果ROS为其他版本&#xff0c;可将melodic替换为对应版本。 方式二&#xff1a; 先从 github 下载源码&#xff0c;然后调用 catkin_make 编译 git clone https://github.com/vanadiumla…

如何把硬盘(分区)一分为二?重装系统的小伙伴不可不看

注意事项&#xff1a;本教程操作不当会导致数据丢失 请谨慎操作 请谨慎操作 请谨慎操作 前言 相信各位小伙伴都会切土豆吧&#xff0c;本教程就是教大家如何切土豆切得好的教程。 啊哈哈哈&#xff0c;开玩笑的。 比如你有一个D盘是200GB&#xff0c;想要把它变成两个100G…

码农的周末日常---2024/1/6

上周总结 按照规划进行开发&#xff0c;处事不惊&#xff0c;稳稳前行 2024.1.6 天气晴 温度适宜 AM 睡觉前不建议做决定是真的&#xff0c;昨天想着睡到中午&#xff0c;今天九点多醒了&#xff0c;得了&#xff0c;不想睡了 日常三连吧&#xff0c;…

商品加入购物篮后就一定能买到?

笔者今天在某网站购买商品时&#xff0c;将商品添加到购物车之后&#xff0c;耽搁了几分钟之后&#xff0c;就发现支付时提示库存不足&#xff0c;我想很多朋友肯定知道平台为啥这样设计&#xff1f;但是我还是想借着这件事来简单介绍一下&#xff0c;有不足之处&#xff0c;还…

HPM6750开发笔记《DMA接收和发送数据UART例程深度解析》

目录 概述&#xff1a; 端口设置&#xff1a; 代码分析&#xff1a; 运行现象&#xff1a; 概述&#xff1a; DMA&#xff08;Direct Memory Access&#xff09;是一种计算机系统中的数据传输技术&#xff0c;它允许数据在不经过中央处理器&#xff08;CPU&#xff09;的直…

虾皮长尾词工具:如何使用关键词工具优化Shopee产品的长尾关键词

在Shopee&#xff08;虾皮&#xff09;平台上&#xff0c;卖家们都希望能够吸引更多的潜在买家&#xff0c;提高产品的曝光率和转化率。而要实现这一目标&#xff0c;了解和使用长尾关键词是非常重要的。本文将介绍长尾关键词的定义、重要性以及如何使用关键词工具来优化Shopee…

【LeetCode:11. 盛最多水的容器 | 双指针】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

数字孪生在增强现实(AR)中的应用

数字孪生在增强现实&#xff08;Augmented Reality&#xff0c;AR&#xff09;中的应用可以提供更丰富、交互性更强的现实世界增强体验。以下是数字孪生在AR中的一些应用&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff…

【C++】STL 算法 ④ ( 函数对象与谓词 | 一元函数对象 | “ 谓词 “ 概念 | 一元谓词 | find_if 查找算法 | 一元谓词示例 )

文章目录 一、函数对象与谓词1、一元函数对象2、" 谓词 " 概念3、find_if 查找算法 二、一元谓词示例1、代码示例 - 一元谓词示例2、执行结果 一、函数对象与谓词 1、一元函数对象 " 函数对象 " 是通过 重载 函数调用操作符 () 实现的 operator() , 函数对…

VMware 安装 macOS虚拟机(附工具包)

VMware 安装 macOS虚拟机&#xff0c;在Windows上体验苹果macOS系统&#xff01; 安装教程&#xff1a;VMware 安装 macOS虚拟机VMware Workstation Pro 是一款强大的虚拟机软件&#xff0c;可让您在 Windows 电脑上运行 macOS 系统。只需简单几步操作&#xff0c;即可轻松安装…

什么是Alibaba Cloud Linux?完全兼容CentOS,详细介绍

Alibaba Cloud Linux是基于龙蜥社区OpenAnolis龙蜥操作系统Anolis OS的阿里云发行版&#xff0c;针对阿里云服务器ECS做了大量深度优化&#xff0c;Alibaba Cloud Linux由阿里云官方免费提供长期支持和维护LTS&#xff0c;Alibaba Cloud Linux完全兼容CentOS/RHEL生态和操作方式…

企业档案集中式管理什么意思?企业档案集中式管理的特点

企业档案集中式管理是指将企业所有的档案资料集中存放、管理和维护的一种方式。在集中式管理中&#xff0c;企业将所有的档案资料集中存放在一个统一的档案中心或档案馆中&#xff0c;通过专门的档案管理人员负责对档案资料进行分类、整理、存储和检索&#xff0c;确保档案资料…

【Java】实验七

实验要求: 1、编写有复制文本文件功能的记事本程序,界面参考下图,窗口中放置文本区(JTextArea)组件: 当点击“复制文件”菜单项后,出现下面的文件对话框,选择要复制的文件。 点击“打开”按钮后,将选中的文件显示在记事本的文本区,并将该文件复制到同一目录下的“cop…

刚学C/C++,使用的是CLion,想要在同一个项目里面运行多个相互独立脚本?

前言&#xff1a; 正常来说&#xff0c;一般一个项目只会有一个程序入口点。C和C程序的入口点是main函数。在一个项目中&#xff0c;只能有一个main函数&#xff0c;否则编译器会不知道从哪个main函数开始执行。 但是&#xff0c;作为初学者&#xff0c;我就是想用CLio…

Apple M2 Pro芯片 + docker-compose up + mysql、elasticsearch pull失败问题的解法

背景 &#xff08;1&#xff09;从github上git clone了一个基于Spring Boot的Java项目&#xff0c;查看readme&#xff0c;发现要在项目的根目录下&#xff0c;执行“docker-compose up”。&#xff08;2&#xff09;执行“docker-compose up”的前提是&#xff0c;在macos上要…

vue实现代码编辑器,无坑使用CodeMirror

vue实现代码编辑器,无坑使用CodeMirror vue实现代码编辑器,使用codemirror5 坑&#xff1a;本打算cv一下网上的&#xff0c;结果发现网上的博客教程都是错的&#xff0c;而且博客已经是几年前的了&#xff0c;我重新看了github上的&#xff0c;发现安装的命令都已经不一样了。我…

sublime text 3 分屏和关闭分屏

有时候需要编辑多个地方的代码&#xff0c;开多个编辑器又太麻烦&#xff0c;那么Sublime自带的分屏快捷键可以解决烦恼。 Altshift2 分为2列 Altshift3 分为3列 Altshift4 分为4列 Altshift5 分为2行2列 Altshift8 分为2行 Altshift9 分为3行 取消分屏&#xff1a;Alts…

【JAVA】OPENGL+TIFF格式图片,不同阈值旋转效果

有些科学研究领域会用到一些TIFF格式图片&#xff0c;由于是多张图片相互渐变&#xff0c;看起来比较有意思&#xff1a; import java.io.IOException; import java.text.SimpleDateFormat; import java.util.Date; import java.util.logging.*;/*** 可以自已定义日志打印格式…