Pytorch实现多层LSTM模型,并增加emdedding、Dropout、权重共享等优化

news2024/9/28 11:22:22

简述

本文是 Pytorch封装简单RNN模型,进行中文训练及文本预测 一文的延申,主要做以下改动:

1.将nn.RNN替换为nn.LSTM,并设置多层LSTM:

既然使用pytorch了,自然不需要手动实现多层,注意nn.RNNnn.LSTM 在实例化时均有参数num_layers来指定层数,本文设置num_layers=2

2.新增emdedding层,替换掉原来的nn.functional.one_hot向量化,这样得到的emdedding层可以用来做词向量分布式表示;

3.在emdedding后、LSTM内部、LSTM后均增加Dropout层,来抑制过拟合:

nn.LSTM内部的Dropout可以通过实例化时的参数dropout来设置,需要注意pytorch仅在两层lstm之间应用Dropout,不会在最后一层的LSTM输出上应用Dropout

emdedding后、LSTM后与线性层之间则需要手动添加Dropout层。

4.考虑emdedding与最后的Linear层共享权重:

这样做可以在保证精度的情况下,减少学习参数,但本文代码没有实现该部分。

不考虑第四条时,模型结构如下:

在这里插入图片描述

代码

模型代码:

class MyLSTM(nn.Module):  
    def __init__(self, vocab_size, wordvec_size, hidden_size, num_layers=2, dropout=0.5):  
        super(MyLSTM, self).__init__()  
        self.vocab_size = vocab_size  
        self.word_vec_size = wordvec_size  
        self.hidden_size = hidden_size  
  
        self.embedding = nn.Embedding(vocab_size, wordvec_size)  
        self.dropout = nn.Dropout(dropout)  
        self.rnn = nn.LSTM(wordvec_size, hidden_size, num_layers=num_layers, dropout=dropout)  
        # self.rnn = rnn_layer  
        self.linear = nn.Linear(self.hidden_size, vocab_size)  
  
    def forward(self, x, h0=None, c0=None):  
        # nn.Embedding 需要的类型 (IntTensor or LongTensor)        # 传过来的X是(batch_size, seq), embedding之后 是(batch_size, seq, vocab_size)  
        # nn.LSTM 支持的X默认为(seq, batch_size, vocab_size)  
        # 若想用(batch_size, seq, vocab_size)作参数, 则需要在创建self.embedding实例时指定batch_first=True  
        # 这里用(seq, batch_size, vocab_size) 作参数,所以先给x转置,再embedding,以便再将结果传给lstm  
        x = x.T  
        x.long()  
        x = self.embedding(x)  
  
        x = self.dropout(x)  
  
        outputs = self.dropout(outputs)  
  
        outputs = outputs.reshape(-1, self.hidden_size)  
  
        outputs = self.linear(outputs)  
        return outputs, (h0, c0)  
  
    def init_state(self, device, batch_size=1):  
        return (torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device),  
                torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device))

训练代码:

模型应用可以参考 Pytorch封装简单RNN模型,进行中文训练及文本预测 一文。

def start_train():  
    # device = torch.device("cpu")  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
    print(f'\ndevice: {device}')  
  
    corpus, vocab = load_corpus("../data/COIG-CQIA/chengyu_qa.txt")  
  
    vocab_size = len(vocab)  
    wordvec_size = 100  
    hidden_size = 256  
    epochs = 1  
    batch_size = 50  
    learning_rate = 0.01  
    time_size = 4  
    max_grad_max_norm = 0.5  
    num_layers = 2  
    dropout = 0.5  
  
    dataset = make_dataset(corpus=corpus, time_size=time_size)  
    data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)  
  
    net = MyLSTM(vocab_size=vocab_size, wordvec_size=wordvec_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)  
    net.to(device)  
  
    # print(net.state_dict())  
  
    criterion = nn.CrossEntropyLoss()  
    criterion.to(device)  
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)  
  
    writer = SummaryWriter('./train_logs')  
    # 随便定义个输入, 好使用add_graph  
    tmp = torch.randint(0, 100, size=(batch_size, time_size)).to(device)  
    h0, c0 = net.init_state(batch_size=batch_size, device=device)  
    writer.add_graph(net, [tmp, h0, c0])  
  
    loss_counter = 0  
    total_loss = 0  
    ppl_list = list()  
    total_train_step = 0  
  
    for epoch in range(epochs):  
        print('------------Epoch {}/{}'.format(epoch + 1, epochs))  
  
        for X, y in data_loader:  
            X, y = X.to(device), y.to(device)  
            # 这里batch_size=X.shape[0]是因为在加载数据时, DataLoader没有设置丢弃不完整的批次, 所以存在实际批次不满足设定的batch_size  
            h0, c0 = net.init_state(batch_size=X.shape[0], device=device)  
            outputs, (hn, cn) = net(X, h0, c0)  
            optimizer.zero_grad()  
            # y也变成 时间序列*批次大小的行数, 才和 outputs 一致  
            y = y.T.reshape(-1)  
            # 交叉熵的第二个参数需要LongTorch  
            loss = criterion(outputs, y.long())  
            loss.backward()  
            # 求完梯度之后可以考虑梯度裁剪, 再更新梯度  
            grad_clipping(net, max_grad_max_norm)  
            optimizer.step()  
  
            total_loss += loss.item()  
            loss_counter += 1  
            total_train_step += 1  
            if total_train_step % 10 == 0:  
                print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  
                writer.add_scalar('train_loss', loss.item(), total_train_step)  
  
        ppl = np.exp(total_loss / loss_counter)  
        ppl_list.append(ppl)  
        print(f'Epoch {epoch + 1} 结束, batch_loss_average: {total_loss / loss_counter}, perplexity: {ppl}')  
        writer.add_scalar('ppl', ppl, epoch + 1)  
        total_loss = 0  
        loss_counter = 0  
  
        torch.save(net.state_dict(), './save/epoch_{}_ppl_{}.pth'.format(epoch + 1, ppl))  
  
    writer.close()  
    return net, ppl_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2089046.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM1-初识JVM

