多层全连接网络:实现手写数字识别50轮准确率92.1%

news2025/1/19 20:21:28

多层全连接网络:实现手写数字识别50轮准确率92.1%

  • 1 导入必备库
  • 2 torchvision内置了常用数据集和最常见的模型
  • 3 数据批量加载
  • 4 绘制样例
  • 5 创建模型
  • 7 设置是否使用GPU
  • 8 设置损失函数和优化器
  • 9 定义训练函数
  • 10 定义测试函数
  • 11 开始训练
  • 12 绘制损失曲线并保存
  • 13 绘制准确率曲线并保存

1 导入必备库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)

输出:

1.12.1+cu102

2 torchvision内置了常用数据集和最常见的模型

import torchvision
from torchvision.transforms import ToTensor
''' transforms.ToTensor    
    1.转化为一个 tensor
    2.转换到0-1之间
    3.会将channel放在第一维度上
'''
train_ds = torchvision.datasets.MNIST('data/',
                                      train=True,
                                      transform=ToTensor(),
                                      download=False
                                     )
test_ds = torchvision.datasets.MNIST('data/',
                                     train=False,
                                     transform=ToTensor(),
                                     download=False  
                                    )
print(len(train_ds),len(test_ds))

输出:

60000 10000

3 数据批量加载

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=256)

# iter方法创建生成器,next方法返回一个批次的图像,shape属性返回一批次张量形状
imgs, labels = next(iter(train_dl))
print(imgs.shape)
print(labels.shape)

输出:

torch.Size([64, 1, 28, 28])
torch.Size([64])

4 绘制样例

plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
    npimg = img.numpy()
    npimg = np.squeeze(npimg)
    plt.subplot(1, 10, i+1)
    plt.imshow(npimg)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(labels[i].numpy())
    # plt.axis('off') #关闭显示坐标
    plt.savefig('pics/3.1.jpg', dpi=400)

1

5 创建模型

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.liner_1 = nn.Linear(28*28, 120)
        self.liner_2 = nn.Linear(120, 84)
        self.liner_3 = nn.Linear(84, 10)
    def forward(self, input):
        x = input.view(-1, 28*28)
        x = F.relu(self.liner_1(x))
        x = F.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x

7 设置是否使用GPU

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# 将模型移动到DEVICE
model = Model().to(device)
print(model)

输出:

Using cuda device
Model(
  (liner_1): Linear(in_features=784, out_features=120, bias=True)
  (liner_2): Linear(in_features=120, out_features=84, bias=True)
  (liner_3): Linear(in_features=84, out_features=10, bias=True)
)

8 设置损失函数和优化器

loss_fn = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

9 定义训练函数

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= size
    correct /= size
    return train_loss, correct

10 定义测试函数

def test(dataloader, model):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    return test_loss, correct

11 开始训练

epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
    epoch_test_loss, epoch_test_acc = test(test_dl, model)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    
    template = ("epoch:{:2d}/{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ," 
                "test_loss: {:.5f}, test_acc: {:.1f}%")
    print(template.format(
          epoch+1,epochs, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
    
print("Done!")

输出:

epoch: 1/50, train_loss: 0.02354, train_acc: 68.0% ,test_loss: 0.00537, test_acc: 70.7%
epoch: 2/50, train_loss: 0.01930, train_acc: 71.9% ,test_loss: 0.00437, test_acc: 74.9%
epoch: 3/50, train_loss: 0.01598, train_acc: 76.0% ,test_loss: 0.00366, test_acc: 78.5%
······
epoch:48/50, train_loss: 0.00455, train_acc: 91.7% ,test_loss: 0.00111, test_acc: 92.0%
epoch:49/50, train_loss: 0.00451, train_acc: 91.8% ,test_loss: 0.00110, test_acc: 92.1%
epoch:50/50, train_loss: 0.00449, train_acc: 91.8% ,test_loss: 0.00109, test_acc: 92.1%
Done!

12 绘制损失曲线并保存

plt.plot(range(1, epochs+1), train_loss, label='train_loss', lw=2)
plt.plot(range(1, epochs+1), test_loss, label='test_loss', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-3.jpg', dpi=400)

输出:
在这里插入图片描述

13 绘制准确率曲线并保存

plt.plot(range(1, epochs+1), train_acc, label='train_acc', lw=2)
plt.plot(range(1, epochs+1), test_acc, label='test_acc', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-4.jpg', dpi=400)

输出:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1007666.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ std::future

std::future是用来接收一个线程的执行结果的,并且是一次性的。 共享状态shared state future可以关联一个共享状态,共享状态是用来储存要执行结果的。这个结果是async、promise、packaged_task设置的,且这个结果只能设置一次。 创建future …

【rtp-benchmarks】读取本地文件基于uvgRtp实现多线程发送

input 文件做内存映射 : get_mem D:\XTRANS\soup\uvg-rtp-dev\rtp-benchmarks\util\util.cc 文件中读取chunksize 到 vector 里作为chunks 创建多个线程进行发送 std::vector<std::thread*> threads;

C++数据结构X篇_12_树的基本概念和存储

学习二叉树之前先学习树的概念。 文章目录 1. 树的基本概念1.1 树的定义1.2 树的特点1.3 若干术语 2. 树的表示法2.1 图形表示法2.2 广义表表示法 3. 树的存储3.1 双亲表示法&#xff1a;保存父节点关系3.2 孩子表示法3.3 左孩子右兄弟表示法 1. 树的基本概念 之前所学均为线性…

22 相交链表

相交链表 题解1 快慢双指针改进 (acb bca)题解2 哈希表(偷懒) 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 题目数据 保证 整个链式结构中不存在环。 注意&#xff…

Golang gorm manytomany 多对多 更新、删除、替换

Delete 移除 只删除中间表的数据 删除原有的 var a Article1db.Preload("Tag1s").Take(&a, 1)fmt.Printf("%v", a) {1 k8s [{1 cloud []} {2 linux []}]}mysql> select * from article1; ------------ | id | title | ------------ | 1 | k8s …

导数公式及求导法则

目录 基本初等函数的导数公式 求导法则 有理运算法则 复合函数求导法 隐函数求导法 反函数求导法 参数方程求导法 对数求导法 基本初等函数的导数公式 基本初等函数的导数公式包括&#xff1a; C0(x^n)nx^(n-1)(a^x)a^x*lna(e^x)e^x(loga(x))1/(xlna)(lnx)1/x(sinx)cos…

十大排序算法及Java中的排序算法

文章目录 一、简介二、时间复杂度三、非线性时间比较类排序冒泡排序&#xff08;Bubble Sort&#xff09;排序过程代码实现步骤拆解演示复杂度 选择排序&#xff08;Selection Sort&#xff09;排序过程代码实现步骤拆解演示复杂度 插入排序&#xff08;Insertion Sort&#xf…

用冒泡排序完成库函数qsort的作用

Hello&#xff0c;今天分享的是我们用冒泡函数实现qsort&#xff0c;也就是快排&#xff0c;之前我们也讲过库函数qsort的使用方法&#xff0c;今天我们尝试用冒泡函数实现一下&#xff0c;当然我们也见过qsort&#xff0c;后面也会继续完善的。这几天我是破防大学生&#xff0…

使用python制作一个简单的任务管理器

本篇文章教大家 使用 Python 创建一个简单的任务管理器应用程序。这个项目将帮助你练习 Python 编程的许多方面&#xff0c;包括文件操作、用户输入处理和基本的命令行界面设计。在这篇文章中&#xff0c;我将指导你创建一个基本的命令行任务管理器。 目录 任务管理器的用途任务…

NLP机器翻译全景:从基本原理到技术实战全解析

目录 一、机器翻译简介1. 什么是机器翻译 (MT)?2. 源语言和目标语言3. 翻译模型4. 上下文的重要性 二、基于规则的机器翻译 (RBMT)1. 规则的制定2. 词典和词汇选择3. 限制与挑战4. PyTorch实现 三、基于统计的机器翻译 (SMT)1. 数据驱动2. 短语对齐3. 评分和选择4. PyTorch实现…

App Inventor 2 实现Ascii码转换(Ascii编码与解码)

之前有同学问&#xff0c;App Inventor 2 字符及Ascii码如何进行转换&#xff0c;经过调查&#xff0c;其原生的组件和内置块无法完成这个功能&#xff0c;网上也有利用Web客户端组件执行js代码来进行转换&#xff0c;不过逻辑稍复杂效率还不高。这里介绍一个拓展可以非常方便的…

【算法系列 | 8】深入解析查找算法之—二分查找

序言 心若有阳光&#xff0c;你便会看见这个世界有那么多美好值得期待和向往。 决定开一个算法专栏&#xff0c;希望能帮助大家很好的了解算法。主要深入解析每个算法&#xff0c;从概念到示例。 我们一起努力&#xff0c;成为更好的自己&#xff01; 今天第8讲&#xff0c;讲一…

CrossOver 23 正式发布:可在 Mac 上运行部分 DX12 游戏

CodeWeivers 公司于今年 6 月发布了 CrossOver 23 测试版&#xff0c;重点添加了对 DirectX 12 支持&#xff0c;从而在 Mac 上更好地模拟运行 Windows 游戏。 该公司今天发布新闻稿&#xff0c;表示正式发布 CrossOver 23 稳定版&#xff0c;在诸多新增功能中&#xff0c;最值…

【经典小练习】JavaSE—拷贝文件夹

&#x1f38a;专栏【Java小练习】 &#x1f354;喜欢的诗句&#xff1a;天行健&#xff0c;君子以自强不息。 &#x1f386;音乐分享【如愿】 &#x1f384;欢迎并且感谢大家指出小吉的问题&#x1f970; 文章目录 &#x1f384;效果&#x1f33a;代码&#x1f6f8;讲解&#x…

Java学习之--类和对象

&#x1f495;粗缯大布裹生涯&#xff0c;腹有诗书气自华&#x1f495; 作者&#xff1a;Mylvzi 文章主要内容&#xff1a;Java学习之--类和对象 类和对象 类的实例化&#xff1a; 1.什么叫做类的实例化 利用类创建一个具体的对象就叫做类的实例化&#xff01; 当我们创建了…

预算有限但想改善客户服务?教你几招轻松解决~

这里有一个常见的误解&#xff1a;只有大公司需要客户服务。事实是&#xff0c;无论行业规模大小&#xff0c;出色的客户服务对每个企业都至关重要。事实上&#xff0c;企业规模越小&#xff0c;客户服务就越重要&#xff0c;因为他们无法承受失去客户的后果。 不仅如此&#…

连nil切片和空切片一不一样都不清楚?那BAT面试官只好让你回去等通知了。

连nil切片和空切片一不一样都不清楚&#xff1f;那BAT面试官只好让你回去等通知了。 问题 package mainimport ("fmt""reflect""unsafe" )func main() {var s1 []ints2 : make([]int,0)s4 : make([]int,0)fmt.Printf("s1 pointer:%v, s2 p…

【UE虚幻引擎】UE源码版编译、Andorid配置、打包

首先是要下载源码版的UE&#xff0c;我这里下载的是5.2.1 首先要安装Git 在你准备放代码的文件夹下右键点击Git Bash Here 然后可以直接git clone https://github.com/EpicGames/UnrealEngine 不行的话可以直接去官方的Github上下载Zip压缩包后解压 运行里面的Setup.bat&a…

浅谈C++|STL之vector篇

一.vector的基本概念 vector是C标准库中的一种动态数组容器&#xff0c;提供了动态大小的数组功能&#xff0c;能够在运行时根据需要自动扩展和收缩。vector以连续的内存块存储元素&#xff0c;可以快速访问和修改任意位置的元素。 以下是vector的基本概念和特点&#xff1a; 动…

第27章_瑞萨MCU零基础入门系列教程之freeRTOS实验

本教程基于韦东山百问网出的 DShanMCU-RA6M5开发板 进行编写&#xff0c;需要的同学可以在这里获取&#xff1a; https://item.taobao.com/item.htm?id728461040949 配套资料获取&#xff1a;https://renesas-docs.100ask.net 瑞萨MCU零基础入门系列教程汇总&#xff1a; ht…