Pytorch神经网络-元组/列表如何喂到神经网络中

news2025/1/16 1:06:05

📚博客主页:knighthood2001

公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)

🎃知识星球:【认知up吧|成长|副业】介绍

❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️

🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

这篇文章适用于初学者,因为我就是这么过来的,最开始连网络中的参数为什么这么设置的都不知道。但是本文没有讲太多这方面的知识。

首先需要说明的是:

在PyTorch中,神经网络模型的输入通常需要是张量(tensor)类型。虽然你可以将元组作为输入传递给神经网络模型,但实际上在模型内部处理时,最终还是需要将数据转换为张量。

PyTorch的神经网络层(如nn.Linearnn.Conv2d等)的输入和输出都是张量。因此,在实际使用中,你需要确保将任何非张量类型的数据(如列表、元组等)转换为张量,以便能够在神经网络中进行计算。

经过我的一点点测试,发现好像pytorch中好像只能输入tensor格式的数据,而不能是tuple或者list数据类型。
在这里插入图片描述
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f1 = nn.Linear(2, 32)
        self.f2 = nn.Linear(32, 4)

    def forward(self, x):
        x = self.f1(x)
        x = F.relu(x)
        action = self.f2(x)
        return action

这里我定义了两层网络,简单实现一个神经网络。
接下来我定义了一些数据类型,比如a是元组,c是列表,b和d都是经过转换成tensor张量格式的数据。

net = Net()

a = (4, 4)
b = torch.FloatTensor(a)

print(b)
print(net(b))

c = [4, 4]
d = torch.FloatTensor(c)
print(d)
print(net(d))

在这里插入图片描述
可以发现他们经过转换后长得一样了,然后才能把他们喂到神经网络中,否则就会出现上面的报错。

从中我们可以看出来,b和d都是包含两个数据的一维张量。这里好像是因为PyTorch 在执行 net(b) 时会自动将 1 维的输入张量 b 视作大小为 (1, 2) 的 batch,并将其送入神经网络进行前向传播计算。

所以可能将其变为二维张量输入网络比较好(个人观点),如下。

如果你想要将 b 转换为二维张量,可以使用 b.view(1, -1) 方法,这样 b 就会变成一个包含一个行和两列的二维张量。同理,d也是这样处理。这里的-1表示的就是,系统会根据总的数据数以及其他维度需要的数量,然后计算出-1所在的维度的数量。

m = b.view(1, -1)
print(m)
print(net(m))

在这里插入图片描述
发现结果一样,就是反向传播的梯度下降函数不同。

我查询了一下:

在 PyTorch 中,ViewBackward0 和 AddmmBackward0 是两种不同类型的 Autograd
Function,用于反向传播计算梯度。

  1. ViewBackward0:

    • ViewBackward0 是View 操作的反向传播函数。View 操作用于改变张量的形状,但不改变张量的数据内容。ViewBackward0 的作用是将梯度传播回 View
      操作之前的张量,以便在反向传播过程中正确更新梯度。
    • ViewBackward0 的主要功能是处理从 View 操作反向传播回来的梯度,确保梯度在形状变换后能够正确传播并更新。
  2. AddmmBackward0:

    • AddmmBackward0 是 addmm 操作(矩阵相加和矩阵乘法)的反向传播函数。addmm 函数用于计算矩阵相加和矩阵乘法的结果。AddmmBackward0 的作用是计算 addmm 操作对输入张量的梯度。
    • AddmmBackward0 主要负责处理 addmm 操作的反向传播过程,根据输出的梯度计算输入张量的梯度,并将其传播到上游的节点。

总的来说,ViewBackward0 用于处理 View 操作的反向传播,而 AddmmBackward0 用于处理 addmm
操作的反向传播。它们都是 Autograd Function 的一部分,负责计算和传播梯度,以支持 PyTorch 的自动微分功能。

本文的全部代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f1 = nn.Linear(2, 32)
        self.f2 = nn.Linear(32, 4)

    def forward(self, x):
        x = self.f1(x)
        x = F.relu(x)
        action = self.f2(x)
        return action

net = Net()

a = (4, 4)
b = torch.FloatTensor(a)

print(b)
print(net(b))

m = b.view(1, -1)
print(m)
print(net(m))

c = [4, 4]
d = torch.FloatTensor(c)
print(d)
print(net(d))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1534435.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计编程网站集:生活部分:饮食+农业,植物(暂记)

这里写目录标题 植物相关综合教程**大型植物:****高大乔木(Trees):** 具有坚硬的木质茎,通常高度超过6米。例如,橡树、松树、榉树等。松树梧桐 **灌木(Shrubs):** 比乔木…

基于Jenkins + Argo 实现多集群的持续交付

作者:周靖峰,青云科技容器顾问,云原生爱好者,目前专注于 DevOps,云原生领域技术涉及 Kubernetes、KubeSphere、Argo。 前文概述 前面我们已经掌握了如何通过 Jenkins Argo CD 的方式实现单集群的持续交付&#xff0c…

基于Springboot的在线投稿系统+数据库+免费远程调试

项目介绍: Javaee项目,springboot项目。采用M(model)V(view)C(controller)三层体系结构,通过Spring SpringBoot Mybatis VueMavenLayui来实现。MySQL数据库作为系统数据储存平台&a…