目录 什么是JVM JVM的功能 解释和运行 内存管理 即时编译 Java性能低的主要原因和跨平台特性 常见的JVM 什么是JVM JVM 全称是 Java Virtual Machine,中文译名:Java虚拟机 JVM本质上是一个运行在计算机上的程序,它的职责是运行Java字…

AI大模型编写多线程并发框架(六十三):监听器优化·下

系列文章目录 文章目录 系列文章目录前言一、项目背景二、第十一轮对话-修正运行时数据三、修正任务计数器四、第十二轮对话-生成单元测试五、验证通过七、参考文章 前言 在这个充满技术创新的时代,AI大模型正成为开发者们的新宠。它们可以帮助我们完成从简单的问答…

C++,如何写单元测试用例?

文章目录 1. 概述1.1 什么是单元测试?1.2 为什么要做单元测试? 2. 写测试用例的方法3. 编写测试用例的通用原则3.1 目的性原则3.2 独立性原则3.3 可重复性原则3.4 小规模原则3.5 一致性原则3.6 自动化原则3.7 边界条件原则3.8 错误检测原则3.9 性能原则3…

西门子PLC控制激光读头,profient转Ethernet IP网关应用

在智能制造的浪潮下,企业对于生产线的灵活性、智能化水平以及数据交互能力提出了更高要求。西门子PLC以其高可靠性和丰富的功能模块,广泛应用于各种自动化生产线中。而激光读头作为精密测量与定位的关键设备,其高精度、非接触式测量特性在自动…

力扣862.和至少为K的最短子数组

力扣862.和至少为K的最短子数组 双端单调队列 前缀和 用单调队列存遍历过的前缀和&#xff0c;同时两个优化 1. 2. class Solution {public:int shortestSubarray(vector<int>& nums, int k) {int n nums.size(),ans n 1;long s[n1];s[0] 0L;for(int i0;i…

1999-2023年上市公司年报文本数据(PDF+TXT)

1999-2023年上市公司年报文本数据&#xff08;PDFTXT&#xff09; 1、时间&#xff1a;1999-2023年 2、来源&#xff1a;上市公司年度报告 3、范围&#xff1a;A股上市公司&#xff0c;5600企业&#xff0c;6.3W份 4、格式&#xff1a;PDFTXT 5、下载链接&#xff1a; 199…

东方通Web服务器(TongWeb)控制台部署改自动部署操作

首先将控制台部署改自动部署的应用进行解除部署&#xff0c;具体如下&#xff1a;登录TongWeb管理控制台&#xff0c;在左侧导航栏中点击“应用管理”&#xff0c;通过应用列表中第一列复选框选中要解除部署的应用&#xff0c;点击“解部署”&#xff0c;完成应用解除部署操作。…

4.Copy Constructor的构造操作

目录 1、对象赋值问题引入 2、Bitwise Copy Semantics&#xff08;位逐次拷贝&#xff09; 3、处理class virtual function 4、处理virtual base class subobject 1、对象赋值问题引入 在C中&#xff0c;有三种情况会以一个object的内容作为另一个class object的初值。这三…

Upload-labs靶场通过攻略

pass-01 1.写一个一句话木马 2.上传php文件 当我们上传php文件时 提示文件类型不正确 3.修改php后缀 通过修改php后缀为jpg 抓包再次修改成php文件 4.查看是否上传成功 页面显示图片 表示上传成功 pass-02 1.上传一个php文件 页面显示文件类型不正确 2.抓包修改 可以看…

