TrustGeo代码理解(五)sublayers.py

news2025/1/23 8:01:35

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、导入模块

import torch
import torch.nn as nn
import torch.nn.functional as F

这段代码是一个简单的神经网络的定义,用于深度学习任务。

1、import torch:导入 PyTorch 库,提供张量(tensor)等深度学习操作的支持。

2、import torch.nn as nn:导入 PyTorch 中的神经网络模块,包括定义神经网络层的基本类。

3、import torch.nn.functional as F:导入 PyTorch 中的函数模块,包括一些激活函数、损失函数等。

二、ScaledDotProductAttention类定义(NN模型)

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        # attn[attn <= torch.quantile(attn, 0.8)] = 0
        # attn = torch.where(attn <= torch.mean(attn)*0.6, torch.full_like(attn, 0), attn)
        output = torch.matmul(attn, v)

        return output, attn

这段代码定义了一个 Scaled Dot-Product Attention 模块,这是 Transformer 模型中注意力机制的一部分。这个模块实现了 Scaled Dot-Product Attention 的计算,是 Transformer 模型中实现自注意力机制的关键组成部分。

分为几个部分展开描述:

(一)__init__()

def __init__(self, temperature, attn_dropout=0.1):
    super().__init__()
    self.temperature = temperature
    self.dropout = nn.Dropout(attn_dropout)

这是一个简单的自注意力(Self-Attention)模块的定义,其中包含了一个温度参数(temperature)和一个注意力丢弃率参数(attn_dropout)。主要用于实现一个简单的自注意力机制,其中包括对输入进行缩放(通过温度参数)以及应用注意力丢弃率。在实际应用中,这样的自注意力机制通常用于图神经网络等任务中,以捕捉输入序列中的重要信息。

1、def __init__(self, temperature, attn_dropout=0.1):这是类的构造函数,用于初始化SimpleAttention类的实例。参数包括temperatureattn_dropout,分别表示温度参数和注意力丢弃率参数。

2、super().__init__():调用父类的构造函数,确保正确地初始化继承自父类的属性。

3、self.temperature = temperature:将输入的temperature参数存储为类的属性,后续在注意力计算中使用。

4、self.dropout = nn.Dropout(attn_dropout):创建了一个 PyTorch 的 nn.Dropout 层,用于在注意力计算中应用丢弃率。attn_dropout 是一个可选参数,默认值为 0.1。

(二)forward()

def forward(self, q, k, v, mask=None):
    attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

    if mask is not None:
        attn = attn.masked_fill(mask == 0, -1e9)

    attn = self.dropout(F.softmax(attn, dim=-1))
    # attn[attn <= torch.quantile(attn, 0.8)] = 0
    # attn = torch.where(attn <= torch.mean(attn)*0.6, torch.full_like(attn, 0), attn)
    output = torch.matmul(attn, v)

    return output, attn

这是一个用于执行自注意力机制(Self-Attention)的前向传播函数。函数的整体功能是计算自注意力机制的输出,其中查询(q)、键(k)、值(v)是输入的特征表示。掩码(mask)是一个可选参数,用于屏蔽输入序列中的某些位置。通过计算注意力分数、应用 Softmax 函数和使用 dropout 进行正则化,该函数产生了自注意力的输出和相应的注意力权重。

1、def forward(self, q, k, v, mask=None):定义了前向传播函数,该函数接受查询(q)、键(k)、值(v)以及可选的掩码(mask)作为输入。

2、attn = torch.matmul(q / self.temperature, k.transpose(2, 3)):计算注意力分数。将查询和键进行点积操作,然后除以温度(temperature)以缩放注意力。这里采用了矩阵相乘的形式。

3、if mask is not None

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1315745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day62力扣打卡

打卡记录 统计区间中的整数数目&#xff08;动态开点线段树&#xff09; 链接 class CountIntervals:__slots__ left, right, l, r, cntdef __init__(self, l1, r10 ** 9):self.left self.right Noneself.l, self.r, self.cnt l, r, 0def add(self, l: int, r: int) ->…

Spring cloud - 断路器 Resilience4J

其实文章的标题应该叫 Resilience4J&#xff0c;而不是Spring Cloud Resilience4J&#xff0c;不过由于正在对Spring cloud的一系列组件进行学习&#xff0c;为了统一&#xff0c;就这样吧。 概念区分 首先区分几个概念 Spring cloud 断路器&#xff1a;Spring Cloud的官网对…

Python的数据类型及举例集合、元组、列表之间的转换规则

Python语言有八种数据类型&#xff0c;有数字&#xff08;整数、浮点数、复数&#xff09;、字符串、字典、集合、元组、列表、布尔值、空值&#xff0c;下面我演示八种数据类型及集合、元组、列表三种类型之间的转换规则。 一、数据类型示例 下面我演示了八种数据类型&#…

Git使用rebase和merge区别

Git使用rebase和merge区别 模拟环境使用merge合并使用rebase 模拟环境 本地dev分支中DevTest增加addRole() 远程dev被同事提交增加了createResource() 使用merge合并 使用idea中merge解决冲突后, 推送远程dev后,日志图显示 使用rebase idea中使用功能rebase 解决冲突…

PyQt6 安装Qt Designer

前言&#xff1a;在Python自带的环境下&#xff0c;安装Qt Designer&#xff0c;并在PyCharm中配置designer工具。 在项目开发中&#xff0c;使用Python虚拟环境安装PyQt6-tools时&#xff0c;designer.exe会安装在虚拟环境的目录中&#xff1a;.venv\Lib\site-packages\qt6_a…

模板方法模式(行为型)

