UniversalTransformer with Adaptive Computation Time(ACT)

news2024/11/15 8:28:26

在这里插入图片描述


原论文链接:https://arxiv.org/abs/1807.03819


Main code

import torch
import numpy as np

class PositionTimestepEmbedding(torch.nn.Module):

    def forward(self, x, t):

        device = x.device

        sequence_length = x.size(1)
        d_model = x.size(2)

        position_embedding = np.array([
            [
                pos / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)
            ] for pos in range(sequence_length)
        ])

        position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2])
        position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2])

        timestep_embedding = np.array([
            [
                t / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)
            ]
        ])

        timestep_embedding[:, 0::2] = np.sin(timestep_embedding[:, 0::2])
        timestep_embedding[:, 1::2] = np.sin(timestep_embedding[:, 1::2])

        embedding = position_embedding + timestep_embedding

        return x + torch.tensor(embedding, dtype=torch.float, requires_grad=False, device=device)

class MultiHeadAttention(torch.nn.Module):

    def __init__(self, d_model, num_heads, dropout=0.):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads

        self.head_dim = d_model // num_heads

        assert self.head_dim * num_heads == self.d_model, "d_model must be divisible by num_heads"

        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)

        self.output = torch.nn.Linear(d_model, d_model)
        self.layer_norm = torch.nn.LayerNorm(d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask, -np.inf)

        scores = scores.softmax(dim=-1)

        scores = self.dropout(scores)

        return torch.matmul(scores, v), scores

    def forward(self, q, k, v, mask=None):

        batch_size = q.size(0)

        residual = q

        if mask is not None:
            mask = mask.unsqueeze(1)

        q = self.query(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        out, scores = self.scaled_dot_product_attention(q, k, v, mask)

        out = (
            out.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.head_dim)
        )
        out = self.output(out)

        out += residual

        return self.layer_norm(out)

class TransitionFunction(torch.nn.Module):

    def __init__(self, d_model, dim_transition, dropout=0.):
        super().__init__()

        self.linear1 = torch.nn.Linear(d_model, dim_transition)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(dim_transition, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.layer_norm = torch.nn.LayerNorm(d_model)

    def forward(self, x):

        y = self.linear1(x)
        y = self.relu(y)
        y = self.linear2(y)
        y = self.dropout(y)
        y = y + x

        return self.layer_norm(y)

class EncoderBasicLayer(torch.nn.Module):

    def __init__(self, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)

        self.transition = TransitionFunction(d_model, dim_transition, dropout)

    def forward(self, block_inputs, enc_self_attn_mask=None):

        self_attention_outputs = self.self_attention(block_inputs, block_inputs, block_inputs, enc_self_attn_mask)

        block_outputs = self.transition(self_attention_outputs)

        return block_outputs

class DecoderBasicLayer(torch.nn.Module):

    def __init__(self, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)

        self.attention_enc_dec = MultiHeadAttention(d_model, num_heads, dropout)

        self.transition = TransitionFunction(d_model, dim_transition, dropout)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask=None, dec_enc_attn_mask=None):

        dec_query = self.self_attention(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)

        block_outputs = self.attention_enc_dec(dec_query, enc_outputs, enc_outputs, dec_enc_attn_mask)

        block_outputs = self.transition(block_outputs)

        return block_outputs

class RecurrentEncoderBlock(torch.nn.Module):

    def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        self.layers = torch.nn.ModuleList([
            EncoderBasicLayer(
                d_model,
                dim_transition,
                num_heads,
                dropout
            ) for _ in range(num_layers)
        ])

    def forward(self, x, enc_self_attn_mask=None):

        for l in self.layers:
            x = l(x, enc_self_attn_mask)

        return x

