【前沿模型解析】潜在扩散模型 2-3 | 手撕感知图像压缩 基础块 自注意力块

news2024/10/7 18:20:01

1 注意力机制回顾

同ResNet一样,注意力机制应该也是神经网络最重要的一部分了。

想象一下你在观看一场电影,但你的朋友在给你发短信。虽然你正在专心观看电影,但当你听到手机响起时,你会停下来查看短信,然后这时候电影的内容就会被忽略。这就是注意力机制的工作原理。

在处理输入序列时,比如一句话中的每个单词,注意力机制允许模型像你一样,专注于输入中的不同部分。模型可以根据输入的重要性动态地调整自己的注意力,注意自己觉得比较重要的部分,忽略一些不太重要的部分,以便更好地理解和处理序列数据。

具体来说,是通过q,k,v实现的

q(查询),k(键值)之间先进行计算,获得重要性权重w,w再作用于v

利用卷积操作确定q,k,v

q,k做运算得到w,缩放w

w和v做运行

最后残差

得到

2 Atten块的实现

在这里插入图片描述

2.1 初始化函数

    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

2.2 前向传递函数

def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)
  1. b,c,h,w = q.shape:假设q是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. q = q.reshape(b,c,h*w):将q张量重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  3. q = q.permute(0,2,1):交换张量维度,将第三维移动到第二维,这是为了后续计算方便。

  4. k = k.reshape(b,c,h*w):对k做和q类似的操作,将其形状改为三维张量。

  5. w_ = torch.bmm(q,k):计算qk的批次矩阵乘积(batch matrix multiplication),得到注意力权重的初始矩阵。这里的w_是一个b x (h*w) x (h*w)的张量,表示每个位置对应的注意力权重。

  6. w_ = w_ * (int(c)**(-0.5)):对初始注意力权重进行缩放,这里使用了一个缩放因子,通常是通道数的倒数的平方根。这个缩放是为了确保在计算注意力时不会因为通道数过大而导致梯度消失或梯度爆炸。

  7. w_ = torch.nn.functional.softmax(w_, dim=2):对注意力权重进行softmax操作,将其归一化为概率分布,表示每个位置的重要性。

这段代码的作用是实现自注意力机制中计算注意力权重的过程,其中qk分别代表查询(query)和键(key),通过计算它们的相似度得到注意力权重。

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_
  1. v = v.reshape(b,c,h*w):将值(value)张量v重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  2. w_ = w_.permute(0,2,1):交换注意力权重w_张量的维度,将第三维移动到第二维,这是为了后续计算方便。

  3. h_ = torch.bmm(v,w_):计算值v和经过缩放的注意力权重w_的批次矩阵乘积(batch matrix multiplication),得到自注意力的输出。这里的h_是一个b x c x (h*w)的张量,表示每个位置经过注意力计算后的输出。

  4. h_ = h_.reshape(b,c,h,w):将h_张量重新形状为四维张量,恢复其原始的高度和宽度。

  5. h_ = self.proj_out(h_):通过一个全连接层proj_out对自注意力的输出h_进行线性变换和非线性变换,这个操作有助于提取特征并保持网络的表达能力。

最后,将输入x和自注意力的输出h_相加,得到最终的自注意力输出。这样做是为了在保留原始输入信息的同时,加入了经过自注意力计算后的新信息,从而使模型能够更好地理解输入序列的语义信息。

2.3 Atten注意力完整代码

from torch import nn
import torch
from einops import rearrange


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

3 源代码中的另一种注意力实现

源代码中还实现了LinearAttention,是另一种注意力机制

可以看看

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

