Pytorch深度学习笔记(十一)卷积神经网络CNN

news2025/1/2 11:43:55

目录

1.概述

2.单通道卷积

3.多通道卷积

4.卷积层常见的参数

5.代码实现(卷积神经网络训练MNIST数据集)


推荐课程:10.卷积神经网络(基础篇)_哔哩哔哩_bilibili

1.概述

全连接神经网络:完全由线性层串行连接起来的网络。在全连接神经网络中,我们会把图像像素映射为一个较长的张量,这样会丧失图像像素之间原始的空间结构

卷积神经网络:保留图像像素之间原始的空间结构的神经网络。

convolution卷积:会保留图像像素之间原始的空间结构

subsampling下采样:缩小图像,提取特征值。

卷积神经网络分为feature extraction特征提取和classification分类两部分。

卷积层每次需要拿出一块像素块进行操作:

Input Channel输入通道数、Output Channel输出通道数、卷积核的大小(长和宽)

2.单通道卷积

可以把卷积核以及卷积核扫描到图像的区域当作两个矩阵,卷积核依次与扫描到的区域做数乘

单通道卷积运算的过程: 

1.将对应位置的元素相乘

2.将得到的所有乘积做一个求和

3.放到输出矩阵的对应位置

3.多通道卷积

有多少个输入通道,卷积核就要有多少个通道数,数乘后,将所有得到的输出矩阵(张量)相加,得到最终的输出矩阵(张量)。

三个卷积核通道构成一个卷积核。

在RGB模型上的表现:

如果有n个输入通道m个输出通道:

1.卷积核的通道数输入通道数一致

2.卷积核的总个数输出通道数一致,输出张量的维度输出通道数一致

3.注意:输出张量的长和宽与卷积核的长和宽一定相同

(牢记这三条卷积规则)

最后将输出的张量摞叠起来。

一个卷积核的大小 = n*卷积核的宽度*卷积核的高度

卷积核总的大小 = m*n*卷积核的宽度*卷积核的高度 

卷积网络的权重为(输出通道数m, 输入通道数n, 卷积核的宽度, 卷积核的高度 )

import torch
# 输入、输出通道数
in_channels, out_channels = 5, 10
# 图像的宽高
width, height = 100, 100
# 卷积核大小
kernel_size = 3
# batch大小
batch_size = 1

# 设置输入张量
input = torch.randn(batch_size,
                    in_channels,
                    width,
                    height)
# 卷积层
# 设置卷积核,kernel_size则默认卷积核为3*3
conv_layer = torch.nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=kernel_size)

# 返回输出
output = conv_layer(input)

print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)

4.卷积层常见的参数

padding(填充):在某些情况下,我们希望输出的张量变得更大一些,由于输出张量的大小与卷积核在输入图像上的扫描面积有关,因此我们只需在输入图像外围填充0即可。

stride(步长):卷积核在输入图像上的扫描间隔。可以有效降低输出张量的大小。Max pooling默认stride = 2。stride = 2意味着输出张量长和宽各减小一半。

注:输出矩阵的长/宽 = 输入矩阵的长/宽 - 卷积核长/宽 + 1

5.代码实现(卷积神经网络训练MNIST数据集)

Conv2d Layer为卷积层。Pooling Layer为最大池层,输出张量缩小一半。ReLU Layer为激活层。

参考之前的标黄的卷积规则,思考(1,28,28)—Conv2d(in = 1,out = 10,size = 5)—>(10,24,24)。

import torch
# 用于图像映射到矩阵中
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

batch_size = 64

#…1.准备数据………………………………………………………………………………………………………………………………………#
# 把像素值0-255转化为图像张量0-1
transform = transforms.Compose([
    # transforms.ToTensor()转化张量,Normalize映射到[0,1]之间
    transforms.ToTensor(),
    # (均值,标准差)
    transforms.Normalize((0.1307, ), (0.381, ))
])

# 训练集
train_dataset = datasets.MNIST(root="../dataset/mnist",
                           train=True,
                           download=False,
                           transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

test_dataset = datasets.MNIST(root="../dataset/mnist",
                           train=False,
                           download=False,
                           transform=transform)

test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          shuffle=False)


