人工智能(Pytorch)搭建模型5-注意力机制模型的构建与GRU模型融合应用

news2024/12/29 2:37:42

大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型5-注意力机制模型的构建与GRU模型融合应用。注意力机制是一种神经网络模型,在序列到序列的任务中,可以帮助解决输入序列较长时难以获取全局信息的问题。该模型通过对输入序列不同部分赋予不同的权重,以便在每个时间步骤上更好地关注需要处理的信息。在编码器-解码器(Encoder-Decoder)框架中,编码器将输入序列映射为一系列向量,而解码器则在每个时间步骤上生成输出序列。在此过程中,解码器需要对编码器的所有时刻进行“注意”,以了解哪些输入对当前时间步骤最重要。

在注意力机制中,解码器会计算每个编码器输出与当前解码器隐藏状态之间的相关度,并将其转化为注意力权重,以确定每个编码器输出对当前时刻解码器状态的贡献。这些权重被用于加权求和编码器输出,从而得到一个上下文向量,该向量包含有关输入序列的重要信息,有助于提高模型的性能和泛化能力。

一、注意力机制模型构建

# 1. 导入所需库
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(1)
        repeated_hidden = hidden.unsqueeze(1).repeat(1, max_len, 1)
        energy = torch.tanh(self.attn(torch.cat((repeated_hidden, encoder_outputs), dim=2)))
        attention_scores = self.v(energy).squeeze(2)
        attention_weights = nn.functional.softmax(attention_scores, dim=1)
        context_vector = (encoder_outputs * attention_weights.unsqueeze(2)).sum(dim=1)
        return context_vector, attention_weights

 以上Attention类是注意力机制的神经网络模型,该模型接收两个输入参数:隐藏状态编码器输出。其中,隐藏状态是解码器中上一个时间步骤的输出,而编码器输出是编码器模型对输入序列进行编码后的输出。编码器输出和隐藏状态被用于计算上下文向量注意力权重。通过将隐藏状态和编码器输出进行拼接,然后将结果通过线性层进行处理,并使用tanh激活函数后得到能量矩阵(energy)。接着,使用另一个线性层(self.v)将能量矩阵转换成注意力得分(attention scores),并使用softmax函数转换成注意力权重(attention weights)。最后,根据注意力权重对编码器输出进行加权组合得到上下文向量。

整个过程可以简单概括为:先将隐藏状态和编码器输出连接起来,然后使用线性转换和tanh激活函数计算能量矩阵,再使用线性转换和softmax函数计算注意力权重,最后使用注意力权重对编码器输出进行加权组合得到上下文向量。

二、GRU模型构建+注意力机制

class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, dropout=0.5):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.attention = Attention(hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, hidden = self.gru(x, h0)
        out, attention_weights = self.attention(hidden[-1], out)
        out = self.dropout(out)
        out = self.fc(out)
        return out

GRUModel类的初始化方法中,先调用了父类构造函数初始化。然后定义了一个GRU层,并将其输出传入Attention类中计算上下文向量和注意力权重。最后将上下文向量送入一个线性层,并加上dropout操作,为了防止过拟合现象,然后输出模型的预测结果。通过这个模型的设计,我们可以将输入序列和输出序列的长度变化对模型的性能影响降到最小,并且利用注意力机制使模型能够更好的关注序列中的重要信息。

三、数据生成与加载

# 3. 准备数据集
class SampleDataset(Dataset):
    def __init__(self):
        self.sequences = []
        self.labels = []
        for _ in range(1000):
            seq = torch.randn(10, 5)
            label = torch.zeros(2)
            if seq.sum() > 0:
                label[0] = 1
            else:
                label[1] = 1
            self.sequences.append(seq)
            self.labels.append(label)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]


train_set_split = int(0.8 * len(SampleDataset()))
train_set, test_set = torch.utils.data.random_split(SampleDataset(),
                                                    [train_set_split, len(SampleDataset()) - train_set_split])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

四、模型训练

# 4. 定义训练过程
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        _, true_labels = torch.max(labels, 1)
        total += true_labels.size(0)
        correct += (predicted == true_labels).sum().item()

    print("Train Loss: {:.4f}, Acc: {:.2f}%".format(running_loss / (batch_idx + 1), 100 * correct / total))


