transfomer中Decoder和Encoder的base_layer的源码实现

news2024/9/24 21:29:25

简介

Encoder和Decoder共同组成transfomer,分别对应图中左右浅绿色框内的部分.
在这里插入图片描述
Encoder:
目的:将输入的特征图转换为一系列自注意力的输出。
工作原理:首先,通过卷积神经网络(CNN)提取输入图像的特征。然后,这些特征通过一系列自注意力的变换层进行处理,每个变换层都会将特征映射进行编码并产生一个新的特征映射。这个过程旨在捕捉图像中的空间和通道依赖关系。
作用:通过处理输入特征,提取图像特征并进行自注意力操作,为后续的目标检测任务提供必要的特征信息。
Decoder:
目的:接受Encoder的输出,并生成对目标类别和边界框的预测。
工作原理:首先,它接收Encoder的输出,然后使用一系列解码器层对目标对象之间的关系和全局图像上下文进行推理。这些解码器层将最终的目标类别和边界框的预测作为输出。
作用:基于Encoder的输出和全局上下文信息,生成目标类别和边界框的预测结果。
总结:Encoder就是特征提取类似卷积;Decoder用于生成box,类似head

源码实现:

Encoder 通常是6个encoder_layer组成,Decoder 通常是6个decoder_layer组成
我实现了核心的BaseTransformerLayer层,可以用来定义encoder_layer和decoder_layer

具体源码及其注释如下,配好环境可直接运行(运行依赖于上一个博客的代码):

import torch
from torch import nn
from ZMultiheadAttention import MultiheadAttention  # 来自上一次写的attension


class FFN(nn.Module):
    def __init__(self,
                 embed_dim=256,
                 feedforward_channels=1024,
                 act_cfg='ReLU',
                 ffn_drop=0.,
                 ):
        super(FFN, self).__init__()
        self.l1 = nn.Linear(in_features=embed_dim, out_features=feedforward_channels)
        if act_cfg == 'ReLU':
            self.act1 = nn.ReLU(inplace=True)
        else:
            self.act1 = nn.SiLU(inplace=True)
        self.d1 = nn.Dropout(p=ffn_drop)
        self.l2 = nn.Linear(in_features=feedforward_channels, out_features=embed_dim)
        self.d2 = nn.Dropout(p=ffn_drop)

    def forward(self, x):
        tmp = self.d1(self.act1(self.l1(x)))
        tmp = self.d2(self.l2(tmp))
        x = tmp + x
        return x


# transfomer encode和decode的最小循环单元,用于打包self_attention或者cross_attention
class BaseTransformerLayer(nn.Module):
    def __init__(self,
                 attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],
                 fnn_cfg=dict(embed_dim=64, feedforward_channels=128, act_cfg='ReLU', ffn_drop=0.),
                 operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')):
        super(BaseTransformerLayer, self).__init__()
        self.attentions = nn.ModuleList()
        # 搭建att层
        for attn_cfg in attn_cfgs:
            self.attentions.append(MultiheadAttention(**attn_cfg))
        self.embed_dims = self.attentions[0].embed_dim

        # 统计norm数量 并搭建
        self.norms = nn.ModuleList()
        num_norms = operation_order.count('norm')
        for _ in range(num_norms):
            self.norms.append(nn.LayerNorm(normalized_shape=self.embed_dims))

        # 统计ffn数量 并搭建
        self.ffns = nn.ModuleList()
        self.ffns.append(FFN(**fnn_cfg))
        self.operation_order = operation_order

    def forward(self, query, key=None, value=None, query_pos=None, key_pos=None):
        attn_index = 0
        norm_index = 0
        ffn_index = 0
        for order in self.operation_order:
            if order == 'self_attn':
                temp_key = temp_value = query  # 不用担心三个值一样,在attention里面会重映射qkv
                query, attention = self.attentions[attn_index](
                    query,
                    temp_key,
                    temp_value,
                    query_pos=query_pos,
                    key_pos=query_pos)
                attn_index += 1
            elif order == 'cross_attn':
                query, attention = self.attentions[attn_index](
                    query,
                    key,
                    value,
                    query_pos=query_pos,
                    key_pos=key_pos)
                attn_index += 1
            elif order == 'norm':
                query = self.norms[norm_index](query)
                norm_index += 1
            elif order == 'ffn':
                query = self.ffns[ffn_index](query)
                ffn_index += 1
        return query