#…2.设计模型………………………………………………………………………………………………………………………………………#
# 继承torch.nn.Module,定义自己的计算模块,neural network
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 卷积层(输入通道数,输出通道数,卷积核大小(长/宽))
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        # 最大池层,池化,卷积核大小(长/宽)减小一半
        self.pooling = torch.nn.MaxPool2d(2)
        # 从320维降到10维
        self.fc = torch.nn.Linear(320, 10)

    def forward(self, x):
        # flatten data from (n,1,28,28) to (n, 784)
        batch_size = x.size(0)
        # 激活
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        # 调整张量维度为320
        x = x.view(batch_size, -1)  # -1 此处自动算出的是320
        # 降维
        x = self.fc(x)

        return x

#……3.构造损失函数和优化器………………………………………………………………………………………………………#
model = Net()
# 实例化损失函数,返回损失值
criterion = torch.nn.CrossEntropyLoss()
# 优化器,momentum冲量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

#……4.训练和测试……………………………………………………………………………………………………………………………#
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        # 1.准备数据
        inputs, labels = data
        optimizer.zero_grad()
        # 2.正向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 3.反向传播
        loss.backward()
        # 4.更新权重w
        optimizer.step()
        # 损失求和
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

def test():
    correct = 0
    total = 0
    # with torch.no_grad():内部代码不会再计算梯度
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # 内部权重已更新完毕,测试时直接使用即可
            outputs = model(images)
            # dim沿着第一个纬度(行)找最大值,返回(最大值,最大值下标)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            # 预测值与标签对比,正确则加1
            correct += (predicted == labels).sum().item()
    # 输出精确率
    print('Accuracy on test set: %d %%' % (100 * correct / total))
    return correct / total


if __name__ == '__main__':
    epoch_list = []
    acc_list = []

    for epoch in range(10):
        train(epoch)
        acc = test()
        epoch_list.append(epoch)
        acc_list.append(acc)

#…………5.绘图……………………………………………………………………#
    plt.plot(epoch_list, acc_list)
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.show()

 训练结果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/458496.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最佳实践|如何写出简单高效的 Flink SQL?

摘要:本文整理自阿里巴巴高级技术专家、Apache Flink PMC 贺小令,在 Flink Forward Asia 2022 生产实践专场的分享。本篇内容主要分为三个部分: 1. Flink SQL Insight 2. Best Practices 3. Future Works Tips:点击「阅读原文」查…

android之 Launcher改造仿桌面排版的效果

一,背景 1.1 新接手一个灯光控制项目,其页面和效果还是比交复杂的,其中一个功能就是仿苹果桌面来排版灯具,支持拖拽,分组,分页。 拖动图标的时候判断是否空白位置还是已经有占位了,有的话就把…

pikachu靶场-RCE

RCE漏洞概述 可以让攻击者直接向后台服务器远程注入操作系统命令或者代码,从而控制后台系统。 远程系统命令执行 命令执行漏洞(Command Execution)即黑客可以直接在Web应用中执行系统命令,从而获取敏感信息或者拿下shell权限 更…

Linux离线状态下安装cuda、cudnn、cudatoolkit

目录 1. 下载与安装说明2. CUDA安装3. cuDNN安装4. cudatoolkit安装5. 测试安装成功 1. 下载与安装说明 工具包下载地址 CUDA历史版本下载地址:https://developer.nvidia.com/cuda-toolkit-archivecuDNN历史版本下载地址:https://developer.nvidia.com/r…

logback日志框架集成方式

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、logback是什么?二、使用步骤1.使用方式控制台输出配置文件输出配置html输出配置定期删除配置方式 总结 前言 提示:这里可以添加本文…

C++每日一练:最长递增区间 阿波罗的魔力宝石 投篮

文章目录 前言一、最长递增区间二、阿波罗的魔力宝石三、投篮总结 前言 今天的题太简单,甚至 “最长递增区间” 和 “投篮” 就是一个问题。实在没事干,也给做了!直接上代码算了… 提示:以下是本篇文章正文内容 一、最长递增区间…

LSSANet:一种用于肺结节检测的长、短切片感知网络

文章目录 LSSANet: A Long Short Slice-Aware Network for Pulmonary Nodule Detection摘要方法Long Short Slice GroupingLong Short Slice-Aware Network 实验结果 LSSANet: A Long Short Slice-Aware Network for Pulmonary Nodule Detection 摘要 提出了一个长短片感知网…

【JAVA程序设计】(C00130)基于SpringBoot的社区养老医疗综合服务系统

基于SpringBoot的社区养老医疗综合服务系统 项目简介项目获取开发环境项目技术运行截图 项目简介 基于基于SpringBoot的社区养老医疗综合服务系统共分为三个角色:系统管理员、医生、用户 管理员角色包含以下功能: 用户管理、角色管理、部门管理、字典管…

