Pytorch从零开始实战01

news2024/10/6 8:33:44

Pytorch从零开始实战——MNIST手写数字识别

文章目录

  • Pytorch从零开始实战——MNIST手写数字识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 可视化展示

环境准备

本系列基于Jupyter notebook,使用Python3.7.12,Pytorch1.7.0+cu110,torchvision0.8.0,需读者自行配置好环境且有一些深度学习理论基础。

导入需要用到的包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
from time import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cuda’)

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实战使用MNIST数据集,这是一个包含了手写数字的灰度图像的数据集,每个图像都是28x28像素大小,并且标记了相应的数字,也是很多计算机视觉初学者第一个使用的数据集。

导入训练集与测试集,使用torchvision.datasets可以在线下载很多常见数据集,只需要将后面参数设置download=True即可直接下载,train=True为训练集,train=False为测试集

# 导入训练集和测试集
train_data = torchvision.datasets.MNIST('data', train=True, 
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True
                                       )
test_data = torchvision.datasets.MNIST('data', train=False, 
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True
                                      )

定义一个函数,随机查看5张图片

# 随机展示5个图片 data = torchvision.datasets....  需要接受tensor格式的对象
def plotsample(data):
    fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图
    for i in range(5):
        num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次
        #抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据
        #而展示图像用的imshow函数最常见的输入格式也是3通道
        npimg = torchvision.utils.make_grid(data[num][0]).numpy()
        nplabel = data[num][1] #提取标签 
        #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取
        axs[i].imshow(np.transpose(npimg, (1, 2, 0))) 
        axs[i].set_title(nplabel) #给每个子图加上标签
        axs[i].axis("off") #消除每个子图的坐标轴

plotsample(train_data)

在这里插入图片描述

使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

模型选择

由于数据集较为简单,所以本次实验使用简单的卷积神经网络。

第一次卷积和池化:
self.conv1 是第一个卷积层,将输入特征图的通道数从1增加到32,同时使用3x3的卷积核进行卷积。由于没有填充(padding)操作,卷积后的特征图大小减小为原来的大小减2(28x28 -> 26x26)。
self.pool1 是第一个最大池化层,将特征图的大小减半,从26x26变为13x13。
第二次卷积和池化:
self.conv2 是第二个卷积层,将输入特征图的通道数从32增加到64,同样使用3x3的卷积核进行卷积。由于没有填充操作,卷积后的特征图大小再次减小为原来的大小减2(13x13 -> 11x11)。
self.pool2 是第二个最大池化层,将特征图的大小再次减半,从11x11变为5x5。
全连接层:
在进入全连接层之前,需要将最后一个池化层的输出拉平成一个一维向量。这是通过 torch.flatten(x, start_dim=1) 完成的,它将5x5x64的三维张量转换为长度为5x5x64 = 1600的一维向量。
然后,self.fc1 是第一个全连接层,将1600个输入特征映射到64个输出特征。
最后进行10分类输出结果。

num_classes = 10 # 10分类
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(1600, 64)
        self.fc2 = nn.Linear(64, num_classes)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        
        x = torch.flatten(x, start_dim=1) # 拉平
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

将模型转移到GPU中,并使用summary查看模型

from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

在这里插入图片描述

模型训练

定义损失函数、学习率、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

定义训练函数,返回一个epoch的模型的准确率和损失

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_loss, train_acc = 0, 0
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
        
    train_acc /= size
    train_loss /= num_batches
    return train_acc, train_loss

定义测试函数,与训练函数类似,只是停止梯度更新,节省计算内存消耗

def test (dataloader, model, loss_fn):
    size = len(dataloader.dataset) 
    num_batches = len(dataloader)         
    test_loss, test_acc = 0, 0
    
    with torch.no_grad():
        for X, target in dataloader:
            X, target = X.to(device), target.to(device)
            
            pred = model(X)
            loss = loss_fn(pred, target)
            
            test_acc += (pred.argmax(1) == target).type(torch.float).sum().item()
            test_loss += loss.item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