对于forward函数

  1. b, c, h, w = x.shape:假设输入张量x是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. qkv = self.to_qkv(x):将输入张量x通过一个线性变换(可能包括分别计算查询(query)、键(key)和值(value))得到qkv张量,其形状为b x (3*heads*c) x h x w,其中heads是多头注意力的头数。

  3. q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3):将qkv张量重新排列为三个张量qkv,分别表示查询、键和值,形状为b x heads x c x (h*w)

  4. k = k.softmax(dim=-1):对键张量k进行softmax操作,将其归一化为概率分布,以便计算注意力权重。

  5. context = torch.einsum('bhdn,bhen->bhde', k, v):使用torch.einsum函数计算注意力权重与值的加权和,得到上下文张量context,形状为b x heads x c x (h*w)

  6. out = torch.einsum('bhde,bhdn->bhen', context, q):使用torch.einsum函数计算上下文张量与查询张量的加权和,得到输出张量out,形状为b x heads x c x (h*w)

  7. out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w):将输出张量out重新排列为形状b x (heads*c) x h x w,恢复其原始形状。

  8. return self.to_out(out):将输出张量out通过一个线性变换得到最终的输出。

如果注意力机制type=None的话,则不进行注意力机制的计算~

用一个torch函数

nn.Identity 这是一个恒等变化的一个函数,不做任何处理

4 完整代码及其测试

from torch import nn
import torch
from einops import rearrange

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type=="line":
        return LinAttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1580958.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++可变模板参数与包装器(function、bind)

文章目录 一、 可变参数模板1. 概念2. 参数包值的获取 二、 包装器1. 什么是包装器2. 为什么要有包装器3. std::function(1) function概念(2) 利用function解决实例化多份问题(3) 包装器的其他使用场景&…

水资源管理系统:守护生命之源,构建和谐水生态

水资源是维系地球生态平衡和人类社会可持续发展的重要基础。然而,随着人口增长、工业化和城市化的加速,水资源短缺、水质污染和生态破坏等问题日益凸显。在这样的背景下,构建一个全面、高效、智能的水资源管理系统显得尤为迫切和必要。 项目…

Google Cookie意见征求底部弹窗

关于欧盟 Cookie 通知 根据2024年欧盟的《通用数据保护条例》以及其他相关法规,要求google cookie的使用必须征求用户的同意,才能进行收集用户数据信息,因此跨境独立站,如果做欧洲市场,就必须弹出cookie收集数据弹窗&a…

阿赵UE学习笔记——26、动画混合空间

阿赵UE学习笔记目录   大家好,我是阿赵。   继续学习虚幻引擎的使用。之前学习了通过蓝图直接控制动画播放,或者通过动画状态机去控制播放。这次来学习一种比较细致的动画控制播放方式,叫做动画混合空间。 一、使用的情景 假设我们现在需…

链表的中间结点——每日一题

题目链接: OJ链接 题目: 给你单链表的头结点 head ,请你找出并返回链表的中间结点。 如果有两个中间结点,则返回第二个中间结点。 示例 1: 输入:head [1,2,3,4,5] 输出:[3,4,5] 解释&…

2023一个前端人的杂谈

酒香也怕巷子深 年底提车,回河北过年,一路总是旅游的牌子,后来去满城滑雪,随拍了几张照片,才更加感受河北的魅力。 感觉仅仅是这一抹黄昏,就让这一行物超所值了,原来那句宣传语所言非虚:这么近,那么美,周末到河北,然而我认为实际的好处,可能不止如此。 作为一个出…

ADC电路项目1——10bit SAR ADC 设计,smic18工艺,有工艺库,有效位数ENOB为9.8

分享一个入门SAR ADC的完整电路项目,适合新手小白学习 10bit 20MHz SAR ADC(WX:didadidadidida313,加我备注:CSDN 10 bit SAR ADC,谢绝白嫖哈) 概述: 本设计采用 smic18mmrf CMOS 工艺&#xf…

23linux 自定义shell文件系统