【Java EE】-JavaScript详解

作者:学Java的冬瓜 博客主页:☀冬瓜的主页🌙 专栏:【JavaEE】 分享: 且视他人如盏盏鬼火,大胆地去走你的道路。——史铁生《病隙碎笔》 主要内容:HTML中引入JS的三种方式。JS语法分析,JS是动态弱…

【Linux高级篇】什么是shell脚本,什么是shell变量

目录 🍁什么是shell 🍂什么是shell脚本 🍂shell脚本能做什么 🍂学习shell需要哪些知识 🍂shell基本规范 🍂shell脚本五种运行方式 🍁shell变量 🍂变量命名规范 🍂shell变…

远程登录--SSH 你值得拥有

目录 一:SSH服务详解 1.什么是SSH 2.SSH服务认证类型 1)基于口令认证 2)基于密钥认证 3.SSH安装 二: 配置ssh服务端 1.ssh配置文件 2. ssh配置文件主要条目介绍 三:使用ssh客户端程序 1.使用ssh命令远程登录 ​2.使用scp远程复制 …

8086汇编之DIV除法指令

2023年4月22日,周六晚上。 今晚写汇编作业的时候,遇到了DIV指令,于是把学到的知识记录成一篇博客。此外,刚刚已经写了一篇关于MUL指令的博客了。 除数有8位和16位种,存放在寄存器或者内存中。 当除数为8位&#xff1a…

Linux离线状态下的Anaconda安装与Python环境创建

1 下载与安装说明 下载 下载地址:https://repo.anaconda.com/archive/版本:此处以版本为2020.11的anaconda作示例,其携带的python版本为3.8.5。下载:在上述链接查找下载 Anaconda3-2020.11-Linux-x86_64.sh 文件,也可以…

时序预测 | MATLAB实现WOA-LSTM鲸鱼算法优化长短期记忆网络时间序列预测

时序预测 | MATLAB实现WOA-LSTM鲸鱼算法优化长短期记忆网络时间序列预测 目录 时序预测 | MATLAB实现WOA-LSTM鲸鱼算法优化长短期记忆网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 MATLAB实现WOA-LSTM鲸鱼算法优化长短期记忆网络时间序列预测 基于鲸鱼…

图论-匈牙利算法学习

本文讲述的是匈牙利算法,即图论中寻找最大匹配的算法。解决的问题是从二分图中找到尽量多的匹配。 原题-华为-HJ28 素数伴侣 描述 题目描述 若两个正整数的和为素数,则这两个正整数称之为“素数伴侣”,如2和5、6和13,它们能应用…

【Vue】学习笔记-初始化脚手架

初始化脚手架 初始化脚手架说明具体步骤脚手架文件结构 初始化脚手架 说明 Vue脚手架是vue官方提供的标准化开发工具(开发平台)最新版本是4.x文档Vue CLI 具体步骤 如果下载缓慢请配置npm淘宝镜像 npm config set registry http://registry.npm.taoba…

有关态势感知(SA)的卷积思考

卷积是一种数学运算,其本质是将两个函数进行操作,其中一个函数是被称为卷积核或滤波器的小型矩阵,它在另一个函数上滑动并产生新的输出。在计算机视觉中,卷积通常用于图像处理和特征提取,它可以通过滤波器对输入图像进…

《Spring MVC》 第六章 MVC类型转换器、格式化器

前言 介绍MVC类型转换器、格式化器 1、使用场景 <form th:action"{/user/register}" method"post">用户名&#xff1a;<input type"text" name"userName"/><br/>密码&#xff1a;<input type"password&q…

对于Ubuntu服务器杀毒的一次记录

概述&#xff1a;叮咚&#xff01;您的主机有异常登录地&#xff0c;登录ip来自人类文明的标杆美丽国的加利福尼亚州&#xff0c;请注意排查。可恶的老美啊&#xff0c;又来入侵我华夏主机了&#xff0c;美帝亡我之心不死啊&#xff08;当然也有可能是境内中国人通过VPN操作境外…

【搭建博客】宝塔面板部署Typecho博客,并发布上线访问

目录 前言 1.安装环境 2.下载Typecho 3.创建站点 4.访问Typecho 5.安装cpolar 6.远程访问Typecho 7.固定远程访问地址 8.配置typecho 前言 Typecho是由type和echo两个词合成的&#xff0c;来自于开发团队的头脑风暴。Typecho基于PHP5开发&#xff0c;支持多种数据库&…