Java安全 反序列化(3) CC1链-TransformedMap版

Java安全 反序列化(3) CC1链-TransformedMap版 本文尝试从CC1的挖掘思路出发,理解CC1的实现原理 文章目录 Java安全 反序列化(3) CC1链-TransformedMap版配置jdk版本和源代码配置前记 为什么可以利用一.CC链中的命令执行我们可以尝试一下通过InvokerTransformer.tr…

分布式异步任务框架celery

Celery介绍 github地址:GitHub - celery/celery: Distributed Task Queue (development branch) 文档地址:Celery - Distributed Task Queue — Celery 5.3.6 documentation 1.1 Celery是什么 celery时一个灵活且可靠的处理大量消息的分布式系统&…

数据库关系运算理论:传统的集合运算概念解析

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

如何在wps的excel表格里面使用动态gif图

1、新建excel表格,粘贴gif图到表格里面,鼠标右键选择超链接。 找到源文件, 鼠标放到图片上的时候,待有个小手图标,双击鼠标可以放大看到动态gif图。 这种方式需要确保链接的原始文件位置和名称不能变化!&a…

网工内推 | 云计算工程师,HCIE认证优先,最高18k*14薪

01 杭州中港科技有限公司 招聘岗位:云计算工程师 职责描述: 1、承担云计算相关工程交付、业务上云及售前测试,从事虚拟化、桌面云、存储、服务器、数据中心、大数据、相关产品的工程项目交付或协助项目交付。 2、承担云计算维护工程师职责&…

深入理解Mysql索引底层原理(看这一篇文章就够了)

目录 前言 1、Mysql 索引底层数据结构选型 1.1 哈希表(Hash) 1.2 二叉查找树(BST) 1.3 AVL 树和红黑树 1.4 B 树 1.5 B树 2、Innodb 引擎和 Myisam 引擎的实现 2.1 MyISAM 引擎的底层实现(非聚集索引方式) 2.2 Innodb 引…

L4 级自动驾驶汽车发展综述

摘要:为了减小交通事故概率、降低运营成本、提高运营效率,实现安全、环保的出行,自动驾驶 技术的发展已成为大势所趋,而搭配有L4 级自动驾驶系统的车辆是将车辆驾驶全部交给系统。据此,介绍了自动驾驶汽车的主流技术解决方案;分析了国内外L4 级自动驾驶汽车的已发布车型、…

Python 安装目录及虚拟环境详解

Python 安装目录 原文链接:https://blog.csdn.net/xhyue_0209/article/details/106661191 Python 虚拟环境 python 虚拟环境图解 python 虚拟环境配置与详情 原文链接:https://www.cnblogs.com/hhaostudy/p/17321646.html

C语言易错知识点:二级指针、数组指针、函数指针

指针在C语言中非常关键,除开一些常见的指针用法,还有一些可能会比较生疏,但有时却也必不可少,本文章整理了一些易错知识点,希望能有所帮助! 1.二级指针: parr是一个指针数组,其中每…

std::shared_ptr与std::make_unique在类函数中的使用

在最近学习cartographer算法的时候,发现源码中大量的使用了std::shared_ptr与std::make_unique,对于这些东西之前不是很了解,为了更好的理解源代码,因此简单学习了一下这块内容的使用,在这里简单记个笔记。 std::shar…

【热门话题】深入浅出:npm常用命令详解与实践

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 标题:深入浅出:npm常用命令详解与实践引言一、npm基本概…

打流仪/网络测试仪这个市场还能怎么卷?

#喝了点,码点字# 以下为个人观点,看看就好,如有冒犯,私信删稿 都有哪些厂商在做打流仪/网络测试仪 -洋品牌:思博伦/Viavi-Spirent,是德/Keysight-Ixia,信雅纳/Lecroy-Xena, -国产…

睿尔曼超轻量仿人机械臂之-灵巧手动作编写及程序调用

一、灵巧手动作编写 1.连接设备 2. 运动控制 3. 参数设置 4 动作库使用 本软件可以设置灵巧手内部第 1-第 13 套动作序列数据,每套动作序列最多能有 8 步 分解动作,每一步分解动作的手指角度、运动速度、力度以及等待时间都可以单独设置。 步骤数&…

QT_day2:2024/3/21

作业1:使用QT完成一个登录界面 要求: 1. 需要使用Ui界面文件进行界面设计 2. ui界面上的组件相关设置,通过代码实现 3. 需要添加适当的动图 源代码: #include "widget.h" #include "ui_widget.h"Widget…

力扣由浅至深 每日一题.06 删除有序数组中的重复项

希望我们都能对抗生活的苦难,在乌云周围突破阴霾积极的生活 —— 24.3.16 删除有序数组中的重复项 提示 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元…

贝尔曼方程【Bellman Equation】

强化学习笔记 主要基于b站西湖大学赵世钰老师的【强化学习的数学原理】课程,个人觉得赵老师的课件深入浅出,很适合入门. 第一章 强化学习基本概念 第二章 贝尔曼方程 文章目录 强化学习笔记一、状态值函数贝尔曼方程二、贝尔曼方程的向量形式三、动作值…

Windows系统部署GoLand结合内网穿透实现SSH远程Linux服务器开发调试

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-HIOuHATnug3qMHzx {font-family:"trebuchet ms",verdana,arial,sans-serif;f…