pytorch-神经网络-手写数字分类任务

news2025/1/19 12:59:00

Mnist分类任务:

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

  • 读取Mnist数据集

  • 会自动进行下载
    %matplotlib inline
    
    from pathlib import Path
    import requests
    
    DATA_PATH = Path("data")
    PATH = DATA_PATH / "mnist"
    
    PATH.mkdir(parents=True, exist_ok=True)
    
    URL = "http://deeplearning.net/data/mnist/"
    FILENAME = "mnist.pkl.gz"
    
    if not (PATH / FILENAME).exists():
            content = requests.get(URL + FILENAME).content
            (PATH / FILENAME).open("wb").write(content)
    
    import pickle
    import gzip
    
    with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
            ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
    
    from matplotlib import pyplot
    import numpy as np
    
    pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
    print(x_train.shape)

  • \

  • 注意数据需转换成tensor才能参与后续建模训练

    import torch
    
    x_train, y_train, x_valid, y_valid = map(
        torch.tensor, (x_train, y_train, x_valid, y_valid)
    )
    n, c = x_train.shape
    x_train, x_train.shape, y_train.min(), y_train.max()
    print(x_train, y_train)
    print(x_train.shape)
    print(y_train.min(), y_train.max())

    torch.nn.functional 很多层和函数在这里都会见到

    torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

    import torch.nn.functional as F
    
    loss_func = F.cross_entropy
    
    def model(xb):
        return xb.mm(weights) + bias
    bs = 64
    xb = x_train[0:bs]  # a mini-batch from x
    yb = y_train[0:bs]
    weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
    bs = 64
    bias = torch.zeros(10, requires_grad=True)
    
    print(loss_func(model(xb), yb))

    创建一个model来更简化代码

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
    from torch import nn
    
    class Mnist_NN(nn.Module):
        def __init__(self):
            super().__init__()
            self.hidden1 = nn.Linear(784, 128)
            self.hidden2 = nn.Linear(128, 256)
            self.out  = nn.Linear(256, 10)
    
        def forward(self, x):
            x = F.relu(self.hidden1(x))
            x = F.relu(self.hidden2(x))
            x = self.out(x)
            return x
            
    net = Mnist_NN()
    print(net)
    

    for name, parameter in net.named_parameters():
        print(name, parameter,parameter.size())
    hidden1.weight Parameter containing:
    tensor([[ 0.0018,  0.0218,  0.0036,  ..., -0.0286, -0.0166,  0.0089],
            [-0.0349,  0.0268,  0.0328,  ...,  0.0263,  0.0200, -0.0137],
            [ 0.0061,  0.0060, -0.0351,  ...,  0.0130, -0.0085,  0.0073],
            ...,
            [-0.0231,  0.0195, -0.0205,  ..., -0.0207, -0.0103, -0.0223],
            [-0.0299,  0.0305,  0.0098,  ...,  0.0184, -0.0247, -0.0207],
            [-0.0306, -0.0252, -0.0341,  ...,  0.0136, -0.0285,  0.0057]],
           requires_grad=True) torch.Size([128, 784])
    hidden1.bias Parameter containing:
    tensor([ 0.0072, -0.0269, -0.0320, -0.0162,  0.0102,  0.0189, -0.0118, -0.0063,
            -0.0277,  0.0349,  0.0267, -0.0035,  0.0127, -0.0152, -0.0070,  0.0228,
            -0.0029,  0.0049,  0.0072,  0.0002, -0.0356,  0.0097, -0.0003, -0.0223,
            -0.0028, -0.0120, -0.0060, -0.0063,  0.0237,  0.0142,  0.0044, -0.0005,
             0.0349, -0.0132,  0.0138, -0.0295, -0.0299,  0.0074,  0.0231,  0.0292,
            -0.0178,  0.0046,  0.0043, -0.0195,  0.0175, -0.0069,  0.0228,  0.0169,
             0.0339,  0.0245, -0.0326, -0.0260, -0.0029,  0.0028,  0.0322, -0.0209,
            -0.0287,  0.0195,  0.0188,  0.0261,  0.0148, -0.0195, -0.0094, -0.0294,
            -0.0209, -0.0142,  0.0131,  0.0273,  0.0017,  0.0219,  0.0187,  0.0161,
             0.0203,  0.0332,  0.0225,  0.0154,  0.0169, -0.0346, -0.0114,  0.0277,
             0.0292, -0.0164,  0.0001, -0.0299, -0.0076, -0.0128, -0.0076, -0.0080,
            -0.0209, -0.0194, -0.0143,  0.0292, -0.0316, -0.0188, -0.0052,  0.0013,
            -0.0247,  0.0352, -0.0253, -0.0306,  0.0035, -0.0253,  0.0167, -0.0260,
            -0.0179, -0.0342,  0.0033, -0.0287, -0.0272,  0.0238,  0.0323,  0.0108,
             0.0097,  0.0219,  0.0111,  0.0208, -0.0279,  0.0324, -0.0325, -0.0166,
            -0.0010, -0.0007,  0.0298,  0.0329,  0.0012, -0.0073, -0.0010,  0.0057],
           requires_grad=True) torch.Size([128])
    hidden2.weight Parameter containing:
    tensor([[-0.0383, -0.0649,  0.0665,  ..., -0.0312,  0.0394, -0.0801],
            [-0.0189, -0.0342,  0.0431,  ..., -0.0321,  0.0072,  0.0367],
            [ 0.0289,  0.0780,  0.0496,  ...,  0.0018, -0.0604, -0.0156],
            ...,
            [-0.0360,  0.0394, -0.0615,  ...,  0.0233, -0.0536, -0.0266],
            [ 0.0416,  0.0082, -0.0345,  ...,  0.0808, -0.0308, -0.0403],
            [-0.0477,  0.0136, -0.0408,  ...,  0.0180, -0.0316, -0.0782]],
           requires_grad=True) torch.Size([256, 128])
    hidden2.bias Parameter containing:
    tensor([-0.0694, -0.0363, -0.0178,  0.0206, -0.0875, -0.0876, -0.0369, -0.0386,
             0.0642, -0.0738, -0.0017, -0.0243, -0.0054,  0.0757, -0.0254,  0.0050,
             0.0519, -0.0695,  0.0318, -0.0042, -0.0189, -0.0263, -0.0627, -0.0691,
             0.0713, -0.0696, -0.0672,  0.0297,  0.0102,  0.0040,  0.0830,  0.0214,
             0.0714,  0.0327, -0.0582, -0.0354,  0.0621,  0.0475,  0.0490,  0.0331,
            -0.0111, -0.0469, -0.0695, -0.0062, -0.0432, -0.0132, -0.0856, -0.0219,
            -0.0185, -0.0517,  0.0017, -0.0788, -0.0403,  0.0039,  0.0544, -0.0496,
             0.0588, -0.0068,  0.0496,  0.0588, -0.0100,  0.0731,  0.0071, -0.0155,
            -0.0872, -0.0504,  0.0499,  0.0628, -0.0057,  0.0530, -0.0518, -0.0049,
             0.0767,  0.0743,  0.0748, -0.0438,  0.0235, -0.0809,  0.0140, -0.0374,
             0.0615, -0.0177,  0.0061, -0.0013, -0.0138, -0.0750, -0.0550,  0.0732,
             0.0050,  0.0778,  0.0415,  0.0487,  0.0522,  0.0867, -0.0255, -0.0264,
             0.0829,  0.0599,  0.0194,  0.0831, -0.0562,  0.0487, -0.0411,  0.0237,
             0.0347, -0.0194, -0.0560, -0.0562, -0.0076,  0.0459, -0.0477,  0.0345,
            -0.0575, -0.0005,  0.0174,  0.0855, -0.0257, -0.0279, -0.0348, -0.0114,
            -0.0823, -0.0075, -0.0524,  0.0331,  0.0387, -0.0575,  0.0068, -0.0590,
            -0.0101, -0.0880, -0.0375,  0.0033, -0.0172, -0.0641, -0.0797,  0.0407,
             0.0741, -0.0041, -0.0608,  0.0672, -0.0464, -0.0716, -0.0191, -0.0645,
             0.0397,  0.0013,  0.0063,  0.0370,  0.0475, -0.0535,  0.0721, -0.0431,
             0.0053, -0.0568, -0.0228, -0.0260, -0.0784, -0.0148,  0.0229, -0.0095,
            -0.0040,  0.0025,  0.0781,  0.0140, -0.0561,  0.0384, -0.0011, -0.0366,
             0.0345,  0.0015,  0.0294, -0.0734, -0.0852, -0.0015, -0.0747, -0.0100,
             0.0801, -0.0739,  0.0611,  0.0536,  0.0298, -0.0097,  0.0017, -0.0398,
             0.0076, -0.0759, -0.0293,  0.0344, -0.0463, -0.0270,  0.0447,  0.0814,
            -0.0193, -0.0559,  0.0160,  0.0216, -0.0346,  0.0316,  0.0881, -0.0652,
            -0.0169,  0.0117, -0.0107, -0.0754, -0.0231, -0.0291,  0.0210,  0.0427,
             0.0418,  0.0040,  0.0762,  0.0645, -0.0368, -0.0229, -0.0569, -0.0881,
            -0.0660,  0.0297,  0.0433, -0.0777,  0.0212, -0.0601,  0.0795, -0.0511,
            -0.0634,  0.0720,  0.0016,  0.0693, -0.0547, -0.0652, -0.0480,  0.0759,
             0.0194, -0.0328, -0.0211, -0.0025, -0.0055, -0.0157,  0.0817,  0.0030,
             0.0310, -0.0735,  0.0160, -0.0368,  0.0528, -0.0675, -0.0083, -0.0427,
            -0.0872,  0.0699,  0.0795, -0.0738, -0.0639,  0.0350,  0.0114,  0.0303],
           requires_grad=True) torch.Size([256])
    out.weight Parameter containing:
    tensor([[ 0.0232, -0.0571,  0.0439,  ..., -0.0417, -0.0237,  0.0183],
            [ 0.0210,  0.0607,  0.0277,  ..., -0.0015,  0.0571,  0.0502],
            [ 0.0297, -0.0393,  0.0616,  ...,  0.0131, -0.0163, -0.0239],
            ...,
            [ 0.0416,  0.0309, -0.0441,  ..., -0.0493,  0.0284, -0.0230],
            [ 0.0404, -0.0564,  0.0442,  ..., -0.0271, -0.0526, -0.0554],
            [-0.0404, -0.0049, -0.0256,  ..., -0.0262, -0.0130,  0.0057]],
           requires_grad=True) torch.Size([10, 256])
    out.bias Parameter containing:
    tensor([-0.0536,  0.0007,  0.0227, -0.0072, -0.0168, -0.0125, -0.0207, -0.0558,
             0.0579, -0.0439], requires_grad=True) torch.Size([10])
  • 使用TensorDataset和DataLoader来简化

    from torch.utils.data import TensorDataset
    from torch.utils.data import DataLoader
    
    train_ds = TensorDataset(x_train, y_train)
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
    
    valid_ds = TensorDataset(x_valid, y_valid)
    valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
    def get_data(train_ds, valid_ds, bs):
        return (
            DataLoader(train_ds, batch_size=bs, shuffle=True),
            DataLoader(valid_ds, batch_size=bs * 2),
        )

  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
    import numpy as np
    
    def fit(steps, model, loss_func, opt, train_dl, valid_dl):
        for step in range(steps):
            model.train()
            for xb, yb in train_dl:
                loss_batch(model, loss_func, xb, yb, opt)
    
            model.eval()
            with torch.no_grad():
                losses, nums = zip(
                    *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
                )
            val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
            print('当前step:'+str(step), '验证集损失:'+str(val_loss))
    from torch import optim
    def get_model():
        model = Mnist_NN()
        return model, optim.SGD(model.parameters(), lr=0.001)
    def loss_batch(model, loss_func, xb, yb, opt=None):
        loss = loss_func(model(xb), yb)
    
        if opt is not None:
            loss.backward()
            opt.step()
            opt.zero_grad()
    
        return loss.item(), len(xb)

    三行搞定!

    train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
    model, opt = get_model()
    fit(25, model, loss_func, opt, train_dl, valid_dl)
    当前step:0 验证集损失:2.2796445930480957
    当前step:1 验证集损失:2.2440698066711424
    当前step:2 验证集损失:2.1889826164245605
    当前step:3 验证集损失:2.0985311767578123
    当前step:4 验证集损失:1.9517273582458496
    当前step:5 验证集损失:1.7341805934906005
    当前step:6 验证集损失:1.4719875366210937
    当前step:7 验证集损失:1.2273896869659424
    当前step:8 验证集损失:1.0362271406173706
    当前step:9 验证集损失:0.8963696184158325
    当前step:10 验证集损失:0.7927186088562012
    当前step:11 验证集损失:0.7141492074012756
    当前step:12 验证集损失:0.6529350900650024
    当前step:13 验证集损失:0.60417300491333
    当前step:14 验证集损失:0.5643046331882476
    当前step:15 验证集损失:0.5317994566917419
    当前step:16 验证集损失:0.5047958114624024
    当前step:17 验证集损失:0.4813900615692139
    当前step:18 验证集损失:0.4618900228500366
    当前step:19 验证集损失:0.4443243554592133
    当前step:20 验证集损失:0.4297310716629028
    当前step:21 验证集损失:0.416976597738266
    当前step:22 验证集损失:0.406348459148407
    当前step:23 验证集损失:0.3963301926612854
    当前step:24 验证集损失:0.38733808159828187
    
    
    https://gitee.com/code-wenjiahao/neural-network-practical-classification-and-regression-tasks/tree/master
    

     

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/979902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RESTful风格介绍

