torch.nn.parameter 生成可更新的 tensor (requires_grad = True)

news2024/11/30 2:50:49

torch.nn.parameter 是 PyTorch 中的一种特殊类型的 tensor,它主要用于存储神经网络中的参数。这些参数可以被自动求导和被优化器自动更新。使用 torch.nn.Parameter 定义的tensor 会被自动添加到模型的参数列表中。

在这里插入图片描述
\quad
torch.nn.Parameter 是继承自 torch.Tensor 的子类,其主要作用是作为 nn.Module 中的可训练参数使用。它与 torch.Tensor 的区别就是 nn.Parameter 生成的 tensor 自动被认为是 module 的可训练参数,即加入到 parameter() 这个迭代器中去;而 module 中 非 nn.Parameter() 的普通tensor是不在 parameter 中的。
\quad
注意到,nn.Parameter 的对象的 requires_grad 属性的默认值是 True,即是可被训练的,这与 torth.Tensor 对象的默认值相反。
\quad
nn.Module 类中,也是使用 nn.Parameter 来对每一个 module 的参数进行初始化的。以nn.Linear为例:

import torch.nn.Parameter as Parameter

class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

可以看到, Linear 在初始化时,weights 和 bias 都是使用 torch.nn.Parameter() 来生成的:

self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.bias = torch.nn.Parameter(torch.Tensor(out_features))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/344571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vite+vue3搭建的工程热更新失效问题

前段时间开发新的项目,由于没有技术上的限制,所以选择了vitevue3ts来开发新的项目,一开始用vite来开发新项目过程挺顺利,确实比vue2webpack的项目高效些(为什么选择vite),但是过了一段时间后,不…

android组件化

学习流程:1.开源最佳实践:Android平台页面路由框架ARouter-阿里云开发者社区 (aliyun.com)2.中文ARouter使用API:https://github.com/alibaba/ARouter/blob/master/README_CN.md3.看当前文档后面的代码4.这是通俗易懂的文章:https…

使用 Nodejs、Express、Postgres、Docker 在 JavaScript 中构建 CRUD Rest API

让我们在 JavaScript 中创建一个 CRUD rest API,使用:节点.js表达续集Postgres码头工人码头工人组成介绍这是我们将要创建的应用程序架构的架构:我们将为基本的 CRUD 操作创建 5 个端点:创造阅读全部读一个更新删除我们将使用以下…

【H.264】码流解析 annexb vs avcc

H264码流解析及NALUAVCC和ANNEXB 前者是FLV容器、mp4 常用的。后者 是实时传输使用,所以是TS 一类的标准。VLC显示AVC1就是AVCC AVCC格式 也叫AVC1格式,MPEG-4格式,字节对齐,因此也叫Byte-Stream Format。用于mp4/flv/mkv, VideoToolbox。 – Annex-B格式 也叫MPEG-2 trans…

微信公众号扫码授权登录思路

引言 上学期研究了一下微信登录相关内容,也写了两三篇笔记,但是最后实际登录流程没有写,主要因为感觉功能完成有所欠缺,一直也没有好的思路;这两天我又看了看官方文档,重新构思了一下微信公众号登录相关的…

操作系统综合实验

实验目的 加深对进程概念理解,进一步认识进程并发执行掌握Linux系统的进程创建和终止操作掌握文件系统调用及文件标准子例程的编程方法掌握Linux下终端图形编程方法,能编写基于文本的图形界面掌握proc文件系统的使用 相关知识 Linux C编程中的头文件 …

知识点整合

⭐面试 自我介绍(优势岗位匹配度) 为什么来我们公司(对公司的了解) 讲讲做的项目(为什么这么做,思路和贡献) 跨部门涉案的业务流程 我们跨部门涉案业务主要是本系统、配合物流系统和罚没系…

二战字节跳动成功上岸,准备了小半年,拿27k也算不上很高吧~

先说下我基本情况,本科不是计算机专业,现在是学通信,然后做图像处理,可能面试官看我不是科班出身没有问太多计算机相关的问题,因为第一次找工作,字节的游戏专场又是最早开始的,就投递了&#xf…

SpringMVC传值

