笔记小结:现代卷积神经网络之批量归一化

news2024/11/15 17:17:29

本文为李沐老师《动手学深度学习》笔记小结,用于个人复习并记录学习历程,适用于初学者

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。 本节将介绍批量规范化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。

从零开始实现

张量的批量规范化函数
import torch
from torch import nn


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data
批量规范化层
class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
使用批量规范化层作用于LeNet

批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))
训练
准备工作

和之前多篇文章中提到的一样,不再赘述,只给出代码

from IPython import display
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt

def load_data_fashion_mnist(batch_size, resize=None): 
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=0)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=0)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
 
def get_dataloader_workers():  
    """使用4个进程来读取数据"""
    return 4

batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1) #找出输入张量(tensor)中最大值的索引
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n
 
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
 
    def reset(self):
        self.data = [0.0] * len(self.data)
 
    def __getitem__(self, idx):
        return self.data[idx]

import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
 
def use_svg_display(): 
    """使⽤svg格式在Jupyter中显⽰绘图"""
    backend_inline.set_matplotlib_formats('svg')
 
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
     """设置matplotlib的轴"""
     axes.set_xlabel(xlabel)
     axes.set_ylabel(ylabel)
     axes.set_xscale(xscale)
     axes.set_yscale(yscale)
     axes.set_xlim(xlim)
     axes.set_ylim(ylim)
     if legend:
         axes.legend(legend)
     axes.grid()
 
class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        use_svg_display()
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts
 
    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

import time
class Timer:  #@save
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """启动计时器"""
        self.tik = time.time()

    def stop(self):
        """停止计时器并将时间记录在列表中"""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """返回平均时间"""
        return sum(self.times) / len(self.times)

    def sum(self):
        """返回时间总和"""
        return sum(self.times)

    def cumsum(self):
        """返回累计时间"""
        return np.array(self.times).cumsum().tolist()

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