😜作 者:是江迪呀✒️本文关键词:HTTP、RESTFul、请求☀️每日 一言:我已经习惯这样的失望了——《英雄联盟》奥恩 一、前言 为了更好地设计和构建分布式系统和网络应用,诞生了RESTful 风格,它…

后端SpringBoot+前端Vue前后端分离的项目(二)

前言:完成一个列表,实现表头的切换,字段的筛选,排序,分页功能。 目录 一、数据库表的设计 ​编辑二、后端实现 环境配置 model层 mapper层 service层 service层单元测试 controller层 三、前端实现 interface接…

Acwing算法心得——现代艺术(统计遍历)

大家好,我是晴天学长,先用两个一维数组维护数据,再统计遍历二维数组,需要的小伙伴请自取哦!💪💪💪 1 )现代艺术 2) .算法思路 现代艺术 1.两个数组维护行和列 2.遍历数组…

命名空间的详讲

本篇文章旨在讲解C中命名空间的概念以及其相关注意事项! C的介绍 C作为C语言的衍生,其对C语言中的一些缺陷进行了一些的补充和优化。但是C也对C语言具有兼容性! 本文旨在讲解C对C语言中当声明的变量与库函数的一些标识符,关键字…

2023年MySQL实战核心技术第二篇

目录 五 . 日志系统:一条SQL更新语句是如何执行的? 5.1 解释 5.2 重要的日志模块:redo log 5.2.1 解释 5.2.2 WAL(Write-Ahead Logging) 5.2.3 crash-safe。 5.3 重要的日志模块:binlog 5.3 .1 为什么会有…

