pytorch-多分类实战之手写数字识别

news2024/10/5 20:25:07

目录

  • 1. 网络设计
  • 2. 代码实现
    • 2.1 网络代码
    • 2.2 train
  • 3. 完整代码

1. 网络设计

输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示:
在这里插入图片描述

2. 代码实现

2.1 网络代码

第一层将784降维到200,第二次使用200不降维,输出层200降维到10,每一层之后加一个激活函数relu,每一层都需要梯度信息所以requires_grad=True;
forward函数最后不要加softmax,因为后面CrossEntropyLoss中包含了softmax操作。
在这里插入图片描述

2.2 train

优化目标是w1、b1、w2、b2、w3、b3,使用SGD优化器,使用CrossEntropyLoss计算loss
在这里插入图片描述

3. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



w1, b1 = torch.randn(200, 784, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
         torch.zeros(10, requires_grad=True)

# torch.nn.init.kaiming_normal_(w1)
# torch.nn.init.kaiming_normal_(w2)
# torch.nn.init.kaiming_normal_(w3)


def forward(x):
    x = x@w1.t() + b1
    x = F.relu(x)
    x = x@w2.t() + b2
    x = F.relu(x)
    x = x@w3.t() + b3
    x = F.relu(x)
    return x



optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)

        logits = forward(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

如下图:
未使用torch.nn.init.kaiming_normal_(w1)初始化参数的情况,可以看出Loss在2.302585后就不下降了。
在这里插入图片描述
如下图:使用了torch.nn.init.kaiming_normal_(w1)初始化参数的情况下,Loss下降还是比较快的。
在这里插入图片描述
因此使用好的初始化参数对网络的训练起到至关重要的作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1590236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux的学习之路:7、yum与git

摘要 本章主要是说一下yum和git的操作 目录 摘要 一、什么是yum 二、yum三板斧 1、list 2、install 3、remove 三、怎么创建仓库 四、git三板斧 1、add 2、commit 3、push 4、pull 五、思维导图 一、什么是yum YUM是Yellowdog Updater Modified的简称&#xf…

三方库移植之NAPI开发(三)通过IDE开发NAPI工程

在三方库移植之NAPI开发[1]—Hello OpenHarmony NAPI一文中,笔者开发的是一个rom包的napi工程。该工程需要编译烧录固件,C 的动态库会集成到开发板的ROM中。在本篇文章中,笔者使用三方库移植之NAPI开发[1]—Hello OpenHarmony NAPI中一样的he…

zabbix监控配置(添加主机、主机组和添加监控项等)

zabbix监控配置 文章目录 zabbix监控配置1.添加主机组2.添加主机(linux)3.添加主机(windows)4.监控项配置(通过模板添加)5.监控项配置(手动添加) 1.添加主机组 2.添加主机&#xff0…

【Github】PwGen用户友好的Web应用密码生成器

弱密码问题一直是网络安全领域的一个重大挑战。许多人为了方便记忆,倾向于使用简单、常见的密码,如“123456”、“password”或者他们的生日等,这些密码很容易被猜测或通过暴力破解方法攻破。弱密码的使用大大增加了账户被黑客入侵的风险&…

【深入解析spring cloud gateway】13 Reactive Feign的使用

问题引入 在gateway中如果使用feignClient的话,会报如下错误 java.lang.IllegalStateException: block()/blockFirst()/blockLast() are blocking, which is not supported in thread reactor-http-nio-3at reactor.core.publisher.BlockingSingleSubscriber.bloc…

C++实现一个自定义字符串类(string)

本博客将详细介绍如何在C中实现一个自定义的字符串类 string,这个类模仿了标准库中 std::string 的关键功能。这个过程将涵盖从声明到定义的每一步,重点介绍内存管理、操作符重载以及提供一些关键的实现细节。 首先:我们采用函数的声明与定义…

python--字符串对象和

1、找出10000以内能被5或6整除,但不能被两者同时整除的数(函数) def Divisible_by_5_6(x:int)->list:arr[]for i in range(1,x1):if (i % 5 0 or i % 6 0 ):if i % 5 0 and i % 6 0:continue #利用continue跳过能被5和6整除的数else:a…

Spring Boot 学习(4)——开发环境升级与项目 jdk 升级

各种版本都比较老,用起来也是常出各样的问题,终于找到一个看来不错的新教程,是原先那个教程的升级。遂决定升级一下开发环境,在升级遇到一些问题,摸索将其解决,得些体会记录备查。 最终确定开发环境约束如下…

ActiveMQ 01 消息中间件jmsMQ

消息中间件之ActiveMQ 01 什么是JMS MQ 全称:Java MessageService 中文:Java 消息服务。 JMS 是 Java 的一套 API 标准,最初的目的是为了使应用程序能够访问现有的 MOM 系 统(MOM 是 MessageOriented Middleware 的英文缩写&am…

QLoRa 低秩分解+权重量化的微调

QLoRa的核心思想是首先使用低秩分解技术降低参数的数量,然后对这些低秩表示的参数应用量化技术,进一步减少所需的存储空间和计算量。 https://arxiv.org/abs/2305.14314 低秩分解 低秩分解(Low-Rank Factorization):…

Elasticsearch初步了解学习记录

目录 前言 一、ElasticSearch是什么? 二、使用步骤(python版) 1.引入包 2.连接数据库 3.创建索引 4.写入数据 5.查询数据 三、相关工具介绍 1.ES浏览器插件 总结 前言 随着数据量的不断增加,传统的查询检索在速度上遇…

Android使用shape属性绘制边框内渐变色

目录 先上效果图实现方法shape属性介绍代码结果 先上效果图 这是使用AndroidStudio绘制的带有渐变色的边框背景色 实现方法 项目中由于UI设计需求,需要给按钮、控件设置带有背景色效果的。以下是UI效果图。 这里我们使用shape属性来绘制背景效果。 shape属性介…

机器学习-09-图像处理02-PIL+numpy+OpenCV实践

总结 本系列是机器学习课程的系列课程,主要介绍机器学习中图像处理技术。 参考 【人工智能】PythonOpenCV图像处理(一篇全) 一文讲解方向梯度直方图(hog) 【杂谈】计算机视觉在人脸图像领域的十几个大的应用方向&…

清远某国企IBM服务器Board故障上门维修

接到一台来自广东清远市清城区某水利大坝国企单位报修一台IBM System x3650 M4服务器无法开机的故障,分享给大家,同时也方便有需要的朋友能及时找到我们快速解决服务器问题。 故障服务器型号:IBM System x3650 M4 服务器使用单位:…

只要0.74元的双通道数字隔离器,1T1R,增强型ESD-3.0 kV ,150Kbps数字隔离器

前言: 做和电源打交道的设备,通信时非常危险,一定要使用隔离的USB-232转换器,或者你设备的串口与市电直连的设备通信时,现在需要使用数字隔离器,早期的一般都使用光耦,在这种情况下,速度不搞的…

正则表达式:量词(三)

正则表达式中的量词有以下几种:1. *: 匹配前面的字符0次或多次。2. : 匹配前面的字符1次或多次。3.?: 匹配前面的字符0次或1次。4. {n}: 匹配前面的字符恰好n次。5. {n,}: 匹配前面的字符至少n次。6. {n,m}:匹配前面的字符至少n次,但不超过m次。 以下是使用Python的…

Unity TextMeshProUGUI 获取文本尺寸·大小

一般使用ContentSizeFitter组件自动变更大小 API 渲染前 Vector2 GetPreferredValues(string text)Vector2 GetPreferredValues(string text, float width, float height)Vector2 GetPreferredValues(float width, float height) 渲染后 Vector2 GetRenderedValues()Vector…

idea 卡怎么办

设置内存大小 清缓存重启 idea显示内存全用情况 右下角

解决CSS中鼠标移入到某个元素其子元素被遮挡的问题

我们在开发中经常遇到一种场景,就是给元素加提示信息,就是鼠标移入到盒子上面时,会出现提示信息这一功能,如果我们给盒子加了hover,当鼠标移入到盒子上时,让他往上移动5px,即transform: transla…

网盘——搜索用户

目录 1、搜索用户 1.1、在friend.h里面定义槽函数 1.2、关联槽函数 1.3、搜索用户的时候,会弹出一个对话框来,在friend.cpp里面引入下面的头文件,专门用来输入数据的 1.4、获取输入信息,并使用Qstring来接收它 1.5、将上述代码打包&…