def try_gpu(i=0):  #@save
    """如果存在,则返回gpu(i),否则返回cpu()"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')
训练

和以前一样,我们将在Fashion-MNIST数据集上训练网络。 这个代码与我们第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()
print(end - begin)

这个结果,对比当时不用批量归一化层的LeNet,训练的收敛速度快了许多,loss变小了,train acc提高了许多,但是test acc没有提高太多,出现了过拟合。 

简洁实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

下面,我们使用相同超参数来训练模型。 请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()

从结果可以看到,运行速度快了,并且过拟合也小了许多。 

小结

  • 在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量规范化在全连接层和卷积层的使用略有不同。
  • 批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。
  • 批量规范化有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1944320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3 + ts 报错:Parsing error: Unexpected token : eslint

报错:Parsing error: Unexpected token : eslint 解决: 在 .eslintrc.json 文件中加入 "parser": "babel/eslint-parser"配置 “parser”: “babel/eslint-parser” 告诉 ESLint 在检查代码之前,先使用 Babel 的解析器…

lua 游戏架构 之 游戏 AI (三)ai_attack

这段Lua脚本定义了一个名为 ai_attack 的类,继承自 ai_base 类。 lua 游戏架构 之 游戏 AI (一)ai_base-CSDN博客文章浏览阅读119次。定义了一套接口和属性,可以基于这个基础类派生出具有特定行为的AI组件。例如,可以…

深度学习:引领未来的人工智能技术(比喻)

深度学习:引领未来的人工智能技术 引言 随着人工智能(AI)的快速发展,深度学习(Deep Learning)作为其中最具革命性的技术之一,正在改变着各个行业。从自动驾驶到医疗诊断,从自然语言…

python—selenium爬虫

文章目录 Selenium与Requests对比一、工作原理二、功能特点三、性能表现 下载对应驱动1.首先我们需要打开edge浏览器,打开设置,找到“关于Microsoft Edge”,点击进入查看浏览器版本。2.查找版本之后,搜索edge驱动下载,…

Unity UGUI 之 ScrollBar与ScrollView

本文仅作学习笔记与交流,不作任何商业用途 本文包括但不限于unity官方手册,唐老狮,麦扣教程知识,引用会标记,如有不足还请斧正 1.什么是ScrollBar 滚动块:Unity - Manual: Scrollbar 2.重要参数 该笔记来源…

MMROTATE的混淆矩阵confusion matrix生成

mmdetection中加入了混淆矩阵生成并可视化的功能,具体的代码在tools/analysis_tools/confusion_matrix.py。 mmrotate由于主流遥感数据集中的DOTA数据集标注格式问题,做了一些修改,所以我们如果是做遥感图像检测的Dota数据集的混淆矩阵&…

Elasticsearch介绍、安装以及IK分词器 --学习笔记

Elasticsearch 是什么? Elasticsearch 是一个高度可扩展的开源全文搜索和分析引擎。它允许你以极快的速度存储、搜索和分析大量数据。Elasticsearch 基于 Apache Lucene 构建,提供了一个分布式、多租户能力的全文搜索引擎,带有 HTTP web 接口…

centos系统mysql数据库压缩备份与恢复

文章目录 压缩备份一、安装 xtrabackup二、数据库中创建一些数据三、进行压缩备份四、模拟数据丢失,删库五、解压缩六、数据恢复 压缩备份 一、安装 xtrabackup 确保已经安装了 xtrabackup 工具。可以从 Percona 的官方网站 获取并安装适合你系统的版本。 # 添加…

2024在线PHP加密网站源码

源码介绍 2024在线PHP加密网站源码 更新内容: 1.加强算法强度 2.优化模版UI 加密后的代码示例截图 源码下载 https://download.csdn.net/download/huayula/89568335

学习日志:JVM垃圾回收

文章目录 前言一、堆空间的基本结构二、内存分配和回收原则对象优先在 Eden 区分配大对象直接进入老年代长期存活的对象将进入老年代主要进行 gc 的区域空间分配担保 三、死亡对象判断方法引用计数法可达性分析算法引用类型总结1.强引用(StrongReference…

Python+Flask+MySQL+日线指数与情感指数预测的股票信息查询系统【附源码,运行简单】

PythonFlaskMySQL日线指数与情感指数预测的股票信息查询系统【附源码,运行简单】 总览 1、《股票信息查询系统》1.1 方案设计说明书设计目标工具列表 2、详细设计2.1 登录2.2 程序主页面2.3 个人中心界面2.4 基金详情界面2.5 其他功能贴图 总览 自己做的项目&#…

【教程】在 VS Code 集成终端中解决 Node.js 环境变量识别问题

背景 外部命令,如 node 在外部的终端中可以识别到,但是在vscode的终端中不能识别到错误:node : 无法将“node”项识别为 cmdlet、函数、脚本文件或可运行程序的名称也就是环境变量其实是有 node 的,但是 vscode 的集成终端中就是…

【Django】在vscode中新建Django应用并新增路由

文章目录 打开一个终端输入新建app命令在app下的views.py内写一个视图app路由引入该视图项目路由引入app路由项目(settings.py)引入app(AntappConfig配置类)运行项目 打开一个终端 输入新建app命令 python manage.py startapp antapp在app下的views.py内…

let、var、const 的区别 --js面试题

作用域 ES5中的作用域有:全局作用域、函数作用域,ES6中新增了块级作用域。块作用域由 { } 包括,if 语句和 for 语句里面的 { } 也属于块作用域。 var 1.没有块级作用域的概念,但具有函数全局作用域、函数作用域的概念 {var a …

交易积累-MACD

MACD(Moving Average Convergence Divergence,即移动平均收敛发散指标)是由Gerald Appel于1970年代后期发明的一种趋势跟踪动量指标。MACD显示了两个不同周期(通常是较长和较短周期)的移动平均线之间的差异。这个指标旨…

PCIe 以太网芯片 RTL8125B 的 spec 和 Linux driver 分析备忘

1,下载 RTL8125B driver 下载页: https://www.realtek.com/Download/List?cate_id584 2,RTL8125B datasheet下载 下载页: https://file.elecfans.com/web2/M00/44/D8/poYBAGKHVriAHnfWADAT6T6hjVk715.pdf3, 编译driver 解压: $ tar xj…

Android APP CameraX应用(02)预览流程

说明:camera子系统 系列文章针对Android12.0系统,主要针对 camerax API框架进行解读。 1 CameraX简介 1.1 CameraX 预览流程简要解读 CameraX 是 Android 上的一个 Jetpack 支持库,它提供了一套统一的 API 来处理相机功能,无论 …

Redis-10大数据类型理解与测试

Redis10大数据类型 我要打10个1.redis字符串(String)2.redis列表(List)3.redis哈希表(Hash)4.redis集合(Set)5.redis有序集合(ZSet)6.redis地理空间(GEO)7.redis基数统计(HyperLogLog)8.redis位图(bitmap)9.redis位域(bitfield)10.redis流(Stream) 官网地址Redis 键(key)常用案…

OpenCV图像滤波(1)双边滤波函数bilateralFilter的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 功能描述 bilateralFilter是图像处理和计算机视觉领域中的一种高级图像滤波技术,特别设计用于在去除噪声的同时保留图像的边缘和细节。相比于传…

NSSCTF-2021年SWPU联合新生赛

[SWPUCTF 2021 新生赛]finalrce 这道题目考察tee命令和转义符\ 这题主要是,遇到一种新的符号,"\"—转义符。我理解的作用就是在一些控制字符被过滤的时候,可以用转义符,让控制符失去原本的含义,变为字面量…