【FPGA】通俗理解从VGA显示到HDMI显示

注:大部分参考内容来自“征途Pro《FPGA Verilog开发实战指南——基于Altera EP4CE10》2021.7.10(上)” 贴个下载地址: 野火FPGA-Altera-EP4CE10征途开发板_核心板 — 野火产品资料下载中心 文档 hdmi显示器驱动设计与验证 — …

10.1 直流电源的组成及各部分的作用

在电子电路及设备中,一般都需要稳定的直流电源供电。本章所介绍的直流电源为单相小功率电源,它将频率为 50 Hz 50\,\textrm {Hz} 50Hz、有效值为 220 V 220\,\textrm V 220V 的单相交流电压转换为幅值稳定、输出电流为几十安以下的直流电压。 单相交流…

机器学习训练,没有机器怎么办

google的cobal,免费提供15G显存。 https://colab.research.google.com/drive/

十五、MySQL(DCL)如何实现用户权限控制?

1、为什么要实现用户权限控制? 在日常工作中,会存在多个用户,为了避免某些用户对重要数据库进行“误操作”,从而导致严重后果,所以对用户进行权限控制是必须的。 2、常见的权限类型: ALL,ALL PRIVILEGES …

数字孪生产品:数字化时代的变革引擎

数字孪生技术,作为一项前沿的科技创新,正在不断改变我们的世界。它为各行各业的发展提供了无限的可能性,成为了当今数字化时代的一大亮点。数字孪生产品,作为数字孪生技术的具体应用,将在未来发挥越来越重要的作用。 数…