打印环境变量,把当前子进程所有环境变量打印出来 环境变量也是一张表(指针数组以null结尾,最后条件不满足就退出了 ) 用子进程调用 结论1 当我们进行程序替换的时候 ,子进程对应的环境变量(子进程的环境变…

2024年3月30日~2024年4月7日周报

文章目录 一、前言二、创意收集2.1 多任务学习2.1.1 多任务学习的定义与优势2.1.2 多任务学习的分类 2.2 边缘检测2.2.1 基础理论2.2.2 sobel代码介绍2.2.3 canny代码介绍 三、《地震速度模型超分辨率的多任务学习》3.1 M-RUDSR架构3.2 详细介绍3.3 实验设置 四、实验五、小结5…

网工内推 | 深信服、宁德时代,最高20K招安全工程师,包吃包住

01 深信服科技 招聘岗位:安全服务工程师 职责描述: 1.负责现场安全服务项目工作内容,包含渗透测试、安全扫描、基线核查、应急响应等; 2.协助用户完成安全测试漏洞整改、复测工作; 3.为用户提供网络、主机、业务系统等…

数据库讲解---(数据查询)【MySQL版本】

零.前言 数据库讲解(MySQL版)(超详细)【第一章】-CSDN博客 数据库-ER图教程_e-r图数据库-CSDN博客 数据库讲解(MySQL版)(超详细)【第二章】【上】-CSDN博客 数据库讲解---(SQL语…

[从0开始AIGC][Transformer相关]:算法的时间和空间复杂度

一、算法的时间和空间复杂度 文章目录 一、算法的时间和空间复杂度1、时间复杂度2、空间复杂度 二、Transformer的时间复杂度分析1、 self-attention 的时间复杂度2、 多头注意力机制的时间复杂度 三、transformer的空间复杂度 算法是指用来操作数据、解决程序问题的一组方法。…

专题十二、字符串

字符串 1. 字符串字面量1.1 字符串字面量中的转义序列1.2 延续字符串字面量1.3 如何存储字符串字面量1.4 字符串字面量的操作1.5 字符串字面量与字符常量 2. 字符串变量2.1 初始化字符串变量2.2 字符数组与字符指针 3. 字符串的读和写3.1 用 printf 函数和 puts 函数写字符串3.…

通俗白话了解资产负债现金利润三张表

看到一本小书不错《财务小白轻松入门》,里面通俗说了三张表之间的关系。贴图摘录下:

SSL证书添加与ICP备案,对于SpringBoot的要求

配置了SSL证书之后,在SpringBoot的resources文件夹里的application.properties会添加以下代码: server.port443 不需要添加server.address。不然会报错。 https类型的请求默认在Postman里面不可请求。 经过SSL证书处理的网页,链接中使默认…

RAG文本加载和分块调研

文本加载和分块 一、文本加载 文本加载是RAG文本增强检索重要环节。文件有不同类型(excel、word、ppt、pdf、png、html、eps、gif、mp4、zip等),衍生出了很多第三方库。使用python处理文件是各种python开发岗位都需要的操作。主要涉及到的标准…

TryHackMe - HTTP Request Smuggling

学完、打完后的复习 HTTP 1 这部分比较简单,直接略过 HTTP2请求走私 首先要了解HTTP2的结构,与HTTP1之间的一些差异 HTTP2中不再使用CRLF来作为字段的边界限定,而是在二进制中直接通过长度、名字、值长度、值,来确认边界 而这…

【C++】用红黑树封装map和set

我们之前学的map和set在stl源码中都是用红黑树封装实现的,当然,我们也可以模拟来实现一下。在实现之前,我们也可以看一下stl源码是如何实现的。我们上篇博客写的红黑树里面只是一个pair对象,这对于set来说显然是不合适的&#xff…

软件可靠性基本概念_1.定义和定量描述

1.软件可靠性定义 软件可靠性(Software Reliability)是软件产品在规定的条件下和规定的时间区间完成规定功能的能力。规定的条件是指直接与软件运行相关的使用该软件的计算机系统的状态和软件的输入条件,或统称为软件运行时的外部输入条件&am…

【MATLAB源码-第183期】基于matlab的图像处理GUI很全面包括滤波,灰度,边缘提取,RGB亮度调节,二值化等。

操作环境: MATLAB 2022a 1、算法描述 1. RGB颜色亮度调整 1.1 RGB颜色模型 RGB颜色模型是一种加色模型,使用红色(R)、绿色(G)、蓝色(B)三种颜色的不同组合来表示各种颜色。每种…