目录 一、前言 二、模板模式 三、带钩子的模板模式 四、总结 一、前言 模板方法模式是一种行为型设计模式&#xff0c;它定义了一个操作中的算法框架&#xff0c;将一些步骤延迟到子类中实现。这种模式是基于“开闭原则”的设计思想&#xff0c;即对扩展开放&#xff0c;对…

Microsoft visual studio 2013卸载方法

1、问 题 Microsoft visual studio 2013 无法通过【程序与功能】卸载 2、解决方法 使用微软的Microsoft visual studio 2013 专用卸载工具 工具下载链接&#xff1a;https://github.com/Microsoft/VisualStudioUninstaller/releases 或 链接&#xff1a;https://pan.baidu.c…

分布式事务seata使用示例及注意事项

分布式事务seata使用示例及注意事项 示例说明代码调用方&#xff08;微服务A&#xff09;服务方&#xff08;微服务B&#xff09; 测试测试一 &#xff0c;seata发挥作用&#xff0c;成功回滚&#xff01;测试二&#xff1a;处理feignclient接口的返回类型从Integer变成String,…

数理统计基础:参数估计与假设检验

在学习机器学习的过程中&#xff0c;我充分感受到概率与统计知识的重要性&#xff0c;熟悉相关概念思想对理解各种人工智能算法非常有意义&#xff0c;从而做到知其所以然。因此打算写这篇笔记&#xff0c;先好好梳理一下参数估计与假设检验的相关内容。 1 总体梳理 先从整体结…

OceanBase数据库初识

文章目录 说明分布式数据库发展发展历史OceanBase和传统数据库的对比总结 OceanBase数据库产品简介应用案例 OceanBase数据库产品OceanBase数据库内核OceanBase开发者中心&#xff08;ODC&#xff09;产品架构OMS核心功能简介 说明 本文仅供学习和交流学习内容参考官方的培训资…

年底了,千万不要跳槽..

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

关于Linux你必须知道的五件事

Linux是一种开源操作系统 (OS)。操作系统是直接管理系统硬件和资源&#xff08;如 CPU、内存和存储&#xff09;的软件。操作系统位于应用程序和硬件之间&#xff0c;并在所有软件和执行工作的物理资源之间建立连接。 俄罗斯军方计划用 Astra Linux 取代 Windows&#xff01;为…

【数据结构】双链表的定义和操作

目录 1.双链表的定义 2.双链表的创建和初始化 3.双链表的插入节点操作 4.双链表的删除节点操作 5.双链表的查找节点操作 6.双链表的更新节点操作 7.完整代码 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助…

RuoYi-Cloud诺依微服务项目

1、架构图 从图中解析出RuoYi-Cloud 使用微服务技术栈 网关&#xff1a;Gateway远程调用&#xff1a;Ribbon/Feign注册中心&#xff1a;Nacos Discovery熔断降级&#xff1a;Sentinel配置中心&#xff1a;Nacos Config链路追踪&#xff1a;Sleuth ZipKin/SkyWalking &#x…

leetcode(力扣) 89. 格雷编码 (规律题)

文章目录 题目描述思路分析完整代码 题目描述 n 位格雷码序列 是一个由 2n 个整数组成的序列&#xff0c;其中&#xff1a; 每个整数都在范围 [0, 2n - 1] 内&#xff08;含 0 和 2n - 1&#xff09; 第一个整数是 0 一个整数在序列中出现 不超过一次 每对 相邻 整数的二进制表…

vue3 使用antd 报错Uncaught TypeError--【已解决】

问题现象 使用最基本的 ant-design-vue 按钮demo 都报错 报错文字如下 Uncaught TypeError: Cannot read properties of undefined (reading value)at ReactiveEffect.fn (ant-design-vue.js?v597f5366:6693:87)at ReactiveEffect.run (chunk-K2VKR2AM.js?v25c381c3:461:…

用文本创建图表的工具PlantUML

什么是 PlantUML &#xff1f; PlantUML 是一种开源工具&#xff0c;允许用户从纯文本语言创建图表。除了各种 UML 图之外&#xff0c;PlantUML 还支持各种其他软件开发相关格式&#xff0c;以及 JSON 和 YAML 文件的可视化。PlantUML 语言是特定领域语言的一个示例。 什么是 P…

Shopee ERP:提升电商管理效率的终极解决方案

Shopee ERP&#xff08;Enterprise Resource Planning&#xff0c;企业资源规划&#xff09;是一款专为Shopee卖家设计的集成化电商管理软件。通过使用Shopee ERP系统&#xff0c;卖家可以更高效地管理他们的在线商店&#xff0c;实现库存管理、订单处理、物流跟踪、财务管理、…

【理论篇】SaTokenException: 非Web上下文无法获取Request问题解决 -理论篇

在我们使用sa-token安全框架的时候&#xff0c;有时候会提示&#xff1a;SaTokenException:非Web上下文无法获取Request 错误截图&#xff1a; 在官方网站中&#xff0c;查看常见问题排查&#xff1a; 错误追踪&#xff1a; 跟着源码可以看到如下代码&#xff1a; 从源码中&a…

【Spring教程30】Spring框架实战:从零开始学习SpringMVC 之 Rest风格简介与RESTful入门案例

目录 1 REST简介2 RESTful入门案例2.1 环境准备2.2 思路分析2.3 修改RESTful风格 3 知识点总结 欢迎大家回到《Java教程之Spring30天快速入门》&#xff0c;本教程所有示例均基于Maven实现&#xff0c;如果您对Maven还很陌生&#xff0c;请移步本人的博文《如何在windows11下安…