卷积网络:实现手写数字是识别50轮准确率97.3%

news2025/1/18 7:29:26

卷积网络:实现手写数字是识别50轮准确率

  • 1 导入必备库
  • 2 torchvision内置了常用数据集和最常见的模型
  • 3 数据批量加载
  • 4 绘制样例
  • 5 创建模型
  • 7 设置是否使用GPU
  • 8 设置损失函数和优化器
  • 9 定义训练函数
  • 10 定义测试函数
  • 11 开始训练
  • 12 绘制损失曲线并保存
  • 13 绘制准确率曲线并保存

1 导入必备库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)

输出:

1.12.1+cu102

2 torchvision内置了常用数据集和最常见的模型

import torchvision
from torchvision.transforms import ToTensor
''' transforms.ToTensor    
    1.转化为一个 tensor
    2.转换到0-1之间
    3.会将channel放在第一维度上
'''
train_ds = torchvision.datasets.MNIST('data/',
                                      train=True,
                                      transform=ToTensor(),
                                      download=False
                                     )
test_ds = torchvision.datasets.MNIST('data/',
                                     train=False,
                                     transform=ToTensor(),
                                     download=False  
                                    )
print(len(train_ds),len(test_ds))

输出:

60000 10000

3 数据批量加载

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=256)

# iter方法创建生成器,next方法返回一个批次的图像,shape属性返回一批次张量形状
imgs, labels = next(iter(train_dl))
print(imgs.shape)
print(labels.shape)

输出:

torch.Size([64, 1, 28, 28])
torch.Size([64])

4 绘制样例

plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
    npimg = img.numpy()
    npimg = np.squeeze(npimg)
    plt.subplot(1, 10, i+1)
    plt.imshow(npimg)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(labels[i].numpy())
    # plt.axis('off') #关闭显示坐标
    plt.savefig('pics/3.1.jpg', dpi=400)

1