class RecurrentDecoderBlock(torch.nn.Module):

    def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        self.layers = torch.nn.ModuleList([
            DecoderBasicLayer(
                d_model,
                dim_transition,
                num_heads,
                dropout
            ) for _ in range(num_layers)
        ])

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):

        for l in self.layers:
            dec_inputs = l(dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

        return dec_inputs

class AdaptiveNetwork(torch.nn.Module):

    def __init__(self, d_model, dim_transition, epsilon, max_hop):
        super().__init__()

        self.threshold = 1.0 - epsilon
        self.max_hop = max_hop

        self.halting_predict = torch.nn.Sequential(
            torch.nn.Linear(d_model, dim_transition),
            torch.nn.ReLU(),
            torch.nn.Linear(dim_transition, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x, mask, pos_time_embed, recurrent_block, encoder_output=None):

        device = x.device

        halting_probability = torch.zeros((x.size(0), x.size(1)), device=device)
        remainders = torch.zeros((x.size(0), x.size(1)), device=device)
        n_updates = torch.zeros((x.size(0), x.size(1)), device=device)

        previous = torch.zeros_like(x, device=device)

        step = 0

        while (((halting_probability < self.threshold) & (n_updates < self.max_hop)).byte().any()):

            x = x + pos_time_embed(x, step)

            p = self.halting_predict(x).squeeze(-1)

            still_running = (halting_probability < 1.0).float()

            new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running

            still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running

            halting_probability = halting_probability + p * still_running

            remainders = remainders + new_halted * (1 - halting_probability)

            halting_probability = halting_probability + new_halted * remainders

            n_updates = n_updates + still_running + new_halted

            update_weights = p * still_running + new_halted * remainders

            if encoder_output is not None:
                x = recurrent_block(x, encoder_output, mask[0], mask[1])

            else:
                x = recurrent_block(x, mask)

            previous = ((x * update_weights.unsqueeze(-1)) + (previous * (1 - update_weights.unsqueeze(-1))))

            step += 1

        return previous

class Encoder(torch.nn.Module):

    def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"

        self.pos_time_embedding = PositionTimestepEmbedding()

        self.recurrent_block = RecurrentEncoderBlock(
            num_layers,
            d_model,
            dim_transition,
            num_heads,
            dropout
        )

        self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)

    def forward(self, x, enc_self_attn_mask=None):

        return self.adaptive_network(x, enc_self_attn_mask, self.pos_time_embedding, self.recurrent_block)

class Decoder(torch.nn.Module):

    def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):
        super().__init__()

        assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"

        self.pos_time_embedding = PositionTimestepEmbedding()

        self.recurrent_block = RecurrentDecoderBlock(
            num_layers,
            d_model,
            dim_transition,
            num_heads,
            dropout
        )

        self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):

        return self.adaptive_network(dec_inputs, (dec_self_attn_mask, dec_enc_attn_mask),
                                     self.pos_time_embedding, self.recurrent_block, enc_outputs)

class AdaptiveComputationTimeUniversalTransformer(torch.nn.Module):
    
    def __init__(self, d_model, dim_transition, num_heads, enc_attn_layers, dec_attn_layers, epsilon, max_hop, dropout=0.):
        super().__init__()

        self.encoder = Encoder(epsilon, max_hop, enc_attn_layers, d_model, dim_transition, num_heads, dropout)

        self.decoder = Decoder(epsilon, max_hop, dec_attn_layers, d_model, dim_transition, num_heads, dropout)

    def forward(self, src, tgt, enc_self_attn_mask=None, dec_self_attn_mask=None, dec_enc_attn_mask=None):

        enc_outputs = self.encoder(src, enc_self_attn_mask)

        return self.decoder(tgt, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

Mask

# from https://zhuanlan.zhihu.com/p/403433120
def get_attn_subsequence_mask(seq):  # seq: [batch_size, tgt_len]
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]
    subsequence_mask = torch.from_numpy(subsequence_mask).bool()  # [batch_size, tgt_len, tgt_len]
    return subsequence_mask

def get_attn_pad_mask(seq_q, seq_k):  # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]
    return pad_attn_mask.expand(batch_size, len_q, len_k)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1353670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(学习打卡2)重学Java设计模式之六大设计原则

前言&#xff1a;听说有本很牛的关于Java设计模式的书——重学Java设计模式&#xff0c;然后买了(*^▽^*) 开始跟着小傅哥学Java设计模式吧&#xff0c;本文主要记录笔者的学习笔记和心得。 打卡&#xff01;打卡&#xff01; 六大设计原则 &#xff08;引读&#xff1a;这里…

