【动手学深度学习】关于“softmax回归的简单实现”报错的解决办法(含源代码)

news2025/1/13 11:44:54

目录:关于“softmax回归的简单实现”报错的解决办法

  • 一、前言
  • 二、实现步骤
    • 2.1 导包
    • 2.2 初始化模型参数
    • 2.3 重新审视Softmax的实现
    • 2.4 优化算法
    • 2.5 训练
    • 2.6 源代码
  • 三、问题出现
  • 四、问题的解决
  • 五、再跑代码
  • 六、改正后的函数源代码

一、前言

在之前的学习中,我们发现通过深度学习框架的高级API能够使实现线性回归变得更加容易。

同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。

本节继续使用Fashion-MNIST数据集,并保持批量大小为256。

二、实现步骤

2.1 导包

import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

2.2 初始化模型参数

softmax回归的输出层是一个全连接层。 因此,为了实现我们的模型, 我们只需在Sequential中添加一个带有10个输出的全连接层。 同样,在这里Sequential并不是必要的, 但它是实现深度模型的基础。 我们仍然以均值0和标准差0.01随机初始化权重。

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

2.3 重新审视Softmax的实现

我们计算了模型的输出,然后将此输出送入交叉熵损失。 从数学上讲,这是一件完全合理的事情。 然而,从计算角度来看,指数可能会造成数值稳定性问题。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
我们也希望保留传统的softmax函数,以备我们需要评估通过模型输出的概率。 但是,我们没有将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的预测,并同时计算softmax及其对数, 这是一种类似”LogSumExp技巧”的聪明方式。

loss = nn.CrossEntropyLoss(reduction='none')

2.4 优化算法

在这里,我们使用学习率为0.1的小批量随机梯度下降作为优化算法。 这与我们在线性回归例子中的相同,这说明了优化器的普适性。

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

2.5 训练

接下来我们调用上一节中定义的训练函数来训练模型:

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.6 源代码

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)


net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

三、问题出现

我们根据上面的过程,尝试运行,结果出现报错:

在这里插入图片描述

Traceback (most recent call last):
  File "d:\Code Project\15.动手学深度学习代码手撸\softmax_2.py", line 21, in <module>
    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
  File "D:\Anaconda\envs\PyTorch\lib\site-packages\d2l\torch.py", line 324, in train_ch3
    train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
  File "D:\Anaconda\envs\PyTorch\lib\site-packages\d2l\torch.py", line 257, in train_epoch_ch3
    l.backward()
  File "D:\Anaconda\envs\PyTorch\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\Anaconda\envs\PyTorch\lib\site-packages\torch\autograd\__init__.py", line 166, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  File "D:\Anaconda\envs\PyTorch\lib\site-packages\torch\autograd\__init__.py", line 67, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

四、问题的解决

对源码进行修改:

在这里插入图片描述

五、再跑代码

在这里插入图片描述
依旧是以动图的形式展示!

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/15.动手学深度学习代码手撸/softmax_2.py"
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>
<Figure size 1500x900 with 1 Axes>

六、改正后的函数源代码

def train_epoch_ch3(net, train_iter, loss, updater):
    """The training loop defined in Chapter 3."""
    # Set the model to training mode
    if isinstance(net, torch.nn.Module):
        net.train()
    # Sum of training loss, sum of training accuracy, no. of examples
    metric = Accumulator(3)
    for X, y in train_iter:
        # Compute gradients and update parameters
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # Using PyTorch in-built optimizer & loss criterion
            updater.zero_grad()
            l.mean().backward()
            updater.step()
            # metric.add(float(l) * len(y), accuracy(y_hat, y),y.size().numel())
        else:
            # Using custom built optimizer & loss criterion
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # Return training loss and training accuracy
    return metric[0] / metric[2], metric[1] / metric[2]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/24535.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

题库系统(公众号免费调用)

题库系统(公众号免费调用) 本平台优点&#xff1a; 多题库查题、独立后台、响应速度快、全网平台可查、功能最全&#xff01; 1.想要给自己的公众号获得查题接口&#xff0c;只需要两步&#xff01; 2.题库&#xff1a; 查题校园题库&#xff1a;查题校园题库后台&#xff0…