5 创建模型

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)   
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.liner_1 = nn.Linear(16*4*4, 256)
        self.liner_2 = nn.Linear(256, 10)
    def forward(self, input):
        x = F.max_pool2d(F.relu(self.conv1(input)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 16*4*4)
        x = F.relu(self.liner_1(x))
        x = self.liner_2(x)
        return x

7 设置是否使用GPU

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# 将模型移动到DEVICE
model = Model().to(device)
print(model)

输出:

Using cuda device
Model(
  (liner_1): Linear(in_features=784, out_features=120, bias=True)
  (liner_2): Linear(in_features=120, out_features=84, bias=True)
  (liner_3): Linear(in_features=84, out_features=10, bias=True)
)

8 设置损失函数和优化器

loss_fn = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

9 定义训练函数

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= size
    correct /= size
    return train_loss, correct

10 定义测试函数

def test(dataloader, model):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    return test_loss, correct

11 开始训练

epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
    epoch_test_loss, epoch_test_acc = test(test_dl, model)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    
    template = ("epoch:{:2d}/{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ," 
                "test_loss: {:.5f}, test_acc: {:.1f}%")
    print(template.format(
          epoch+1,epochs, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
    
print("Done!")

输出:

epoch: 1/50, train_loss: 0.03559, train_acc: 24.1% ,test_loss: 0.00899, test_acc: 39.7%
epoch: 2/50, train_loss: 0.03413, train_acc: 51.0% ,test_loss: 0.00827, test_acc: 59.9%
epoch: 3/50, train_loss: 0.02756, train_acc: 62.9% ,test_loss: 0.00527, test_acc: 71.5%
······
epoch:48/50, train_loss: 0.00158, train_acc: 97.0% ,test_loss: 0.00037, test_acc: 97.1%
epoch:49/50, train_loss: 0.00155, train_acc: 97.0% ,test_loss: 0.00035, test_acc: 97.3%
epoch:50/50, train_loss: 0.00153, train_acc: 97.0% ,test_loss: 0.00035, test_acc: 97.3%
Done!

12 绘制损失曲线并保存

plt.plot(range(1, epochs+1), train_loss, label='train_loss', lw=2)
plt.plot(range(1, epochs+1), test_loss, label='test_loss', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-5.jpg', dpi=400)

输出:
在这里插入图片描述

13 绘制准确率曲线并保存

plt.plot(range(1, epochs+1), train_acc, label='train_acc', lw=2)
plt.plot(range(1, epochs+1), test_acc, label='test_acc', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-6.jpg', dpi=400)

输出:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1008019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Xilinx ZYNQ 7000学习笔记五(Xilinx SDK 烧写镜像文件)

概述 前面几篇讲了ZYNQ7000的启动过程,包括BootRom和FSBL的代码逻辑,其中关于FSBL代码对启动模式为JTAG被动启动没有进行分析,本篇将通过将JTAG的功能和通过Xilinx SDK烧写镜像文件到flash来顺道把FSBL中的JTAG代码部分给讲解下。 1.JTAG …

springboot+jxls复杂excel模板导出

JXLS 是基于 Jakarta POI API 的 Excel 报表生成工具,可以生成精美的 Excel 格式报表。它采用标签的方式,类似 JSP 标签,写一个 Excel 模板,然后生成报表,非常灵活,简单! Java 有一些用于创建 …

为什么不推荐使用Lombok?@Data不香吗?

目录 一、前言 二、源码跟踪 三、总结 一、前言 之前写项目遇到的一个Bug,下面是模拟代码。 新建一个springboot的项目,Person一个实体类,定义一个方法传一个JSON数据 Data public class Person {private String name;private String a…

el-table表格动态设置最大高度 高度根据窗口可视高度大小改变自适应

由于表格内容过多,如果不给高度限制,每页100条数据的情况下,去操作底部的分页或者其他功能都需要划到数据最底部操作,用户体验性较差。解决方法是让表格一屏展示,超出部分滚动展示。 1.效果及思路图: 思路是…

【uniapp】小程序开发,初始化项目vscode

使用uniapp开发小程序可以实现一份代码打包成多个不同平台的小程序。 这里使用uniapp官方的项目模板作为示例,采用vue3ts开发,并使用vscode作为开发工具 一、通过命令行创建项目并运行 1、通过以下命令创建模板项目 参考 官方说明 npx degit dcloudi…

Trinitycore学习之windows上用cmake生成vs项目并尝试在windows上启动服务

0:参考 https://trinitycore.info/en/install/requirements/windows 参考该文章安装相关的工具,主要有boost,openssl,cmake,mysql,vs2022自己电脑已经安装。 1:安装mysql 用zip进行安装的方式&#xff…

又一款国产 Web 防火墙工具,开源了?

众所周知,Web 网站是当今互联网上最主流的业务形态,随着开源 Web 框架和各种建站工具的兴起,搭建网站已经是一件成本非常低的事情,但是网站的安全性很少有人关注,WAF 这个品类也鲜为人知。 WAF 是什么 WAF 是网站的防…

多合一小程序商城系统源码完整版 源码开源 支持多行业多门店

分享一个多合一小程序商城系统源码完整版,源码开源,支持多端和多行业适用,将多个小程序商城的功能整合到一个系统中,商家只需通过一个系统就能管理多个小程序商城,一个后台控制7端,支持微信小程序支付宝小程…

Transformers-Bert家族系列算法汇总

🤗 Transformers 提供 API 和工具,可轻松下载和训练最先进的预训练模型。使用预训练模型可以降低计算成本、碳足迹,并节省从头开始训练模型所需的时间和资源。这些模型支持不同形式的常见任务,例如: 📝 自…

智慧城市道路通行时间预测(笔记未完成版)

数据与任务目标分析 数据 道路通行时间 当前道路在该时间段内有车通行的时间 道路长宽情况 道路连接情况 任务 基于历史数据预测某个时间段内,如预测未来一个月travel_time, 每2分钟内通行时间。 构建时间序列,基于时间序列预测 预测高峰点&…

简单介绍Rope Crystal(类似Roop)项目

文章目录 (一)关于 Rope Crystal(二)安装 Rope Crystal(三)运行 Rope Crystal(3.1)选择目录(3.2)加载目录(3.3)选择并替换&#xff08…

MySQL安装validate_password_policy插件

功能介绍 validate_password_policy 是插件用于验证密码强度的策略。该参数可以设定三种级别:0代表低,1代表中,2代表高。 validate_password_policy 主要影响密码的强度检查级别: 0/LOW:只检查密码长度。 1/MEDIUM&am…

行业追踪,2023-09-13

自动复盘 2023-09-13 凡所有相,皆是虚妄。若见诸相非相,即见如来。 k 线图是最好的老师,每天持续发布板块的rps排名,追踪板块,板块来开仓,板块去清仓,丢弃自以为是的想法,板块去留让…

2023下半年创业风口项目:实景自动无人直播!揭秘3大好处!

实景自动无人直播就是2023下半年的创业风口项目,你踩中过风口吗?如果你还没有踩中过风口啊,就缺这么一个机会,那你要注意把握机遇了,建议你看完这篇文章。 为什么说实景自动无人直播将是2023下半年的创业风口项目呢&am…

golang面试官:for select时,如果通道已经关闭会怎么样?如果select中只有一个case呢?

问题 for循环select时,如果通道已经关闭会怎么样?如果select中的case只有一个,又会怎么样? 怎么答 for循环select时,如果其中一个case通道已经关闭,则每次都会执行到这个case。如果select里边只有一个ca…

【2023年Google 开发者大会】武侠风格讲述Gloud

文章目录 Google Cloud 如何加速创新,加强信息安全Google Cloud 如何加强信息安全?1.高级安全防护2.强大的身份验证和访问控制3.基于机器学习的威胁检测 Google Cloud 的 3 个 AI 重点发展领域,了解生成式 AI 功能如何助推创意落地Vertex AIV…

Leetcode刷题_贪心相关_c++版

(1)455分发饼干–简单 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。 对每个孩子 i,都有一个胃口值 g[i],这是能让孩子们满足胃口的饼干的最小尺寸&#xff1b…

Web之tomcat

[TOC](文章目录) 1.程序架构 1.C/S(client/server) 比如:QQ、 微信、 LOL 优点:有一部分代码写在客户端, 用户体验比较好。 缺点: 服务器更新,客户端也要随着更新。 占用资源大。 2. B/S(brows…

【盘点】设计师更偏爱Telerik Kendo UI界面库的4个理由!

就像许多开发人员(错误地)认为设计软件和工具对他们没有任何用处一样,许多设计人员也错误地认为,当涉及到以开发人员为中心的软件和工具时,对他们没有任何价值。事实上,如果双方都愿意走出自己的舒适区去探索,他们都会…

在微信小程序上怎么实现多门店管理功能

微信小程序已经成为连接线上与线下的重要工具,尤其对于拥有多家门店的企业来说,通过微信小程序可以实现多门店管理,提高管理效率和用户体验。下面,我将为大家详细介绍如何在微信小程序上实现多门店管理功能。 一、确定多门店管理功…