【pytorch20】多分类问题

news2024/11/27 18:39:30

网络结构以及示例

在这里插入图片描述
该网络的输出不是一层或两层的,而是一个十层的代表有十分类
在这里插入图片描述

新建三个线性层,每个线性层都有w和b的tensor

首先输入维度是784,第一个维度是ch_out,第二个维度才是ch_in(由于后面要转置),没有经过softmax函数和sigmoid,即logits

上图已经完成了网络的参数的定义和网络的前向传播过程

在这里插入图片描述
nn.CrossEntropyLoss()F.cross_entropy()是一样的功能,都包含softmax和log和F.nll_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

batch_size = 200
learning_rate = 0.01
epochs = 10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)

w1, b1 = torch.randn(200, 784, requires_grad=True), \
    torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True), \
    torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True), \
    torch.zeros(10, requires_grad=True)


def forward(x):
    x = x @ w1.t() + b1
    x = F.relu(x)
    x = x @ w2.t() + b2
    x = F.relu(x)
    x = x @ w3.t() + b3
    x = F.relu(x)
    return x


# train
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criten = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28 * 28)

        logits = forward(data)
        loss = criten(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criten(logits, target).item()

        # 每一行的最大值对应的索引
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

在这里插入图片描述
创建的loss一直不变,为什么会出现这个问题,这个网络的结构层次并不深,数据集也比较简单,但这里出现了梯度弥散的情况,因为loss长时间得不到更新,因为梯度信息几乎接近于0

为什么会出现梯度为0?
影响训练的因素,除了有loss,学习率过大,还有初始化的问题,把初始化代码加上

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

为什么b不初始化,因为已经初始化为0了

但是w也初始化,只是它们使用的是高斯分布进行初始化,即使是用高斯分布初始化后,结果也不满意,所以用了何凯明的初始化
在这里插入图片描述
可以看出loss直接到0.4了,准确率也达到了80%,而且这里还没运行完,运行完效果会更好

可以看出对于分类问题,初始化参数非常关键

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1903257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于java+springboot+vue实现的校园外卖服务系统(文末源码+Lw)292

摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,外卖信息因为其管理内容繁杂,管理数量繁多导致手工进行处理不能满足广…

11.常见的Bean后置处理器

CommonAnnotationBeanPostProcessor (Resource PostConstructor PreDestroy) AutowiredAnnotationBeanPostProcessor (Autowired Value) GenericApplicationContext是一个干净的容器,它没有添加任何的PostProcessor处理器。 调用GenericApplicationContext.refre…

MSPM0G3507——编码器控制速度

绿色设置的为目标值100,红色为编码器实际数据 。 最后也是两者合在了一起,PID调试成功。 源码直接分享,用的是CCStheia,KEIL打不开。大家可以看一下源码的思路,PID部分几乎不用改 链接:https://pan.baid…

CSS【详解】长度单位 ( px,%,em,rem,vw,vh,vmin,vmax,ex,ch )

px 像素 pixel 的缩写,即电子屏幕上的1个点,以分辨率为 1024 * 768 的屏幕为例,即水平方向上有 1024 个点,垂直方向上有 768 个点,则 width:1024px 即表示元素的宽度撑满整个屏幕。 随屏幕分辨率不同,1px …

HCIE之IPV6和OSPFv6(十四)

IPV6 1、IPv6基础1.1 Ipv6地址静态配置、Eui 641.1.1 Ipv6地址静态配置1.1.2、Ipv6地址计算总结1.1.2.1、IEEE eui 64计算1.1.2.1.1、作用1.1.2.1.2、计算方法1.1.2.1.3、计算过程 1.1.2.2、被请求加入的组播组地址计算(三层)1.1.2.2.1、 作用1.1.2.2.2、…

数据结构与算法笔记:实战篇 - 剖析微服务接口鉴权限流背后的数据结构和算法

概述 微服务是最近几年才兴起的概念。简单点将,就是把复杂的大应用,解耦成几个小的应用 。这样做的好处有很多。比如,这样有利于团队组织架构的拆分,比较团队越大协作的难度越大;再比如,每个应用都可以独立…

BAT-致敬精简

什么是bat bat是windows的批处理程序,可以批量完成一些操作,方便快速。 往往我们可以出通过 winR键来打开指令窗口,这里输入的就是bat指令 这里就是bat界面 节约时间就是珍爱生命--你能想象以下2分钟的操作,bat只需要1秒钟 我…

深入理解JS逆向代理与环境监测

博客文章:深入理解JS逆向代理与环境监测 1. 引言 首先要明确JavaScript(JS)在真实网页浏览器环境和Node.js环境中有很多使用特性的区别。尤其是在环境监测和对象原型链的检测方面。本文将探讨如何使用JS的代理(Proxy&#xff09…

分数的表示和运算方法fractions.Fraction()

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 分数的表示和运算方法 fractions.Fraction() 选择题 以下代码三次输出的结果分别是? from fractions import Fraction a Fraction(1, 4) print(【显示】a ,a) b Fraction(1, 2…

免费的鼠标连点器电脑版教程!官方正版!专业鼠标连点器用户分享教程!2024最新

电脑技术的不断发展,许多用户在日常工作和娱乐中,需要用到各种辅助工具来提升效率或简化操作,而电脑办公中,鼠标连点器作为一种能够模拟鼠标点击的软件,受到了广大用户的青睐。本文将为大家介绍一款官方正版的免费鼠标…

C++_STL---list

list的相关介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。 list的底层是带头双向循环链表结构,链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前一个元素和后一个元素。…

WAIC | 上海人形机器人创新中心 | 最新演讲 | 详细整理

前言 笔者看了7月4号的人形机器人与具身智能发展论坛的直播,并在7月5日到了上海WAIC展会现场参观。这次大会的举办很有意义,听并看了各家的最新成果,拍了很多照片视频,部分演讲也录屏了在重复观看学习 稍后会相继整理创立穹彻智…

使用RAID与LVM磁盘阵列技术

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 目录 一、RAID磁盘冗余阵列 1、部署磁盘整列 2、损坏磁盘阵列及修复 3、磁盘阵列备份盘 4、删除磁盘阵列 二、LVM逻辑卷管理器 致谢 一、RAID…

linux中可执行文件在运行过程中为什么不能拷贝覆盖

对于一个普通的文件,假如有两个文件,分别是file和file1,我们使用 cp file1 file的方式使用file1的内容来覆盖file的内容,这样是可以的。 但是对于可执行文件来说,当这个文件在执行的时候,是不能通过cp的方…

Python 算法交易实验76 QTV200日常推进

说明 最近实在太忙, 没太有空推进这个项目,我想还是尽量抽一点点时间推进具体的工程,然后更多的还是用碎片化的时间从整体上对qtv200进行设计完善。有些结构的问题其实是需要理清的,例如: 1 要先基于原始数据进行描述…

【ROS2】初级:客户端-编写一个简单的服务和客户端(Python)

目标:使用 Python 创建并运行服务节点和客户端节点。 教程级别:初学者 时间:20 分钟 目录 背景 先决条件 任务 1. 创建一个包2. 编写服务节点3. 编写客户端节点4. 构建并运行 摘要 下一步 相关内容 背景 当节点通过服务进行通信时&#xff0c…

【机器学习】机器学习重塑广告营销:精准触达,高效转化的未来之路

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀目录 📒1. 引言📙2. 机器学习基础与广告营销的结合🧩机器学习在广告营销中的核心应用领域🌹用…

将大型语言模型模块化打造协作智能体

B UILDING C OOPERATIVE E MBODIED A GENTS MODULARLY WITH L ARGE L ANGUAGE M ODELS 论文链接: https://arxiv.org/abs/2307.02485https://arxiv.org/abs/2307.02485 1.概述 在去中心化控制及多任务环境中,多智能体合作问题因原始感官观察、高昂…

穿梭印度风情记:维乐 Angel Revo Halo坐垫,让每一寸旅程闪耀光辉!

想象骑乘在印度的万花筒世界中,斑斓色彩与悠久历史交织,每一转轮都是对神秘东方的深刻探索。在这样的骑行之旅中,维乐Angel Revo Halo坐垫不仅是你的坐骑上的宝石,更是舒适与探险的完美媒介。    探索印度的色彩与灵魂&#x…

每日一题~oj(贪心)

对于位置 i来说,如果 不选她,那她的贡献是 vali-1 *2,如果选他 ,那么她的贡献是 ai. 每一个数的贡献 是基于前一个数的贡献 来计算的。只要保证这个数的前一个数的贡献是最优的,那么以此类推下去,整体的val…