AIGC带给开发者的冲击

未来会有两种开发者&#xff0c;一种是会使用AIGC工具的开发者另一种是不会使用AIGC的开发者&#xff0c;AIGC的出现提高了开发效率和代码质量&#xff0c;对开发者意味着需要不断学习和适应新的技术和工作范式&#xff0c;开发者可以把更多的精力放在高级抽象的定义以及更高维…

(16)Linux 进程等待 wait/waitpid 的 status 参数

前言&#xff1a;我们开始讲解进程等待&#xff0c;简单地讲解 wait 函数&#xff0c;然后我们主要讲解 waitpid 函数。由于 wait 只有一个参数 status&#xff0c;且 waitpid 有三个参数且其中一个也是 status&#xff0c;我们本章重点讲解这个 status 参数。 一、进程等待&a…

学习笔记240102 --- 表单无法输入,是否data中没有提前声明导致的

前端框架 &#xff1a;vue2.x 第三方ui组件&#xff1a;ElementUI 操作系统&#xff1a;windows 浏览器&#xff1a;chrome 谷歌 问题描述 表单使用中&#xff0c;没有在data中提前声明参数&#xff0c;当数据回显时&#xff0c;表单无法输入 <el-form :model"queryPa…

Oracle-数据库迁移之后性能变慢问题分析

问题背景&#xff1a; ​一套Oracle11.2.0.4的RAC集群&#xff0c;通过Dataguard switchover方式迁移到新机器之后&#xff0c;运行第一天应用报障说应用性能慢&#xff0c;需要进行性能问题排查 问题分析&#xff1a; 首先&#xff0c;登陆到服务器&#xff0c;用TOP看一眼两个…

CMake入门教程【基础篇】CMake+Linux gcc构建C++项目

文章目录 1.概述2.GCC与CMake介绍3.安装CMake和GCC4.代码示例 1.概述 在Linux环境下&#xff0c;使用CMake结合GCC&#xff08;GNU Compiler Collection&#xff09;进行项目构建是一种常见且高效的方法。CMake作为一个跨平台的构建系统&#xff0c;可以生成适用于不同编译器的…

CentOS 7 实战指南:文件操作命令详解

写在前面 想要快速掌握 CentOS 7 系统下的文件操作技巧吗&#xff1f;不用担心&#xff01;我为你准备了一篇详细的技术文章&#xff0c;涵盖了各种常用的文件操作命令。无论您是初学者还是有一定经验的用户&#xff0c;这篇文章都能帮助您加深对 CentOS 7 文件操作的理解&…

海外住宅IP代理的工作原理和应用场景分析,新手必看

海外住宅IP代理作为一种技术解决方案&#xff0c;为用户提供了访问全球网络资源和维护隐私安全的方法。本文将介绍海外住宅IP代理的工作原理和应用场景&#xff0c;帮助读者更好地理解和利用这一技术。 一、工作原理 海外住宅IP代理的工作原理基于代理服务器和IP地址的转发。它…

阿里云系统盘测评ESSD、SSD和高效云盘IOPS、吞吐量性能参数表

阿里云服务器系统盘或数据盘支持多种云盘类型&#xff0c;如高效云盘、ESSD Entry云盘、SSD云盘、ESSD云盘、ESSD PL-X云盘及ESSD AutoPL云盘等&#xff0c;阿里云百科aliyunbaike.com详细介绍不同云盘说明及单盘容量、最大/最小IOPS、最大/最小吞吐量、单路随机写平均时延等性…

HackTheBox - Medium - Linux - BroScience

BroScience BroScience 是一款中等难度的 Linux 机器&#xff0c;其特点是 Web 应用程序容易受到“LFI”的攻击。通过读取目标上的任意文件的能力&#xff0c;攻击者可以深入了解帐户激活码的生成方式&#xff0c;从而能够创建一组可能有效的令牌来激活新创建的帐户。登录后&a…

李沐机器学习系列1--- 线性规划