Linux命令之文件管理

Linux命令之文件管理 创建文件删除文件移动文件拷贝文件查看文件文件统计信息的查看文件内容的查看文件的权限文件权限的介绍和表示文件权限的改变 文件的类型 查找文件 创建文件 创建文件的话,一般使用touch命令 touch file1(文件名字)删除文件 删除文件的话&…

QT QToolBox控件使用详解

本文详细的介绍了QToolBox控件的各种操作,例如:新建界面、添加页签、索引设置当前项、获取当前项的索引、获取当前项窗口、获取索引值是int的窗口、移除索引值项、获取项的数量、获取指定索引值、设置索引项是否激活、获取索引值项是否激活、设置项的图标…

可靠的可视化监控平台应用在那些场景?

可视化监控平台是一种用户友好的工具,可以帮助用户实时监控IT设备的运行状态和网络流量,以及监测安全性和性能指标。它们通常采用图形化界面,使得用户能够直观地了解设备和网络的状态。 以下是一些可视化监控平台常见的应用场景:…

R7 7840H和i7 1360p选哪个 R77840H和i71360p对比

i71360P采用10nm工艺 最高睿频 5GHz 十核心 十六线程 三级缓存 18MB热设计功耗(TDP) 28W 支持最大内存 64GB 内存类型 DDR4 3200MHzDDR5 5200MHz集成显卡 Intel Iris Xe Graphics 选i7 1360p还是r7 7840h这些点很重要看过你就懂了 http://www.adiannao.cn/dy r7 7840h采用4nm…