if __name__ == '__main__':
    query = torch.rand(size=(10, 2, 64))
    key = torch.rand(size=(5, 2, 64))
    value = torch.rand(size=(5, 2, 64))
    query_pos = torch.rand(size=(10, 2, 64))
    key_pos = torch.rand(size=(5, 2, 64))
    # encoder 通常是6个encoder_layer组成 每个encoder_layer['self_attn', 'norm', 'ffn', 'norm']
    encoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4)],
                                         fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',
                                                      ffn_drop=0.),
                                         operation_order=('self_attn', 'norm', 'ffn', 'norm'))

    encoder_layer_output = encoder_layer(query=query, query_pos=query_pos, key_pos=key_pos)

    # decoder 通常是6个decoder_layer组成 每个decoder_layer['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm']
    decoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],
                                         fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',
                                                      ffn_drop=0.),
                                         operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))

    decoder_layer_output = decoder_layer(query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos)

    pass

具体流程说明:

Encoder 通常是6个encoder_layer组成,每个encoder_layer[‘self_attn’, ‘norm’, ‘ffn’, ‘norm’]
Decoder 通常是6个decoder_layer组成,每个decoder_layer[‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’]
按照以上方式搭建网络即可
其中norm为LayerNorm,在样本内部进行归一化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1387907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java如何修改windows计算机本地日期和时间?

