PyTorch典型函数之gather

news2024/10/5 7:30:57

PyTorch典型函数之gather

  • 作用描述
  • 函数详解
  • 典型应用场景
    • (1) 深度强化学习中计算损失函数
  • 参考链接

作用描述

图解torch.gather
如上图所示,假如我们有一个Tensor A(图左),要从A中提取一部分元素组成Tensor B(图右),这时可以用torch.gather来实现:

>>> import torch
>>> t1 = torch.Tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
>>> t2 = torch.gather(t1, 1, torch.tensor([[3,3],[0,2],[0,1]]))
>>> print(t2)
tensor([[ 4.,  4.],
        [ 5.,  7.],
        [ 9., 10.]])

图中每个方块代表一个值,图中数字代表这个值在该行中的序号,这里以dim=1,即按行提取为例。

对于二维Tensor而言,dim=0为按列提取,dim=1为按行提取。

函数详解

官网描述:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

根据dim参数指定的轴来收集值。对于一个三维Tensor:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
  • 输入Tensor(input)和索引Tensor(index)必须维数一样。比如input给一个矩阵,index给个一维向量PyTorch就不知道要怎么办了。
  • 对于所有d != dim的维数d,需要满足index.size(d) <= input.size(d) 。(原文It is also required that index.size(d) <= input.size(d) for all dimensions d != dim.)
  • 输出Tensor和索引Tensor具有相同的形状。
  • 输入Tensor(input)和索引张量不会互相广播。(原文Note that input and index do not broadcast against each other.)

参数

  • input (Tensor) - the source tensor
  • dim (int) - the axis along which to index
  • index (LongTensor) - the indices of elements to gather

参数名传参

  • sparse_grad (bool, optional) - If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensor, optional) - the destination tensor

此外,下面两种用法等价:

input_tensor = torch.Tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
index_tensor = torch.tensor([[3,3],[0,2],[0,1]])

# 方法1
t1 = torch.gather(input_tensor , 1, index_tensor)

# 方法2
t2 = input_tensor.gather(1, index_tensor)

典型应用场景

(1) 深度强化学习中计算损失函数

在深度Q-network方法中,需要构建Q-network,并从经验区进行采样,根据采样计算损失函数并更新Q-network。

采样信息包括当前环境观测值(当前状态)和当前实际采取的行动。

之后根据当前环境观测值,通过Q-network计算各行为对应的Q值。

接下来用gather函数从各行为对应的Q值根据实际采取的行动提取其对应的Q值。

最后结合(1)根据实际行为计算出的当前状态Q值和(2)根据下一个环境观测值计算出的Q值进行MSELoss计算。

对应代码如下:

def calc_loss(batch, net, tgt_net, device='cpu'):
    states, actions, rewards, dones, next_states = batch

    states_v = torch.tensor(np.array(states, copy=False)).to(device) # 当前环境观察
    next_states_v = torch.tensor(np.array(next_states, copy=False)).to(device)  # 下一刻环境观察
    actions_v = torch.tensor(actions, dtype=torch.int64).to(device) # 当前采取的行动
    rewards_v = torch.tensor(rewards).to(device) # 采取当前行动后的奖励值
    done_mask = torch.BoolTensor(dones).to(device)

    # net(states_v)产生在输入环境为states_v情况下,各行动对应的Q值
    # 从net(states_v)中提取实际选择的行动对应的Q值,用于后面和Q值公式计算出的Q值期望计算MSELoss
    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    next_state_values = tgt_net(next_states_v).max(1)[0]
    next_state_values[done_mask] = 0.0
    next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * GAMMA + rewards_v

    return nn.MSELoss()(state_action_values, expected_state_action_values)

参考链接

  1. torch.gather — PyTorch 2.0 documentation
  2. Deep-Reinforcement-Learning-Hands-On-Second-Edition

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/508635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

7.外观模式C++用法示例

