【PyTorch深度学习实践】09_卷积神经网络基础

news2024/12/23 1:51:45

文章目录

    • 1.卷积操作
      • 1.1 卷积操作
      • 1.2 padding-填充
      • 1.3 stride-步长
      • 1.4 pooling-池化
      • 1.5 基础版CNN代码示例
      • 1.6 完整CNN代码示例

1.卷积操作

卷积神经网络概览
在这里插入图片描述

1.1 卷积操作

在这里插入图片描述

输入通道数=卷积核通道数,卷积核个数=输出通道数

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

1.2 padding-填充

padding是为了让源图像最外一圈或多圈像素(取决于kernel的尺寸),能够被卷积核中心取到。
这里有个描述很重要:想要使源图像(1,1)的位置作为第一个与kernel中心重合,参与计算的像素,想想看padding需要扩充多少层,这样就很好计算了
在这里插入图片描述
在这里插入图片描述

1.3 stride-步长

stride操作指的是每次kernel窗口滑动的步长,默认值是1
在这里插入图片描述
在这里插入图片描述

1.4 pooling-池化

以最大池化为例
在这里插入图片描述
卷积神经网络示例
在这里插入图片描述

1.5 基础版CNN代码示例

import torch
in_channels, out_channels=5, 10
width, height = 100, 100    # 图像大小
kernel_size = 3     # 卷积核大小
batch_size = 1  # 所有输入pytorch中的data必须是小批量

input = torch.randn(
    batch_size,  # B 表明是小批量第几个
    in_channels,  # n 输入通道数
    width,      # W 宽
    height      # H 高
)

conv_layer = torch.nn.Conv2d(
    in_channels,    # 输入通道数
    out_channels,   # 输出通道数
    kernel_size=kernel_size  # 内核大小
)

output = conv_layer(input)

print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)

1.6 完整CNN代码示例

import matplotlib.pyplot as plt
import torch

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn.functional as F
import torch.optim as optim  # (可有可无)



batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform=transform
                               )
train_loader = DataLoader(dataset=train_dataset,
                          shuffle=True,
                          batch_size=batch_size,
                          )

test_dataset = datasets.MNIST(root='../dataset/mnist/',
                              train=False,
                              download=True,
                              transform=transform)

test_loader = DataLoader(dataset=test_dataset,
                         shuffle=False,
                         batch_size=batch_size,
                         )


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__() # 卷积层
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)   # 卷积层1
        self.conv2 = torch.nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5)  # 卷积层2
        self.pooling = torch.nn.MaxPool2d(2)    # 池化层,没有涉及到权重,实例化一个就可以
        self.fc = torch.nn.Linear(320, 10)

    def forward(self, x):
        batch_size = x.size(0)  # 用于求维度
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size, -1)     # 用于变成全连接
        x = self.fc(x)
        return x


model = Net()
# 选择是GPU CPU
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 表示把整个模型涉及到的权重迁移到GPU
#model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def train(epoch):
    running_loss = 0.0
    for batch_index, (inputs, labels) in enumerate(train_loader, 0):
        # 迁移至GPU(模型数据要在同一块显卡上)
        # inputs, labels = inputs.to(device), labels.to(device)
        y_hat = model(inputs)
        loss = criterion(y_hat, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_size % 10 == 9:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_index + 1, running_loss / 300))


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in test_loader:
            # images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, pred = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (pred == labels).sum().item()
    print('accuracy on test set: %d %%' % (100 * correct / total))
    return correct / total


if __name__ == '__main__':
    epoch_list = []
    acc_list = []

    for epoch in range(10):
        train(epoch)
        acc = test()
        epoch_list.append(epoch)
        acc_list.append(acc)

    plt.plot(epoch_list, acc_list)
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.show()
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/165098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FPGA图像处理HLS实现三种图像缩放算法,线性插值、双线性插值、双三次插值,提供HLS工程和vivado工程源码

目录一、三种图像缩放算法介绍线性插值双线性插值双三次插值二、HLS实现线性插值图像缩放三、HLS实现双线性插值图像缩放四、HLS实现双三次插值图像缩放五、HLS在线仿真并导出IP六、其他FPGA型号HLS在线仿真并导出IP七、zynq7100开发板vivado工程八、上板调试验证九、福利&…

纪念QT可直接安装的离线版最后版本5.14.2

为什么说纪念呢?因为,这个版本之后再也没有可下载下来安装的版本了,因为我们以后再也没有这么方便了。为是很么说纪念呢?因为我们从QT还很柔弱的时候开始就是使用的离线版。 以前用c#来做组态,自定义控件开发起来也还…

基础知识一览2

这里写目录标题1.XML2.1 XML中的转义字符2.2 CDATA区2.3 如何去约束XMl:DTD2.3.1 xml文件内部引用DTD约束2.3.2 xml文件引用外部DTD约束2.3.3 xml文件引用公共DTD约束1.XML xml的文件后缀名是.xmlxml有且只有一个根标签xml的标签是尖括号包裹关键字成对出现的,有开…

如何做好banner设计(banner设计要点包括哪些)

网页设计的Banner作为表达网站价值或者传达广告信息的视觉主体,一直在根据网络环境的变化而变化着,从表现形式到尺寸大小,再到创意的多元化,因此更需要我们网页设计师们对其设计创意进行丰富和完善,才能真正达到宣传的…