本文教程,主要介绍,在java中如何修改windows计算机本地日期和时间。 目录 一、程序代码 二、运行结果 一、程序代码 package com;import java.io.IOException;/**** Roc-xb*/ public class ChangeSystemDate {public static void main(String[] args)…

MySQL面试题 | 10.精选MySQL面试题

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

WinForms TreeView 控件:保持节点选中状态即使失去焦点

WinForms TreeView 控件:保持节点选中状态即使失去焦点 在 Windows 窗体(WinForms)应用程序中,TreeView 控件是一种非常有用的界面元素,允许用户以层次结构的方式浏览信息。然而,一个常见的用户界面问题是&…

关于浮点数的四舍五入问题

最近有关注到,在C/C中,对于浮点数的四舍五入,与实际的有一些出入,我打算今天总结一下,并解释一下这是为啥, 好了,下面进入正题,都是干货哦,认真看完,留下你的…

非常好用的Mac清理工具CleanMyMac X 4.14.7 如何取消您对CleanMyMac X的年度订购

CleanMyMac X 4.14.7是Mac平台上的一款非常著名同时非常好用的Mac清理工具。全方位扫描您的Mac系统,让垃圾无处藏身,您只需要轻松单击2次鼠标左键即可清理数G的垃圾,就这么简单。瞬间提升您Mac速度。 CleanMyMac X 4.14.7下载地址&#xff1a…

Linux Mii management/mdio子系统分析之三 mii_bus注册、注销及其驱动开发流程

(转载)原文链接:https://blog.csdn.net/u014044624/article/details/123303174 本篇是mii management/mdio模块分析的第三篇文章,本章我们主要介绍mii-bus的注册与注销接口。在前面的介绍中也已经说过,我们可以将mii-b…

如何增加服务器的高并发

随着互联网的快速发展和普及,越来越多的应用程序需要支持高并发的请求处理。在这种情况下增加服务器的高并发能力成为了一个热门的话题。下面简单的介绍如果提高服务器的高并发能力。 负载均衡 是把请求分发到多个服务器上,来实现请求的平衡和分担。负…

compose 实验

cd /opt mkdir compose_nginx cd compose_nginx mkdir nginx cd nginx/ 此时顺便将nginx安装包拖进来 vim Dockerfile mkdir /opt/compose_nginx/wwwroot echo "<h1>this is test web</h1>" > /opt/compose_nginx/wwwroot/index.html docker netw…

如何配置mybatisplus基础环境?

1.在pom文件&#xff08;都加上吧&#xff0c;以防万一&#xff09; 2.若当初有mybatis的依赖&#xff0c;要删除 3.在Mapper接口加上"extends BaseMapper<实体类型>" 4.更改yml文件内容 别名扫描包&#xff1a;是指实体类型 5.添加"extends ServiceIm…

SQL语句详解四-DQL(数据查询语言-约束)

约束 概述&#xff1a;对表中的数据进行限定&#xff0c;保证数据的正确性&#xff0c;有效性和完整性。 约束分类 约束关键字约束意思primary key主键约束not null非空约束unique唯一约束foreign key外键约束 例子&#xff1a;sname varchar(40) not null, – 代表 sname 这…

【C语言】指针知识点笔记(2)

目录 一、野指针 二、assert断言 三、指针的使用和传址调用 四、数组名的理解 五、使用指针访问数组 一、野指针 二、assert断言 三、指针的使用和传址调用 四、数组名的理解 五、使用指针访问数组

Web 服务器渗透测试清单

Web 服务器渗透测试在三个重要类别下进行&#xff1a;身份、分析和报告漏洞&#xff0c;例如身份验证弱点、配置错误和协议关系漏洞。 1. “进行一系列有条不紊且可重复的测试”是测试网络服务器是否能够解决所有不同应用程序漏洞的最佳方法。 2.“收集尽可能多的信息”关于…

AtCoder Beginner Contest 336 G. 16 Integers(图计数 欧拉路径转欧拉回路 矩阵树定理 best定理)

题目 给16个非负整数&#xff0c;x[i∈(0,1)][j∈(0,1)][k∈(0,1)][l∈(0,1)] 求长为n3的01串的方案数&#xff0c;满足长度为4的ijkl&#xff08;2*2*2*2&#xff0c;16种情况&#xff09;串恰为x[i][j][k][l]个 答案对998244353取模 思路来源 https://www.cnblogs.com/tz…

多线程并发与并行

&#x1f4d1;前言 本文主要是【并发与并行】——并发与并行的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304;每日一句&…

03 顺序表

目录 线性表顺序表练习 线性表(Linear list)是n个具有相同特性的数据元素的有限序列。线性表是一种在实际中广泛使用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串。。。 线性表在逻辑上时线性结构&#xff0c;是连续的一条直线。但在物理结…

【PostgreSQL内核学习(二十一)—— 执行器(InitPlan)】

执行器&#xff08;InitPlan&#xff09; 概述InitPlan 函数代码段解释ExecInitNode 函数 总结 声明&#xff1a;本文的部分内容参考了他人的文章。在编写过程中&#xff0c;我们尊重他人的知识产权和学术成果&#xff0c;力求遵循合理使用原则&#xff0c;并在适用的情况下注明…

力扣每日一练(24-1-16)

我一开始想到的是&#xff0c;如果数字相同则加一。 然而&#xff0c;对了一点点&#xff0c;而已。 高手的方法不是普通人在几分钟内能想得出来的&#xff0c;hh 继续补充&#xff1a; 如果数字不同则减一&#xff0c;如果计数到达了0&#xff0c;则更新数字&#xff0c;最…

【极光系列】springboot集成redis

【极光系列】springboot集成redis tips&#xff1a;主要用于快速搭建环境以及部署项目入门 gitee地址 直接下载源码可用 https://gitee.com/shawsongyue/aurora.git模块&#xff1a;aurora_rediswindow安装redis安装步骤 1.下载资源包 直接下载解压&#xff1a;https://pa…

PHP项目如何自动化测试

开发和测试 测试和开发具有同等重要的作用 从一开始&#xff0c;测试和开发就是相向而行的。测试是开发团队的一支独立的、重要的支柱力量。 测试要具备独立性 独立分析业务需求&#xff0c;独立配置测试环境&#xff0c;独立编写测试脚本&#xff0c;独立开发测试工具。没有…

华硕原厂系统天选5Pro原厂Win11系统恢复安装过程方法

华硕原厂系统天选5Pro原厂Win11系统恢复安装过程方法 华硕原厂系统枪神8/枪神8plus原厂Win11系统恢复安装过程方法 还是老规矩&#xff0c;分3种安装方法 远程恢复安装&#xff1a;https://pan.baidu.com/s/166gtt2okmMmuPUL1Fo3Gpg?pwdm64f 提取码:m64f 支持型号&#x…