开始训练,一共进行了5轮epoch,最后在训练集准确率可达97.7%,测试集准确率可达98.1%

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval() # 确保模型不会进行训练操作
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"
          % (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")

可视化展示

使用matplotlib进行训练、测试的可视化

plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/977318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python综合案例(动态柱状图)

一、基础柱状图 基本代码: """ 演示基础柱状图的开发 """ from pyecharts.charts import Bar from pyecharts.options import LabelOpts # 使用Bar构建基础柱状图 bar Bar() # 添加x轴的数据 bar.add_xaxis(["中国", &q…

谷歌浏览器打开白屏 后台还有还有很多google chrome进程在运行

环境: Win10 专业版 谷歌浏览器 版本 116.0.5845.141(正式版本) (64 位) L盾加密终端 问题描述: 谷歌浏览器打开白屏 后台还有还有很多google chrome进程在运行,要全部结束谷歌浏览器进程&…

pear admin 新增模块流程

pear admin 新增模块流程 一、界面新增模块二、增加路由情况三、增加前端页面四、增加db Module配置 一、界面新增模块 增加主菜单 增加子菜单 对应底层表:rt_power 二、增加路由情况 增加路由代码 from flask import render_template from common.utils.righ…

如何实现24/7客户服务自动化?

传统的客服制胜与否的法宝在于人,互联网时代,对于产品线广的大型企业来说:单靠人力,成本大且效率低,相对于产品相对单一的中小型企业来说:建设传统客服系统的成本难以承受,企业客户服务的转型已…

计算机网络初识

目录 1、计算机网络背景 网络发展 认识 "协议" 2、网络协议初识 OSI七层模型 TCP/IP五层(或四层)模型 3、网络传输基本流程 网络传输流程图 数据包封装和分用 4、网络中的地址管理 认识IP地址 认识MAC地址 1、计算机网络背景 网络发展 在之前呢&…

尼康D90使用心得

文章目录 规格参数快速指南相机机身模式拨盘控制面板取景器拍摄信息展示 核心功能指令拨盘拍摄模式自动模式场景模式快门速度和光圈 固件、软件、驱动升级更多细节参考 规格参数 型号尼康D90发布日期2008年08月机身特性APS-C规格数码单反产品定位中端单反传感器类型CMOS传感器…

Redis Redis的数据结构 - 通用命令 - String类型命令 - Hash类型命令

目录 Redis的数据结构: Redis命令: 通用命令:(通用指令是部分数据类型的,都可以使用的指令) KEYS查询命令: DEL删除命令: EXISTS判断命令: EXPIPE有效期设置命令&…

Ubuntu系统安装JDK1.8(附网盘链接)

这里写目录标题 1.下载JDK:2.将压缩包上传至服务器:3.安装JDK:4.配置环境变量:5.配置生效:6.检查JDK版本: 1.下载JDK: 方式一:[官网链接](https://www.oracle.com/java/technologie…

vue3+vant4封装日期时间组件(年月日时分秒)

vant4目前无法直接使用vant3自带的年月日时分秒组件&#xff0c;综合考虑下&#xff0c;决定自己封装一个&#xff01; vue3vant4封装日期时间组件&#xff08;年月日时分秒&#xff09; 效果图代码片段核心组件代码引入 效果图 代码片段 核心组件代码 <template><!…

软件评测师之码制

目录 一、机器数二、码制三、数的表示范围 一、机器数 机器数就是一个数在计算机中的二进制表示&#xff0c;计算机中机器数的最高位是符号位&#xff0c;正数符号位为0&#xff0c;负数符号位为1&#xff0c;机器数包含原码、反码和补码三种表示形式。 二、码制 表现形式数…

Flink基础

Flink architecture job manager is master task managers are workers task slot is a unit of resource in cluster, number of slot is equal to number of cores(超线程则slot2*cores), slot一组内存一些线程共享CPU when starting a cluster,job manager will allocate a …

【个人博客系统网站】我的博客列表页 · 增删改我的博文 · 退出登录 · 博客详情页 · 多线程应用

【JavaEE】进阶 个人博客系统&#xff08;4&#xff09; 文章目录 【JavaEE】进阶 个人博客系统&#xff08;4&#xff09;1. 增加博文1.1 预期效果1.1 约定前后端交互接口1.2 后端代码1.3 前端代码1.4 测试 2. 我的博客列表页2.1 期待效果2.2 显示用户信息以及博客信息2.2.1…

文件能做二维码吗?多种文件格式在线转二维码

怎么把文件做成二维码&#xff1f;在使用电脑办公时&#xff0c;必不可少的经常会使用word、excel、ppt等文件格式&#xff0c;那么当需要将文件生成二维码使用时&#xff0c;如何操作才能快速制作二维码呢&#xff1f;可以使用二维码生成器来在线制作二维码&#xff0c;与使用…

知识储备--基础算法篇-子串

1.子串 1.1第560题-和为k的子数组 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的连续子数组的个数 。 示例 1&#xff1a; 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2 一开始想用滑动窗口&#xff0c;但是在运行过程中碰…

定时任务管理器(xxl-job)

文章目录 xxl-job简介安装使用拉取xxl-job项目导入数据库表启动 admin 服务端Spring Boot 整合 xxl-job修改执行器新建定时任务 xxl-job简介 XXL-JOB是一个分布式任务调度平台&#xff0c;其核心设计目标是开发迅速、学习简单、轻量级、易扩展。开箱即用。 admin &#xff1a;…

uni-app 可视化创建的项目 移动端安装调试插件vconsole

可视化创建的项目&#xff0c;在插件市场找不到vconsole插件了。 又不好npm install vconsole 换个思路&#xff0c;先创建一个cli脚手架脚手架的uni-app项目&#xff0c;然后再此项目上安装vconsole cli脚手架创建uni-app项目 安装插件 项目Terminal运行命令&#xff1a;npm…

商城开发:店铺管理系统应具备哪些功能?

电子商务的迅猛发展&#xff0c;越来越多的企业选择在线商城作为业务拓展的重要渠道。而要实现一个成功的在线商城&#xff0c;一个强大而高效的店铺管理系统是不可或缺的。店铺管理系统作为商城的核心管理工具&#xff0c;应具备一系列功能&#xff0c;以提供卓越的用户体验和…

游戏海外运营需要准备什么?

游戏海外运营需要充分的准备和计划&#xff0c;以确保游戏在目标市场中取得成功。以下是一些游戏海外运营需要准备的关键方面。 游戏平台 游戏出海必不可少的就是游戏平台&#xff0c;而且要注意的是&#xff0c;海外游戏平台的搭建和国内有所不同&#xff0c;对于支付方式和语…

zabbix监控网络设备和zabbix proxy

监控linux主机 [rootrocky8 conf]# yum -y install net-snmp vim /etc/snmp/snmpd.conf com2sec notConfigUser default 123456##修改此行,设置团体密码,默认为public,此处 改为123456 view systemview included .1. ##添加此行,自定义授权,否则 zabbix 无法获取数据 [rootr…

【Redis】NoSQL之Redis的配置及优化

关系数据库与非关系数据库 关系型数据库 关系型数据库是一个结构化的数据库&#xff0c;创建在关系模型&#xff08;二维表格模型&#xff09;基础上&#xff0c;一般面向于记录。 SQL 语句&#xff08;标准数据查询语言&#xff09;就是一种基于关系型数据库的语言&a…