外观模式 一.外观模式1.原理2.特点3.外观模式与装饰器模式的异同4.应用场景C程序示例 一.外观模式 外观模式&#xff08;Facade Pattern&#xff09;是一种结构型设计模式&#xff0c;它提供了一个简单的接口&#xff0c;隐藏了一个或多个复杂的子系统的复杂性&#xff0c;并使…

图嵌入表示学习—Node Embeddings随机游走

Random Walk Approaches for Node Embeddings 一、随机游走基本概念 想象一个醉汉在图中随机的行走&#xff0c;其中走过的节点路径就是一个随机游走序列。 随机行走可以采取不同的策略&#xff0c;如行走的方向、每次行走的长度等。 二、图机器学习与NLP的关系 从图与NLP的…

posix线程的优先级测试

如果创建的线程不够多&#xff0c;有些问题是体现不出来的。 优先级打印&#xff1a; 测试目的&#xff1a;输出三种调度模式下的最大优先级和最小优先级 #include <stdio.h> #include <sys/socket.h> #include <sys/types.h> #include <fcntl.h> #…

Kubernetes_容器网络_01_Docker网络原理(二)

文章目录 一、前言二、被隔离的Docker容器三、网桥Bridge四、VethPair网络对五、统一宿主机上的两个Container容器通信六、宿主机访问其上的容器七、宿主机上的容器访问另一个宿主机八、尾声 一、前言 二、被隔离的Docker容器 Linux 网络&#xff0c;就包括&#xff1a;网卡&…

技术选型对比- RPC(Feign VS Dubbo)

协议 Dubbo 支持多传输协议: Dubbo、Rmi、http,可灵活配置。默认的Dubbo协议&#xff1a;利用Netty&#xff0c;TCP传输&#xff0c;单一、异步、长连接&#xff0c;适合数据量小(传送数据小&#xff0c;不然影响带宽&#xff0c;响应速度)、高并发和服务提供者远远少于消费者…

UnityWebGL+阿里云服务器+Apache完成项目搭建展示

一、服务器相关 Step1:租借一台阿里云服务器 我自己租借了一台北京的ECS服务器&#xff0c;有免费一年的活动&#xff0c;1 vCPU 2 GiB&#xff0c;我自己选择的Ubuntu系统&#xff0c;也可以选择Windows系统 Step2:进入远程连接 进入自己的服务器实例后&#xff0c;点击远程…

vue+elementui+nodejs机票航空飞机航班查询与推荐

语言 node.js 框架&#xff1a;Express 前端:Vue.js 数据库&#xff1a;mysql 数据库工具&#xff1a;Navicat 开发软件&#xff1a;VScode )本系统主要是为旅客提供更为便利的机票预定方式&#xff0c;同时提高民航的预定机票的工作效率。通过网络平台实现信息化和网络化&am…

关于Android的性能优化,主要是针对哪些方面的问题进行优化

前言 我们在开发Android的时候&#xff0c;经常会遇到一些性能问题&#xff1b;例如&#xff1a;卡顿、无响应&#xff0c;崩溃等&#xff0c;当然&#xff0c;这些问题为我们可以从日志来进行追踪&#xff0c;尽可能避免此类问题的发生&#xff0c;要解决这些问题&#xff0c…

mysql从零开始(05)----锁

全局锁 使用 # 启用全局锁 flush tables with read lock # 释放全局锁 unlock tables开启全局锁后&#xff0c;整个数据库就处于只读状态了&#xff0c;这种状态下&#xff0c;对数据的增删改操作、对表结构的更改操作都会被阻塞。 另外&#xff0c;当会话断开&#xff0c;全…

【1015. 可被 K 整除的最小整数】

来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 给定正整数 k &#xff0c;你需要找出可以被 k 整除的、仅包含数字 1 的最 小 正整数 n 的长度。 返回 n 的长度。如果不存在这样的 n &#xff0c;就返回 -1。 注意&#xff1a; n 不符合 64 位带…

手把手教你在winform中将文本或文件路径拖到控件中