实现步骤 先看后台代码如何获取前端传过来的数据,直接在方法的参数列表中添加RequestParam(xxx),再在后面加上参数列表即可 不过这样的方式会出现一个问题:当前端页面没有提交相应的数据过来的时候,后台会出现异常,所…

Elasticsearch7.8.0版本进阶——数据读流程

目录一、数据读流程概述二、数据读流程步骤2.1、数据读流程图2.2、数据读流程步骤(从主分片或者副本分片检索文档的步骤顺序)2.3、数据读流程注意事项一、数据读流程概述 从主分片或者从其它任意副本分片检索文档。 二、数据读流程步骤 2.1、数据读流…

5_机试_递归和分治

一、递归 本章介绍程序设计中的另一个非常重要的思想一递归策略。递归是指函数直接或间接调用自身的一种方法,它通常可把一个复杂的大型问题层层转化为与原问题相似但规模较小的问题来求解。递归策略只需少量的程序就可描述解题过程所需的多次重复计算,…

谈谈Java多线程离不开的AQS

如果你想深入研究Java并发的话,那么AQS一定是绕不开的一块知识点,Java并发包很多的同步工具类底层都是基于AQS来实现的,比如我们工作中经常用的Lock工具ReentrantLock、栅栏CountDownLatch、信号量Semaphore等,而且关于AQS的知识点…

DDR4 信号说明

SDRAM Differential Clock :Differential clocks signal pairs , pair perrank . The crossing of the positive edgeand the negative edge of theircomplement are used to sample thecommand and control signals on theSDRAMSDRAM差分时钟:差分时钟信号对&#…

MagicThoughts|让ChatGPT变得更智能的Finetuned数据集

近两个月,ChatGPT无疑都是AI领域最炙手可热的话题。而它的成功,也引发了行业内外对于对话式AI、LLM模型商业化应用可能性的思考。诚然,尽管就目前来看ChatGPT对大部分问答都能基本做到“对答如流”。但是,ChatGPT本质上依旧是预训…

Flutter Modul集成到IOS项目

Flutter Modul集成到IOS项目中1. 创建一个Flutter Modul2.在既有应用中集成Flutter Modul2.1 Flutter的构建模式选择2.1.1 debug模式2.1.2 Release模式2.1.3 Podfile 模式2.2 Cocoapods管理依赖库集成方式2.3 直接在Xcode中集成framework2.4 Local Network Privacy Permissions…

采用 spring 配置文件管理Bean

文章目录采用 spring 配置文件管理Bean一、安装配置Maven二、Spring 框架1、Spring 官网三、Spring 容器演示-采用Spring配置文件管理Bean1、创建Manev项目2、添加Spring依赖3、创建杀龙骑士类4、创建勇敢骑士类5、采用传统方式让勇敢骑士完成杀龙任务6、采用Spring 容器让勇敢…

创建Ubuntu虚拟机与Windows共享的文件夹

目录 1、Windows创建一个共享文件夹 2、在虚拟机的设置中选择Windows下的共享文件夹 3、在Ubuntu中查看共享文件夹 1、Windows创建一个共享文件夹 该共享文件夹可以被Windows和Ubuntu访问,需要注意的是,Ubuntu在共享目录下的一些操作会受到限制&…

图解经典电路之OCL差分功放-三极管分立器件电路分析

下面从简到繁,从框架到细节的顺序讲解电路。即先讲框架,然后逐渐添加电路细节,所以大家跟上思路。 1、第一步,尽可能的抽象这个电路,等效如下: 图二 OCL等效电路 整个OCL电路,可以等效为一个大功率的运放,加上几个电阻电容构成了一个同向放大器,就是这么简单。 为了便…

Linux常用命令---系统常用命令

Linux系统常用命令场景一: 查看当前系统内核版本相关信息场景二: sosreport 命令场景三: 如何定位并确定命令?场景四:查看当前系统运行负载怎场景五: 查看当前系统的内存可用情况场景六:查看网卡…

【DOTA】目标检测数据集介绍与使用

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 DOTA 数据集简单介绍 1. 正文 1.1 简介 数据集包含来自不同的传感器和平台的航拍图。每张图像的像素尺寸在 800 800 到 20,000 20,000 之间&#xf…