arthas 源码构建

arthas 源码构建 git下载代码 git clone https://github.com/alibaba/arthas.git 若github被墙&#xff0c;可以在gitee搜索下载 maven clean 可以在项目目录执行 mvn clean &#xff0c; ide可以执行界面执行 maven package 可以在项目目录执行mvn package 问题记录 ja…

四肽Suc-AAPD-对硝基苯胺,165174-58-3

粒酶B底物succ - aapd - pna。也被ICE劈开了。 编号: 177581中文名称: 四肽Suc-Ala-Ala-Pro-Asp-对硝基苯胺CAS号: 165174-58-3单字母: Suc-AAPD-pNA三字母: Suc-Ala-Ala-Pro-Asp-pNA氨基酸个数: 4分子式: C25H32O11N6平均分子量: 592.56精确分子量: 592.21等电点(PI): -pH7.0时…

新进场的獴哥健康、至真健康们,讲不出互联网医疗的新故事

文丨智能相对论 作者丨沈浪 曾几何时&#xff0c;互联网医疗风靡一时&#xff0c;现如今潮水退去&#xff0c;当市场回归理性&#xff0c;赛道竞争趋于同质化&#xff0c;一批互联网医疗企业正在试图通过讲好新故事&#xff0c;来拉开品牌与品牌之间的商业差距&#xff0c;寻…

ArcGIS Pro 转换Smart3D生成的倾斜3D模型数据osgb——创建集成网格场景图层包

最近在做Arcgis 批处理的一些工作&#xff0c;然后再学习Python的同时&#xff0c;偶然觉得arcgis Pro是个好东西呢&#xff1f;然后结合近期的Smart3D倾斜3D模型数据&#xff0c;是否可以在arcgis里查看呢&#xff1f;带着这样的疑问和好奇&#xff0c;开始了arcgis Pro的学习…

【408专项篇】C语言笔记-第四章(选择与循环)

第四章&#xff1a;选择、循环 第一节&#xff1a;选择if-else 1. 关系表达式与逻辑表达式 if-else的判断条件结果还是真与假&#xff0c;即1或0&#xff0c;一般是关系表达式或逻辑表达式。 算术运算符的优先级高于关系运算符&#xff0c;关系运算符的优先级高于逻辑与和逻…

Webpack 5 超详细解读(四)

31.proxy 代理设置 为什么开发阶段需要设置代理,在开发阶段,我们需要请求后端接口,但是一般后端接口地址和我们本地的不在同一个服务中提供,这时进行访问就会存在跨域的问题,所以我们需要对我们的请求进行转啊操作。模拟跨域请求代码如下: https://api.github.com/users…

高项 沟通管理论文

3个过程&#xff1a; 1&#xff0c;规划沟通管理&#xff1a;根据干系人的信息需要和要求及组织的可用资产情况&#xff0c;制订合适的项目沟通方式和计划的过程。 2&#xff0c;管理沟通&#xff1a;根据沟通管理计划,生成、收集、分发、储存、检索及最终处置项目信息的过程…

Spring boot 实践Rabbitmq消息防丢失

之前看很多网上大佬的防丢失的文章&#xff0c;文章中理论知识偏多&#xff0c;所以自己想着实践一下&#xff0c;实践过程中也踩了一些坑&#xff0c;因此写出了这篇文章。如果文章有误人子弟的地方&#xff0c;望在评论区指出。 导致消息出现丢失的原因 发送时失败&#xff…

295348-87-7,AF 594 Succinimidyl Ester可用于成像和流式细胞分析

理论分析&#xff1a; 中文名&#xff1a;AF 594活性酯 英文名&#xff1a;AF 594 Succinimidyl Ester&#xff0c;Alexa Fluor 594 NHS Ester&#xff0c;AF 594 NHS Ester CAS号&#xff1a;295348-87-7 化学式&#xff1a;C39H37N3O13S2 分子量&#xff1a;819.85 ex/em : 5…

招投标业务总结