文章目录 前言博主履历介绍&#xff1a;一、将txt文件的所有内容复制到 RichTextBox中二、将txt文件的一行内容移动到RichTextBox中三、将多个文件的全路径复制到 RichTextBox中四 、源码1、[Winform从入门到精通&#xff08;1&#xff09;——&#xff08;如何年入30万&#x…

「MIAOYUN」:降本增效,赋能传统企业数字化云原生转型 | 36kr 项目精选

作为新经济综合服务平台第一品牌&#xff0c;36氪自2019年落地四川站以来&#xff0c;不断通过新锐、深度的商业报道&#xff0c;陪跑、支持四川的新经济产业。通过挖掘本土优质项目&#xff0c;36氪四川帮助企业链接更多资源&#xff0c;助力企业成长&#xff0c;促进行业发展…

分布式系统概念和设计——命名服务设计和落地经验

分布式系统概念和设计 通过命名服务&#xff0c;客户进程可以根据名字获取资源或对象的地址等属性。 被命名的实体可以是多种类型&#xff0c;并且可由不同的服务管理。 命名服务 命名是一个分布式系统中的非常基础的问题&#xff0c;名字在分布式系统中代表了广泛的资源&#…

C语言:指针求解鸡兔同笼问题

题目&#xff1a;鸡兔同笼问题 要求&#xff1a;使用自定义函数void calc(int h, int f,int *c,int *r) 求解鸡兔同笼问题。 h 表示总的头数&#xff0c;f 表示总的脚数。 例子&#xff1a; 输入&#xff1a; 5 16 输出&#xff1a; 2 3 分析&#xff1a; 在该代码中&a…

05-Docker安装Mysql、Redis、Tomcat

Docker 安装 Mysql 以安装 Mysql 5.7为例&#xff1a; docker pull mysql:5.7Mysql 单机 Mysql 5.7安装 启动 Mysql 容器&#xff0c;并配置容器卷映射&#xff1a; docker run -d -p 3306:3306 \--privilegedtrue \-v /app/mysql/log:/var/log/mysql \-v /app/mysql/data:…

ASP.NET Core MVC 从入门到精通之文件上传

随着技术的发展&#xff0c;ASP.NET Core MVC也推出了好长时间&#xff0c;经过不断的版本更新迭代&#xff0c;已经越来越完善&#xff0c;本系列文章主要讲解ASP.NET Core MVC开发B/S系统过程中所涉及到的相关内容&#xff0c;适用于初学者&#xff0c;在校毕业生&#xff0c…

VMware NSX-T Data Center 3.2.2.1 - 数据中心网络全栈虚拟化

请访问原文链接&#xff1a;https://sysin.org/blog/vmware-nsx-t-3/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org VMware NSX-T Data Center 3.2.2.1 | 30 MAR 2023 | Build 21487560 VMware NSX-T Data Center 3.2.2 | 08 …

NOA上车「清一色」自主品牌,哪些供应商正在突围前线

随着入门级L2进入普及周期&#xff0c;以NOA&#xff08;高速、城区&#xff09;为代表的L2/L2赛道&#xff0c;正在成为主机厂、硬件供应商、算法及软件方案商的下一波市场制高点的争夺阵地。 高工智能汽车研究院监测数据显示&#xff0c;2023年1-3月中国市场&#xff08;不含…

MySQL基础(十六)变量、流程控制与游标

1. 变量 在MySQL数据库的存储过程和函数中&#xff0c;可以使用变量来存储查询或计算的中间结果数据&#xff0c;或者输出最终的结果数据。 在 MySQL 数据库中&#xff0c;变量分为系统变量以及用户自定义变量。 1.1 系统变量 1.1.1 系统变量分类 变量由系统定义&#xff…

【Nacos在derby模式下密码忘记】使用derby的ij工具重置密码/修改密码

【问题描述】 nacos部署未用mysql,直接运行&#xff0c;使用了默认的derby数据库&#xff0c;这时候不一小心修改的密码给忘记了&#xff0c;无法登录 当时是部署在centos上的一个演示环境&#xff0c;没有采用mysql数据库&#xff0c;如果生产上&#xff0c;建议使用mysql。 …