1 Introduction 1.1 线性回归函数 典型的线性回归函数 f ( x ) w ⃗ ⋅ x ⃗ f(x)\vec{w} \cdot \vec{x} f(x)w ⋅x 现实生活中&#xff0c;简单的线性回归问题很少&#xff0c;这里有一个简单的线性回归问题。房子的价格和房子的面积以及房子的年龄假设成线性关系。 p r …

如何做好设备维护管理?这款设备管理系统值得推荐

在现代化的工业生产中&#xff0c;设备的高效运行是保障生产安全和效率的关键因素。然而&#xff0c;在企业实际的设备维护管理业务中&#xff0c;仍面临着许多难题与痛点&#xff1a; 设备档案管理乱&#xff1a;传统管理方式下&#xff0c;如果想查询设备的历史巡检、维修、…

[雷池WAF]长亭雷池WAF配置基于健康监测的负载均衡,实现故障自动切换上游服务器

为了进一步加强内网安全&#xff0c;在原有硬WAF的基础上&#xff0c;又在内网使用的社区版的雷池WAF&#xff0c;作为应用上层的软WAF。从而实现多WAF防护的架构。 经过进一步了解&#xff0c;发现雷池WAF的上游转发代理是基于Tengine的&#xff0c;所以萌生出了一个想法&…

用通俗易懂的方式讲解大模型:在 CPU 服务器上部署 ChatGLM3-6B 模型

大语言模型&#xff08;LLM&#xff09;的量化技术可以大大降低 LLM 部署所需的计算资源&#xff0c;模型量化后可以将 LLM 的显存使用量降低数倍&#xff0c;甚至可以将 LLM 转换为完全无需显存的模型&#xff0c;这对于 LLM 的推广使用来说是非常有吸引力的。 本文将介绍如何…

双侧电源系统距离保护MATLAB仿真模型

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 系统原始数据 双侧电源系统模型如图所示&#xff1a; 仿真模型搭建 将线路AB分成Line1和Line2&#xff0c;将线路BC分成Line3和Line4&#xff0c;用三相电压电流测量模块作为系统母线&#xff0c;根据系统已…

在Cadence中单独添加或删除器件与修改网络的方法

首先需要在设置中使能 ,添加或修改逻辑选项。 添加或删除器件&#xff0c;点击logic-part&#xff0c;选择需要添加或删除的器件&#xff0c;这里的器件必须是PCB中已经有的器件&#xff0c;Refdes中输入添加或删除的器件标号&#xff0c;点击Add添加。 添加完成后就会显示在R1…

Linux学习记录——삼십삼 http协议

文章目录 1、URL2、http协议的宏观构成3、详细理解http协议1、http请求2、http响应1、有效载荷格式2、有效载荷长度3、客户端要访问的资源类型4、修改响应写法5、处理不同的请求6、跳转 3、请求方法&#xff08;GET/POST&#xff09;4、HTTP状态码&#xff08;实现3和4开头的&a…

基于深度学习的交通标志图像分类识别系统

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 本文详细探讨了一基于深度学习的交通标志图像识别系统。采用TensorFlow和Keras框架&#xff0c;利用卷积神经网络&#xff08;CNN&#xff09;进行模型训练和预测&#xff0c;并引入VGG16迁移学习…

x-cmd pkg | trafilatura - 网络爬虫和搜索引擎优化工具

目录 简介首次用户技术特点竞品和相关作品进一步阅读 简介 trafilatura 是一个用于从网页上提取文本的命令行工具和 python 包: 提供网络爬虫、下载、抓取以及提取主要文本、元数据和评论等功能可帮助网站导航和从站点地图和提要中提取链接无需数据库&#xff0c;输出即可转换…

深入了解Apache 日志,Apache 日志分析工具

Apache Web 服务器在企业中广泛用于托管其网站和 Web 应用程序&#xff0c;Apache 服务器生成的原始日志提供有关 Apache 服务器托管的网站如何处理用户请求以及访问您的网站时经常遇到的错误的重要信息。 什么是 Apache 日志 Apache 日志包含 Apache Web 服务器处理的所有事…