Pytorch入门实战 P2-CIFAR10彩色图片识别

news2024/9/20 0:56:25

目录

一、前期准备

1、数据集CIFAR10

2、判断自己的设备,是否可以使用GPU运行。

3、下载数据集,划分好训练集和测试集

4、加载训练集、测试集

5、取一个批次查看下

6、数据可视化

二、搭建简单的CNN网络模型

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、模型训练结果可视化

五、模型训练结果:


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这周的实战内容,主要使用的数据集是CIFAR10数据集。用来验证彩色图片的识别。

一、前期准备

1、数据集CIFAR10

我们使用的数据集的文档地址:Datasets — Torchvision 0.17 documentation

简单介绍下CIFAR10数据集:

CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。

50000张训练图像10000张测试图像

2、判断自己的设备,是否可以使用GPU运行。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

3、下载数据集,划分好训练集和测试集

import torchvision.datasets

# 下载训练集
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 下载测试集
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)

4、加载训练集、测试集

# 使用dataloader加载数据集,并设置好batch_size
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,
                                       shuffle=True,
                                       batch_size=batch_size)
test_dl = torch.utils.data.DataLoader(test_ds,
                                      batch_size=batch_size)

5、取一个批次查看下

# 取一个批次,查看下数据
imgs,labels = next(iter(train_dl))
print(imgs.shape)   #  数据的shape为:[batch_size,channel,height,weight]  
'''
    对于CIFAR10,这里的shape是 [32,3,32,32],即 因为取得是train_dl的数据,batch_size为32;
    channel为3是因为,是彩色图片RGB的3通道,如果是黑白图片,则channel为1;剩下的32x32是高度和宽度;
'''

6、数据可视化

即:展示下取到的数据。

# 数据可视化
plt.figure(figsize=(20,5))
for i, imgs in enumerate(imgs[:20]):
    npimg = imgs.numpy().transpose((1,2,0))   
            #.numpy()用于将Tensor转换为一个Numpy数组。transpose是Numpy数组的一个方法,用于重新排列数组的维度。
    plt.subplot(2, 10, i+1)
    plt.imshow(npimg, cmap=plt.cm.binary)
    plt.axis('off')
plt.show()

运行结果展示: 

二、搭建简单的CNN网络模型

 CNN(卷积神经网络),需要注意其结构、层与层之间的连接关系以及各层的功能。

①卷积层:负责提取特征。(通常使用局部连接权值共享方式,这有助于减少网络的参数数量和计算复杂度。)

②池化层:负责降低数据的空间尺寸和计算复杂度。

③全连接层:负责将提取的特征映射到输出类别。

# 构建简单的CNN网络
num_classes = 10
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        # 特征提取
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, num_classes)

    # 前向传播
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))

        x = torch.flatten(x, start_dim=1)  # 线性层+激活函数  是构建复杂模型的基础
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 打印并加载模型
model = Model().to(device)
print(model)

三、训练模型

1、设置超参数

# 1、设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2   #学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)   # 定义一个随机梯度下降优化器,即SGD优化器。
                    # model.parameters() 返回模型中所有可训练的参数(通常是权重和偏置)

2、编写训练函数

# 2、编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # 数据集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目 1875 (60000/32 = 1875)

    train_loss, train_acc = 0, 0   # 初始化训练的损失和正确率
    for X,y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,y为真实值,计算二者差值,即为损失。

        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches
    return train_acc, train_loss

3、编写测试函数

# 3、编写测试函数
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 数据集的大小,共10000张
    num_batches = len(dataloader)  # 批次数目 ,313( 10000/32 = 321.5 ,向上取整)

    test_loss, test_acc = 0, 0  # 初始化测试的损失和精确

    # 不进行训练时,停止梯度下降,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

        test_acc /= size
        test_loss /= num_batches
        return test_acc, test_loss

4、正式训练

# 4、正式训练
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []

'''
     model.train()和model.eval() 是深度学习中常见的两个方法,它们用于设置模型的训练模式和评估模式。
        ①当你调用model.train()时,你正在告诉模型你即将进入训练阶段。通常意味着模型中的某些层(如Dropout层和BatchNormalization层)会改变它们的行为以适应训练过程。
            Dropout层:在训练模式下,Dropout层会随机将一部分神经元的输出设置为0,有助于防止过拟合。
            BatchNormalization层:在训练模式下,BatchNoralization层会使用当前批次的数据来更新其运行均值和方差,并应用这些统计量来标准化输入。
        ②当你调用model.eval()时,你正在告诉模型你即将进入评估或推断阶段。在这种模式下,模型的某些层会改变它们的行为,以确保在评估时模型给出一致的结果。
'''
for epoch in range(epochs):
    model.train()  # 进入训练阶段
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}'
    print(template.format(epoch+1, epoch_train_acc*100,epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Finish')

四、模型训练结果可视化

# 四、结果可视化
warnings.filterwarnings('ignore')   # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100    # 分辨率

epochs_range = range(epochs)  # 生成从0到epoches-1的整数序列

plt.figure(figsize=(12,3))  # figsize=(12,3)  包含两个元素的元组,分别代表图形的宽度和高度,单位是英寸。

plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')

# 在远程服务器上面跑代码,想要保存下,plt.show()的结果,打下下面的注释
# plt.savfig('想要保存的服务器的地址+图片的名称.png/jpg自行定义即可')  
# eg:plt.savefig('/data/jupyter/deepinglearning/resultImg.jpg')

plt.show()
print("画图结束。。。")

五、模型训练结果:

这周和上周的代码类似,但是,比起刚开始的时候,好多代码都清晰了很多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1516486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue2】slot 插槽全家桶

插槽-默认插槽 插槽的基本语法 组件内需要定制的结构部分&#xff0c;改用<slot></slot>占位使用组件时, <MyDialog></MyDialog>标签内部, 传入结构替换slot给插槽传入内容时&#xff0c;可以传入纯文本、html标签、组件 插槽-默认值 封装组件时&am…

Nginx的日志怎么看,在哪看,access.log日志内容详解

Nginx 的日志文件通常位于服务器的文件系统中&#xff0c;具体位置可能因配置而异。以下是查看 Nginx 日志的几种方法&#xff1a; 1、查看访问日志&#xff1a;在默认配置下&#xff0c;Nginx 的访问日志文件路径为 /var/log/nginx/access.log。您可以通过命令 sudo cat /var…

创新营销的新篇章:企业如何通过VR虚拟发布会提升品牌影响力

在数字化转型的浪潮中&#xff0c;VR虚拟发布会作为一种新兴的营销手段&#xff0c;正逐渐成为企业品牌推广和产品发布的重要选择。通过利用虚拟现实技术&#xff0c;企业能够在虚拟空间中举办发布会&#xff0c;为参与者提供沉浸式的体验。 一、创新体验&#xff1a;虚拟空间的…

linux系统对于docker容器的监控

容器监控 容器监控原生命令操作问题 容器监控三剑客CAdvisorInfluxDBGranfana compose编排监控工具新建目录创建CIG.yml文件启动docker-compose测试 容器监控 CAdvisorInfluxDBGranfana 原生命令 操作 docker stats问题 通过docker stats命令可以很方便的看到当前宿主机上所…

【黑马程序员】Python文件操作

文章目录 文件操作文件编码什么是编码为什么要使用编码 文件的读取openmodel常用的三种基础访问模式读操作相关方法 文件的写入注意代码示例 文件操作 文件编码 什么是编码 编码就是一种规则集合&#xff0c;记录了内容和二进制间进行互相转换的规则 最常用的是UTF-8编码 …

魔法手链(burnside+矩阵优化+dp acwing 3134)

题目&#xff1a;3134. 魔法手链 - AcWing题库 思路&#xff1a; 代码&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<string> #include<cstring> #include<cmath> #include<ctime> #include<algorithm> #i…

史上最全Spring教程,从零开始带你深入♂学习(三)—

减少数据处理量&#xff0c;提高查询效率 (一)使用Limit分页 –从第2个开始查询&#xff0c;每一页10个 select * from user limit 2,10 –从第0个开始查询&#xff0c;每一页10个 SELECT * from user limit 10; 领取资料 (二)使用Mybatis实现分页&#xff0c;核心SQL 1、编…

小文件问题及GlusterFS的瓶颈

01海量小文件存储的挑战 为了解决海量小文件的存储问题&#xff0c;必须采用分布式存储&#xff0c;目前分布式存储主要采用两种架构&#xff1a;集中式元数据管理架构和去中心化架构。 (1)集中式元数据架构&#xff1a; 典型的集中式元数据架构的分布式存储有GFS&#xff0…

一、NLP中的文本分类

目录 1.0 文本分类的应用场景 1.1 文本分类流程 ​编辑 1.2 判别式模型 1.3 生成式模型 1.4 评估 1.5 参考文献 NLP学习笔记系列&#xff0c;欢迎收藏交流&#xff1a; 零、自然语言处理开篇-CSDN博客 一、NLP中的文本分类-CSDN博客 二、NLP中的序列标注&#xff08;分…

413 Request Entity Too Large 问题如何解决

遇到“413 Request Entity Too Large”错误通常意味着你尝试上传或提交到服务器的数据量超过了服务器能够处理的限制。这个问题通常与Web服务器的配置相关&#xff0c;比如Nginx或Apache。这个问题出现在使用Nginx作为Web服务器的环境中。这里有几种解决方法&#xff1a; 1. 调…

openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断

文章目录 openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断242.1 SQL自诊断242.1.1 告警场景242.1.2 规格约束 openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断 SQL调优是一个不断分析与尝试的过程&#xff1a;试跑Query&…

流水账-20240314

目录 Linux系统删除文件后&#xff0c;磁盘大小没变化mysql事务和neo4j事务冲突误诊描述解决方法网上提供的方法重置Neo4j密码&#xff0c;成功解决问题高版本低版本 Linux系统删除文件后&#xff0c;磁盘大小没变化 lsof L1|grep 删除的文件名kill进程 mysql事务和neo4j事务…

面试题系列一之-css画三角形(原理解析)

用html写一个三角形的图标算是一个比较简单的,但是工作中用的还是比较多的&#xff0c;面试也可能会问&#xff0c;但了解背后的原理才能熟练使用 我们首先写一个div,设置边框 <body><div class"border"></div> </body> <style> .bo…

华宽通招商资源推介平台:一站式立体展示,招商资源尽在眼前

传统园区在招商引资推介过程中&#xff0c;主要以画册、PPT、视频等形式进行介绍&#xff0c;对于客商来说体验感不佳&#xff0c;难以通过地理信息、空间信息和图文信息结合的方式&#xff0c;更加直观和立体地呈现园区整体优势和每个载体资源的详细情况&#xff0c;导致客商无…

基于SpringBoot的“家政服务管理平台”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“家政服务管理平台”的设计与实现&#xff08;源码数据库文档PPT) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 系统首页界面图 用户注册界面图 服务信息界面…

DeleteFile(szFilePath)失败,之后再对文件操作,造成崩溃

调用WINAPI函数DeleteFile(szFilePath1)之后&#xff1a; 1.如果不再对szFilePath1文件进行操作 DeleteFile()函数执行失败》也不会造成 软体崩溃&#xff01; 2.如果后续需要对szFilePath1文件进行操作 DeleteFile()函数执行失败》就会造成 软体崩溃&#xff01; 所以&…

【刷题训练】LeetCode:557. 反转字符串中的单词 III

557. 反转字符串中的单词 III 题目要求 示例 1&#xff1a; 输入&#xff1a;s “Let’s take LeetCode contest” 输出&#xff1a;“s’teL ekat edoCteeL tsetnoc” 示例 2: 输入&#xff1a; s “Mr Ding” 输出&#xff1a;“rM gniD” 思路&#xff1a; 第一步&am…

Clickhouse MergeTree 原理(一)

作者&#xff1a;俊达 MergeTree是Clickhouse里最核心的存储引擎。Clickhouse里有一系列以MergeTree为基础的引擎&#xff08;见下图&#xff09;&#xff0c;理解了基础MergeTree&#xff0c;就能理解整个系列的MergeTree引擎的核心原理。 本文对MergeTree的基本原理进行介绍…

目标检测——YOLOv3算法解读

论文&#xff1a;YOLOv3&#xff1a;An Incremental Improvement 作者&#xff1a;Joseph Redmon, Ali Farhadi 链接&#xff1a;https://arxiv.org/abs/1804.02767 代码&#xff1a;http://pjreddie.com/yolo/ YOLO系列其他文章&#xff1a; YOLOv1通俗易懂版解读SSD算法解读…

pgsql常用索引简写

文章来源&#xff1a;互联网博客文章&#xff0c;后续有时间再来细化整理。 在数据库查询中&#xff0c;合理的使用索引&#xff0c;可以极大提升数据库查询效率&#xff0c;充分利用系统资源。这个随着数据量的增加得到提升&#xff0c;越大越明显&#xff0c;也和业务线有关…