Chrome扩展开发实战:网页图片抓取,打造专属自己的效率插件

🏆作者简介,黑夜开发者,CSDN领军人物,全栈领域优质创作者✌,CSDN博客专家,阿里云社区专家博主,2023年6月csdn上海赛道top4。 🏆数年电商行业从业经验,历任核心研发工程师…

软件测试框架的面试题讲解

主要对测试框架的面试题讲解。 1.测试一个杯子怎么写测试用例? 界面:杯子外观 安全性:杯子有没有毒或细菌 可靠性:杯子从不同高度落下的损坏程度;杯子放水放置12个小时或者24小时,是否漏水 可移植性&#…

Redis基础特性及应用练习-php

redis持久化(persistence) redis支持两种方式的持久化,可以单独使用或者结合起来使用。 第一种:RDB方式(redis默认的持久化方式) rdb方式的持久化是通过快照完成的,当符合一定条件时redis会自…

Ansible playbook简介与初步实战,实现批量机器应用下载与安装

一.Ansible playbook简介 playbook是ansible用于配置,部署,和管理被节点的剧本通过playbook的详细描述,执行其中的一些列tasks,可以让远端的主机达到预期的状态。playbook就像ansible控制器给被控节点列出的一系列to-do-list&…

webpack(四)plugin

定义 和loader的区别 loader:文件加载器,能够加载资源,并对这些文件进行一些处理,诸如编译、压缩等,最终一起打包到指定的文件中。plugin:赋予了webpack各种灵活的功能,例如打包优化、资源管理、环境变量注入等&…

【c++ debug】cmake编译报错 No such file or directory

1. 报错:error while loading shared libraries: libprotoc.so.24: cannot open shared object file: No such file or directory 问题原因:找不到动态库 解决方法:添加动态库路径 export LD_LIBRARY_PATH$LD_LIBRARY_PATH:/your/protobuf/l…