最近接了一个招标投标的项目&#xff0c;开发完成后&#xff0c;整个招投标的流程也就理清楚了&#xff0c;简介一下业务过程。业务主流程可以分为3个阶段: 招标方建立招标项目发布招标公告投标人参与竞标招标方开标评选公示 系统可以划分为两套&#xff0c;一个是给招标方使…

若依框架解读(微服务版)——2.模块间的调用逻辑(ruoyi-api模块)(OpenFeign)

模块之间的关系 我们可以了解到一共有这么多服务&#xff0c;我们先启动这三个服务 其中rouyi–api模块是远程调用也就是提取出来的openfeign的接口 ruoyi–commom是通用工具模块 其他几个都是独立的服务 ruoyi-api模块 api模块当中有几个提取出来的OpenFeign的接口 分别为文件…

华为机试 - ABR 车路协同场景

目录 题目描述 输入描述 输出描述 用例 题目解析 算法源码 题目描述 数轴有两个点的序列 A{A1&#xff0c; A2, …, Am}和 B{B1, B2, ..., Bn}&#xff0c; Ai 和 Bj 均为正整数&#xff0c; A、 B 已经从小到大排好序&#xff0c; A、 B 均肯定不为空&#xff0c; 给定…

大数据培训课程Reduce Join案例实操

Reduce Join案例实操 1&#xff0e;需求 表4-4 订单数据表t_order idpidamount100101110020221003033100401410050251006036 表4-5 商品信息表t_product pidpname01小米02华为03格力将商品信息表中数据根据商品pid合并到订单数据表中。 表4-6 最终数据形式 idpnameamount…

2022我的前端面试总结

Webpack Proxy工作原理&#xff1f;为什么能解决跨域 1. 是什么 webpack proxy&#xff0c;即webpack提供的代理服务 基本行为就是接收客户端发送的请求后转发给其他服务器 其目的是为了便于开发者在开发模式下解决跨域问题&#xff08;浏览器安全策略限制&#xff09; 想…

盘点 | 跨平台桌面应用开发的5大主流框架

受益于开源技术的发展&#xff0c;以及响应快速开发的实际业务需求&#xff0c;跨平台开发不仅限于移动端跨平台&#xff0c;桌面端虽然在市场应用方面场景不像移动端那么丰富&#xff0c;但也有市场的需求。 相对于个人开发者而言&#xff0c;跨平台框架的使用&#xff0c;主…

Vue开发 提交后台,二维码,自定义

1. 修改title和图标 资源可以放在static下面&#xff0c;给一个小的&#xff1a; 直接再index里面改&#xff1a; 不生效&#xff0c;需要在 vue.config.js 中增加&#xff1a; module.exports {pwa: {iconPaths: {favicon32: logo.png,favicon16: logo.png,appleTouchIcon:…

阿里巴巴全新SpringCloud实战笔记(全彩版)GitHub狂揽70000标星

最近小编淘到一份宝贝&#xff01; 先看看目录&#xff1a; 这份手册真的非常全面&#xff0c;涵盖了所有SpringCloud所有的内容&#xff0c;限于文章篇幅原因&#xff0c;只能以截图的形式展示出来&#xff0c;有需要的小伙伴可以文末获取↓↓↓ 直接展示内容&#xff1a; …

react redux 状态管理

1.store store是一个状态管理容器&#xff0c;它通过createStore创建&#xff0c;createStore接收initialState和reducer两个参数。它暴露了4个api分别是&#xff1a; getState() dispatch(action) subscribe(listener) replaceReducer 前三个是比较常用的api&#xff0c;之…

葡萄糖-聚乙二醇-二茂铁Ferrocene-PEG-Glucose

葡萄糖-聚乙二醇-二茂铁Ferrocene-PEG-Glucose&#xff0c;二茂铁&#xff0c;是一种具有芳香族性质的有机过渡金属化合物&#xff0c;化学式为Fe(C5H5)2&#xff0c;常温下为橙黄色粉末&#xff0c;有樟脑气味。熔点172℃-174℃&#xff0c;沸点249℃&#xff0c;100℃以上能升…