# 5. 定义评估过程
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            _, true_labels = torch.max(labels, 1)
            total += true_labels.size(0)
            correct += (predicted == true_labels).sum().item()

    print("Test Loss: {:.4f}, Acc: {:.2f}%".format(running_loss / (batch_idx + 1), 100 * correct / total))


# 6. 训练模型并评估
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GRUModel(input_size=5, hidden_size=10, output_size=2, num_layers=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 100
for epoch in range(num_epochs):
    print("Epoch {}/{}".format(epoch + 1, num_epochs))
    train(model, train_loader, criterion, optimizer, device)
    evaluate(model, test_loader, criterion, device)

运行结果:

Epoch 97/100
Train Loss: 0.0264, Acc: 99.75%
Test Loss: 0.1267, Acc: 94.50%
Epoch 98/100
Train Loss: 0.0294, Acc: 99.75%
Test Loss: 0.1314, Acc: 95.00%
Epoch 99/100
Train Loss: 0.0286, Acc: 99.75%
Test Loss: 0.1280, Acc: 94.50%
Epoch 100/100
Train Loss: 0.0286, Acc: 99.75%
Test Loss: 0.1324, Acc: 95.50%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/496002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工具收集 - 键鼠模拟改建

工具收集 - 键鼠模拟&改建 AutoTinyFreeMouseClickerX-Mouse Button Control AutoTiny 官网:https://autotiny.cn AutoTiny软件是一款PC电脑端使用的自动化录制制作软件,不仅能够实现电脑自动化操作,而且可以控制手机实现自动化操作。 A…

uniApp 实现上传功能(七牛云,node获取上传token)

版本:uniAppvue2uview-ui 需求:利用uView_upload组件实现上传功能 难点:兼容性强,支持pc、App、h5; 1.使用leancloud 实现上传(兼容性弱) JS-SDK 只兼容pc、h5,运行到虚拟机上会报错——uniApp问答详情&am…

深度挖掘.c到.exe的整个过程,透过现象看本质

文章目录 程序的翻译环境和执行环境翻译环境编译预编译头文件的包含删除注释替换#define定义的符号 编译词法分析语法分析语义分析符号汇总 汇编 链接合并段表符号表的合并和重定位 执行环境 程序的翻译环境和执行环境 在ANSI C的任何一种实现中,存在两个不同的环境…

【Jenkins】使用java -jar jenkins.war --httpPort=XXXX启动Jenkins报错【解决方案】

使用java -jar jenkins.war --httpPortXXXX启动Jenkins报错【解决方案】 👉欢迎关注博主【米码收割机】 👉一起学习C、Python主流编程语言。 👉机器人、人工智能等相关领域开发技术。 👉主流开发、测试技能。 文章目录 使用java -…

学习分享|一文搞懂WiFi 6/7 以及选择路由器改造网络那些事

目录 什么是 WiFi 6 WiFi 6 功能特点 WIFI 6 与前几代对比 速度更快 延时更低 容量更大 更安全 更省电 WiFi 4~WiFi 6对比 WiFi 6 核心技术 WiFi 7 WiFi 世代列表 路由器常用技术扩展 2.5Ge 网口 WAN/LAN口复用/网口盲插 双WAN口 双LAN口端口聚合 mesh组网 聊…

实验四 文件系统原理与模拟实现

实验四 文件系统原理与模拟实现 代码资源地址 Java实现的混合索引和成组链接法算法资源-CSDN文库 实验目的: 了解操作系统中文件系统的结构和管理过程,掌握经典的算法:混合索引与成组链接法等方法。 实验内容: 编程模拟实现混合…

【Android取证篇】ADB版本更新详细步骤

【Android取证篇】ADB版本更新详细步骤 更新ADB版本,解决无法连接设备问题【蘇小沐】 ADB没有自动更新的命令,我们需要下载新的ADB进行替换更新。 1、ADB查找 打开任务管理器(快捷键shiftctrlEsc或WinX),在“详细信…

Arcgis通过矢量建筑面找到POI对应的标准地址

背景 有时候我们需要找到POI对应的标准地址,也许有很多的方法, 比如通过POI的地址数据和标准地址做匹配,用sql语句就能实现; 但是POI数据中也存在很多没有地址数据的,这时候只能通过空间关联来匹配对应的标准地址了,而空间关联也有不一样的方法,一个是通过空间连接,找…

数智化转型再加速,低代码开发助力企业转型

毫无疑问,随着数智化转型的加速,越来越多的企业正在把数智化战略提升到一个全新的高度,转型的进程也正从“浅层次”的数智化走向“深层次”数智化的阶段。 据权威机构数据统计,过去几年全球数字经济同比增长15.6%,采取…

DJ5-5/6 与设备无关的 I/O 软件、用户层的 I/O 软件

目录 5.5 与设备无关的 I/O 软件 5.5.1 与设备无关软件的概念 5.5.2 与设备无关的软件的功能 5.5.3 设备分配 5.5.4 逻辑设备名到物理设备名映射的实现 5.6 用户层的 I/O 软件 5.6.1 系统调用与库函数 5.6.2 假脱机技术 SPOOLing 5.5 与设备无关的 I/O 软件 …

鲲鹏昇腾开发者峰会2023举办

[2023年5月6日 广东东莞]今天,以“创未来 享非凡”为主题的鲲鹏昇腾开发者峰会2023在东莞松山湖举办。 鲲鹏产业生态繁荣,稳步发展,正在成为行业核心场景及科研领域首选,加速推动数字化转型;昇腾产业快速蓬勃向上&…

【大数据之Hadoop】二十五、生产调优-HDFS核心参数

1 NameNode内存生产配置 Hadoop3.x系列的NameNode内存是动态分配的,可以用jmap -heap 进程号 查看分配的内存。 在hadoop102中NameNode和DataNode的内存都是自动分配的,且相等。 根据经验: NameNode最小值为1G,每增加1百万个物理…

【JavaEE初阶】多线程带来的风险~线程安全

目录 🌟观察线程不安全的现象 🌟线程不安全的原因 🌈1、多个线程修改了同一个共享变量 🌈2、线程是抢占式执行的,CPU的调度是随机的 🌈3、指令执行时没有保证原子性 🌈4、多线程环境中内…

当无触控板和鼠标的情况下,如何开启触控板

背景:一次出行匆忙,忘记带鼠标,周围也无可用工具,主要是触控板当时也被我关闭了,下面讲述一下我是如何解决在没有鼠标的情况下开启触控板的。 首先我们开启电脑后, 存在两种思路去开启触控板 第一种方案…

加拿大访问学者签证材料清单

加拿大在教育、政府透明度、社会自由度以及生活品质等方面在国际上排名名列前茅,出于环境、社会氛围等因素,不少学者将目光聚焦于这个北美的发达国家。加拿大的访问学者签证属于工作签证,过去只要有邀请函就可以办理,但是自去年2月…

Python:Python底层原理:Python的整数是如何实现的

Python整数在底层存储方式 1. Python整数在底层对应的结构体 PyLongObject2.整数是怎么存储的2.1 整数0存储2.2 整数12.3 整数-12.4. 2**30 -12.5 . 2**302.6 . ob_digit[a, b, c] 对应整数计算 计算整数所占内存大小总结 Python的底层是C/C ,但是 C/C 能表示的整数…

Linux挂载新磁盘到根目录

添加磁盘到需要挂载的机器上 lsblk查看硬盘挂载情况,sdb,sdc为我新挂载的磁盘 fdisk -l查看挂载之前的分区情况 为新硬盘创建分区 fdisk /dev/sdb 终端会提示: Command (m for help):输入:n 依次输入p…

【HTTPS】

HTTP明文传输问题 窃听风险,比如通信链路上可以获取通信内容,用户号容易没。篡改风险,比如强制植入垃圾广告,视觉污染,用户眼容易瞎。冒充风险,比如冒充淘宝网站,用户钱容易没。 TLS协议解决H…

【雅可比左乘右乘】

常见雅可比左乘(以自变量R为例子,围绕旋转点p的旋转点的左扰动雅可比): 旋转点的右扰动雅可比(右乘): 左雅可比和右雅可比之间的区别在于它们各自描述了不同的变换方向。左雅可比将输入空…

硬件-6-基站和移动通信系统的演进

1G、2G、3G、4G、5G 移动通信技术发展简史 1 移动通信系统简介 移动通信系统从第一代移动通信系统(1G)开始逐渐发展,目前已经发展到第四代移动通信系统(4G),第五代移动通信系统(5G)也已经开始标准化,预计2020年商用,6G预计2030年…