【Python零基础】文件使用和异常处理

文章目录 前言一、从文件中读取数据二、向文件中写入数据三、异常四、存储数据总结 前言 本篇笔者将展示Python如何处理文件数据&#xff0c;包括文件内容的读取和写入操作&#xff0c;以及程序运行时异常模块的处理方式&#xff0c;保证我们写出健壮的代码。 一、从文件中读取…

Nature揭示应变不变的射频电子器件新突破,无线健康监测的前景

【行业背景】 可拉伸电子设备是未来柔性电子技术发展的重要趋势。这些设备在皮肤接口、健康监测、智能穿戴等领域发挥着关键作用&#xff0c;离不开高性能的射频&#xff08;RF&#xff09;电子组件。射频电子设备的功能依赖于其基板材料的电气性能&#xff0c;然而传统的弹性…

突发:Runway 从 HuggingFace 上删库跑路,究竟发生了什么?

&#x1f525; 突发新闻&#xff1a;Runway 从 HuggingFace 上删库跑路&#xff0c;究竟发生了什么&#xff1f; 1️⃣ Runway 从 HuggingFace 上删库跑路&#xff01;究竟是技术问题还是另有隐情&#xff1f; 最近科技圈内流传着一则令人瞠目结舌的消息&#xff1a;曾经为AI图…

5款自动生成文案的神器,助你轻松创作优质文案

随着人工智能技术的发展&#xff0c;生活中的很多工作都可以自动化操作&#xff0c;就连创作文案也不再会让人绞尽脑汁的去思考怎么写&#xff0c;因为有了自动生成文案的神器&#xff0c;从而使创作者在写作文案的过程中更加得心应手&#xff0c;并且不费吹灰之力便能拥有优质…

优思学院|精益生产中现场管理的7大工具

在现代制造业中&#xff0c;精益生产&#xff08;Lean Production&#xff09;已成为提升生产效率、确保产品质量的关键方法论。精益生产的核心思想在于消除浪费、持续改进&#xff0c;而要实现这些目标&#xff0c;依赖于一系列行之有效的管理工具。在这篇文章中&#xff0c;我…

爆品是测出来的,不是选出来的

我在亚马逊摸爬滚打了五年&#xff0c;深深感受到了"七分选品&#xff0c;三分运营"的重要性。不管你的产品图片、描述多么精美&#xff0c;如果不去精选和测试&#xff0c;很难保证能出单。我见过很多跨境新手在选品上卡了几个月&#xff0c;纠结于卖什么。但实际上…

一次VUE3 使用axios调用萤石云OpenAPI踩坑经历

通过调用萤石云的获取设备列表功能&#xff0c;我们可以根据 ACCESS_TOKEN 获取该用户下的设备列表。 Python 调用接口 根据接口文档[1]&#xff0c;使用Python&#xff0c;很轻松就能获取到该列表&#xff0c;代码如下&#xff08;该代码用于拼接生成vue代码&#xff0c;这是…

爱浦路云化核心网:支持百万用户规模,构筑超快海量连接网络

广州爱浦路网络技术有限公司&#xff08;简称&#xff1a;IPLOOK&#xff09;是全球领先的4G/5G/6G核心网厂商&#xff0c;致力于向全球客户提供端到端的移动通信解决方案&#xff0c;其产品和服务覆盖了卫星通信、能源通信、电网通信等多个重要领域。经过十二年的探索与发展&a…

英文论文格式编辑(二)

这里写自定义目录标题 正文部分段落格式段落对齐方式conclusion图片左右对齐 正文部分段落格式 出现下面这种箭头&#xff0c;是使用了标题格式 在这个样式里面修改 包括图片啥的&#xff0c;都别用标题格式&#xff0c;按道理来说&#xff0c;一个标题的箭头是能把下面的内…

如何构建短视频矩阵?云微客开启多账号协同作战

你有没有疑惑过&#xff0c;为什么有些账号每一次发布视频&#xff0c;都要艾特一下其他账号呢&#xff1f;那些被艾特的账号&#xff0c;你有点进去关注过吗&#xff1f;其实做过运营的都或多或少的接触过矩阵&#xff0c;短视频矩阵的玩法现在也逐步成为了趋势。企业通过多账…

深度学习分类模型训练代码模板

深度学习分类模型训练代码模板 简介 参数模块 采用argparse模块进行配置&#xff0c;便于服务器上训练&#xff0c;以及超参数记录。在服务器上进行训练时&#xff0c;通常采用命令行启动&#xff0c;或时采用sh脚本批量训练&#xff0c;这时候就需要从命令行传入一些参数&a…