Elasticsearch入门——Elasticsearch7.8.0版本和Kibana7.8.0版本的下载、安装(win10环境)

目录一、Elasticsearch7.8.0版本下载、安装1.1、官网下载地址1.2、下载步骤1.3、安装步骤(需要jdk11及以上版本支持)1.4、启动后,控制台中文乱码问题解决二、Node下载、安装(安装Kibana之前需要先安装Node)2.1、Node官网下载地址2.2、Node下载…

Linux文字处理和文件编辑(三)

1、Linux里的配置文件: /etc/bashrc文件:该配置文件在root用户下,权限很高。~/.bashrc文件:只有当前用户登录时才会执行该配置文件。每次打开终端,都会自动执行配置文件里的代码。比如,alias md‘mkdir’就…

《2022年终总结》

2022年终总结 笔者成为社畜的一年,整整打了一年工! 之前都说每年都有点变化,今年的变化可能就是更加懒散了,玩了更多的手机 就是运动的坚持更加多了,收入也增加了,哈哈! 其实今年的变化不大&am…

41. 【农产品溯源项目前后端Demo】后端目录结构

本节介绍下后端代码的目录结构。 1. 实现用户管理、菜单管理、角色管理、代码自动生成等服务,归结为系统管理,是若依框架提供的能力。 2. ruoyi-traces实现农产品溯源应用的代码,如果要引入其他Java包,修改本模块的pom.xml文件。 1)config包加载配置文件数据,配置文件路…

FPGA:IIC验证镁光EEPROM仿真模型(纯Verilog)

目录日常唠嗑一、程序设计二、镁光模型仿真验证三、testbench文件四、完整工程下载日常唠嗑 IIC协议这里就不赘述了,网上很多,这里推荐两个,可以看看【接口时序】6、IIC总线的原理与Verilog实现 ,还有IIC协议原理以及主机、从机Ve…

基于SpringBoot的车牌识别系统(附项目地址)

yx-image-recognition: 基于spring boot maven opencv 实现的图像深度学习Demo项目,包含车牌识别、人脸识别、证件识别等功能,贯穿样本处理、模型训练、图像处理、对象检测、对象识别等技术点 介绍 spring boot maven 实现的车牌识别及训练系统 基于…

3-1存储系统-存储器概述主存储器

文章目录一.存储器概述(一)存储器分类1.按在计算机中的作用(层次)分类2.按存储介质分类3.按存取方式分类4.按信息的可保存性分类(二)存储器的性能指标二.主存储器(一)基本组成1.译码…

6 个必知必会高效 Python 编程技巧

编写更好的Python 代码需要遵循Python 社区制定的最佳实践和指南。遵守这些标准可以使您的代码更具可读性、可维护性和效率。 本文将展示一些技巧,帮助您编写更好的 Python 代码 文章目录遵循 PEP 8 风格指南1.遵守 PEP 8 命名约定2. 使用描述性的和有意义的变量名…

读书笔记--- ggplot2:数据分析与图形艺术

最近看了这本书《ggplot2:数据分析与图形艺术》(第2版),实际上网页在线版本已经更新到第3版了(https://ggplot2-book.org/)。 这本书页数不多,但是整体还是值得阅读,不愧是Hadley W…

【Proteus仿真】【STM32单片机】酒精浓度检测系统设计

文章目录一、功能简介二、软件设计三、实验现象联系作者一、功能简介 本项目使用Proteus8仿真STM32单片机控制器,使用LCD1602显示模块、按键模块、LED和蜂鸣器、MQ-3酒精传感器模块等。 主要功能: 系统运行后,LCD1602显示酒精浓度值和阈值&…

插入排序.

根据找插入位置的方法分为: ①、顺序法定位插入位置——直接插入排序 ②、二分法定位插入位置——二分插入排序 ③、缩小增量多遍插入排序——希尔排序 一、直接插入排序(以升序为例) 先背模板! void insert_sort(int *a,int le…

远程服务器(恒源云)上使用NNI进行训练调参的详细流程

远程服务器(恒源云)上使用NNI进行训练调参的详细流程 一、环境配置 pip下载安装nni,(可使用豆瓣源,可快速下载,在安装命令后加 -i http://pypi.douban.com/simple --trusted-host pypi.douban.com&#x…

VUE|后台管理项目——动态路由权限管理

公共数据复用1.1 为什么要公共数据复用?因为我们只有把导航和路由的数据公共的提出来,我们才能告知后端人员需要返回什么数据。1.2 怎么数据复用呢?首先,我们可以在utils文件夹里新建一个navDate.js的文件:把我们需要的…

go入门知识

step1:去https://go.dev下载golang step2:下载jetbrains的Goland编译器(安装的过程会自动帮你配置好环境变量) 一个最简单的go程序 package mainimport ("fmt" )func main() {fmt.Printf("Hello World")}1.定义变量: …

蓝桥杯C51(试题内容学习)

因为C51只有一组数码管,但是我们需要显示的东西有很多,所以通过按键切换是我们必须要知道的 按键之间有嵌套,切换,计数,对于按键的使用我们是必须知道的 1. HC573锁存器的选择 我们在之前的基础上对其进行了优化&…