构建文本摘要Baseline并且开始训练

news2024/11/25 15:31:01

构建文本摘要Baseline并且开始训练

基于前面word2vec的原理与训练实践、seq2seq模型的原理与实践以及attention机制,已经分别写了相关的文章来记录,此篇文章就是基于前面所学,开始着手训练文本摘要模型,当然仅是一个比较普通的baseline,后面还会不断优化模型。

构建seq2seq模型

首先利用上一节seq2seq实践中,封装的encoder、decoder和attention,集成到此模型中来,另外就是增加了一个训练技巧–teacher forcing。那么teacher forcing是啥意思呢?

在这里插入图片描述

seq2seq模型的输出为decoder解码出的一系列概率分布,因此采用何种方式进行解码,就显得尤为重要。如贪心解码(greedy search)teacher forcing以及介于两种之间的beam search等。
    贪心解码的思想是,预测 t 时刻输出的单词时,直接将t−1时刻的输出词汇表中概率最大的单词,作为t时刻的输入,因此可能导致如果前一个预测值就不准的话,后面一系列都不准的问题
    Teacher Forcing的方法是,预测 t时刻输出的单词时,直接将t−1时刻的实际单词,作为输入,因此可能带来的问题是,训练过程预测良好(因为有标签,即实际单词),但是测试过程极差(因为测试过程不会给对应的真实单词)
    实际应用中,往往采用介于这两种极端方式之间的解码方式,如beam search 等,具体思路是预测 t 时刻输出的单词时,保留t−1时刻的输出词汇表中概率最大的前K个单词,以此带来更多的可能性(解决第一个方法的缺陷);而且在训练过程,采用一定的概率P,来决定是否使用真实单词作为输入(解决第二个方法的缺陷)。greedy search 和beam search后面我们也会一一介绍,下面是teacher forcing的具体实现。

import tensorflow as tf

from src.seq2seq_tf2.model_layers import Encoder, BahdanauAttention, Decoder
from src.utils.gpu_utils import config_gpu
from src.utils.params_utils import get_params
from src.utils.wv_loader import load_embedding_matrix, Vocab


class Seq2Seq(tf.keras.Model):
    def __init__(self, params, vocab):
        super(Seq2Seq, self).__init__()
        self.embedding_matrix = load_embedding_matrix()
        self.params = params
        self.vocab = vocab
        self.batch_size = params["batch_size"]
        self.enc_units = params["enc_units"]
        self.dec_units = params["dec_units"]
        self.attn_units = params["attn_units"]
        self.encoder = Encoder(self.embedding_matrix,
                               self.enc_units,
                               self.batch_size)

        self.attention = BahdanauAttention(self.attn_units)

        self.decoder = Decoder(self.embedding_matrix,
                               self.dec_units,
                               self.batch_size)

    def teacher_decoder(self, dec_hidden, enc_output, dec_target):
        predictions = []

        # 第一个输入<START>
        dec_input = tf.expand_dims([self.vocab.START_DECODING_INDEX] * self.batch_size, 1)

        #  Teacher forcing 将target作为下一次的输入,依次解码
        for t in range(1, dec_target.shape[1]):
            # passing enc_output to the decoder
            # 应用decoder来一步一步预测生成词语概论分布
            pred, dec_hidden, _ = self.decoder(dec_input, dec_hidden, enc_output)
            dec_input = tf.expand_dims(dec_target[:, t], 1)

            predictions.append(pred)

        return tf.stack(predictions, 1), dec_hidden

开始训练

import tensorflow as tf

from src.seq2seq_tf2.seq2seq_model import Seq2Seq
from src.seq2seq_tf2.train_helper import train_model
from src.utils.gpu_utils import config_gpu
from src.utils.params_utils import get_params
from src.utils.wv_loader import Vocab


def train(params):
    # GPU资源配置
    config_gpu(use_cpu=True)

    # 读取vocab训练
    vocab = Vocab(params["vocab_path"], params["vocab_size"])

    params['vocab_size'] = vocab.count

    # 构建模型
    print("Building the model ...")
    model = Seq2Seq(params, vocab)

    # 获取保存管理者
    checkpoint = tf.train.Checkpoint(Seq2Seq=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['checkpoint_dir'], max_to_keep=5)

    # 训练模型
    train_model(model, vocab, params, checkpoint_manager)

import tensorflow as tf

# from src.pgn_tf2.batcher import batcher
from src.seq2seq_tf2.seq2seq_batcher import train_batch_generator
import time
from functools import partial


def train_model(model, vocab, params, checkpoint_manager):
    epochs = params['epochs']

    pad_index = vocab.word2id[vocab.PAD_TOKEN]

    # 获取vocab大小
    params['vocab_size'] = vocab.count

    optimizer = tf.keras.optimizers.Adam(name='Adam', learning_rate=params['learning_rate'])

    train_dataset, val_dataset, train_steps_per_epoch, val_steps_per_epoch = train_batch_generator(
        params['batch_size'], params['max_enc_len'], params['max_dec_len'], params['buffer_size']
    )

    for epoch in range(epochs):
        start = time.time()
        enc_hidden = model.encoder.initialize_hidden_state()

        total_loss = 0.
        running_loss = 0.
        for (batch, (inputs, target)) in enumerate(train_dataset.take(train_steps_per_epoch), start=1):

            batch_loss = train_step(model, inputs, target, enc_hidden,
                                    loss_function=partial(loss_function, pad_index=pad_index),
                                    optimizer=optimizer)
            total_loss += batch_loss

            if batch % 50 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                             batch,
                                                             (total_loss - running_loss) / 50))
                running_loss = total_loss
        # saving (checkpoint) the model every 2 epochs
        if (epoch + 1) % 2 == 0:
            ckpt_save_path = checkpoint_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
                                                                ckpt_save_path))

        valid_loss = evaluate(model, val_dataset, val_steps_per_epoch,
                              loss_func=partial(loss_function, pad_index=pad_index))

        print('Epoch {} Loss {:.4f}; val Loss {:.4f}'.format(
            epoch + 1, total_loss / train_steps_per_epoch, valid_loss)
        )

        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))


# 定义损失函数
def loss_function(real, pred, pad_index):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    mask = tf.math.logical_not(tf.math.equal(real, pad_index))
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)


def train_step(model, enc_inp, dec_target, enc_hidden, loss_function=None, optimizer=None, mode='train'):
    with tf.GradientTape() as tape:

        enc_output, enc_hidden = model.encoder(enc_inp, enc_hidden)
        # 第一个隐藏层输入
        dec_hidden = enc_hidden

        # 逐个预测序列
        predictions, _ = model.teacher_decoder(dec_hidden, enc_output, dec_target)

        batch_loss = loss_function(dec_target[:, 1:], predictions)

        if mode == 'train':
            variables = (model.encoder.trainable_variables + model.decoder.trainable_variables
                         + model.attention.trainable_variables)

            gradients = tape.gradient(batch_loss, variables)

            gradients, _ = tf.clip_by_global_norm(gradients, 1.0)

            optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss


def evaluate(model, val_dataset, val_steps_per_epoch, loss_func):
    print('Starting evaluate ...')
    total_loss = 0.
    enc_hidden = model.encoder.initialize_hidden_state()
    for (batch, (inputs, target)) in enumerate(val_dataset.take(val_steps_per_epoch), start=1):
        batch_loss = train_step(model, inputs, target, enc_hidden,
                                loss_function=loss_func, mode='val')
        total_loss += batch_loss
    return total_loss / val_steps_per_epoch

from src.build_data.data_loader import load_dataset 
import tensorflow as tf
from src.utils import config
from tqdm import tqdm


def train_batch_generator(batch_size, max_enc_len=200, max_dec_len=50, buffer_size=5, sample_sum=None):
    # 加载数据集
    train_X, train_Y = load_dataset(config.train_x_path, config.train_y_path,
                                    max_enc_len, max_dec_len)
    val_X, val_Y = load_dataset(config.test_x_path, config.test_y_path,
                                max_enc_len, max_dec_len)
    if sample_sum:
        train_X = train_X[:sample_sum]
        train_Y = train_Y[:sample_sum]
    print(f'total {len(train_Y)} examples ...')
    train_dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y)).shuffle(len(train_X),
                                                                                   reshuffle_each_iteration=True)
    val_dataset = tf.data.Dataset.from_tensor_slices((val_X, val_Y)).shuffle(len(val_X),
                                                                             reshuffle_each_iteration=True)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True).prefetch(buffer_size)
    val_dataset = val_dataset.batch(batch_size, drop_remainder=True).prefetch(buffer_size)
    train_steps_per_epoch = len(train_X) // batch_size
    val_steps_per_epoch = len(val_X) // batch_size
    return train_dataset, val_dataset, train_steps_per_epoch, val_steps_per_epoch

def load_dataset(x_path, y_path, max_enc_len, max_dec_len, sample_sum=None):
    x = np.load(x_path+".npy")
    y = np.load(y_path+".npy")

    if sample_sum:
        x = x[:sample_sum, :max_enc_len]
        y = y[:sample_sum, :max_dec_len]
    else:
        x = x[:, :max_enc_len]
        y = y[:, :max_dec_len]
    return x, y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/84158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]Node.js计算机毕业设计大学体育馆预约系统Express

项目运行 环境配置&#xff1a; Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术&#xff1a; Express框架 Node.js Vue 等等组成&#xff0c;B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境&#xff1a;最好是Nodejs最新版&#xff0c;我…

vue框架搭建大屏自适应方案

vue框架搭建大屏自适应方案 1.可使用flexible.js rem实现宽高&#xff0c;字体自适应 附上flexible.js代码 (function flexible(window, document) {var docEl document.documentElement;var dpr window.devicePixelRatio || 1;// adjust body font sizefunction setBody…

NocasRule负载均衡与服务实例的权重设置

NocasRule负载均衡 .yml 配置文件配置 server:port: 8080 spring:application:name: orderservicecloud:nacos:server-addr: localhost:8848 #nocas服务地址discovery:cluster-name: HZ #集群名字 userservice: #要做配置的微服务名称ribbon:NFLoadBalancerRuleClassName: com…

游戏开发57课 性能优化14

5. 内存优化 内存优化目的是加快IO&#xff0c;防止卡主线程&#xff0c;防止频繁操作&#xff08;创建/删除&#xff09;内存&#xff0c;避免内存碎片化和占用过高。 5.1 缓存法 与CPU的缓存计算类似&#xff0c;思路是将需要重复创建的对象缓存起来&#xff0c;销毁时将它…

安装、启动与停止Apache服务

安装、启动与停止Apache服务 安装Apache相关软件 [rootcentos7 ~]# rpm -q httpd [rootcentos7-1 ~]# mkdir /opt/centos //创建目录/opt/centos [rootcentos7-1 ~]# mount /dev/cdrom /opt/centos //挂载光盘到/opt/centos 下 mount: /dev/sr0 写保护…

Spring Boot 3.0.0正式发布,Banner不再支持图片增强可观测性

本文已被https://yourbatman.cn收录&#xff1b;女娲Knife-Initializr工程可公开访问啦&#xff1b;程序员专用网盘https://wangpan.yourbatman.cn&#xff1b;技术专栏源代码大本营&#xff1a;https://github.com/yourbatman/tech-column-learning&#xff1b;公号后台回复“…

openCV(一)基础背景

1 认识计算机视觉 2012年AlexNet模型在ImageNet图像分类中获得比赛冠军&#xff0c;深度学习开始在计算机视觉领域流行。早期的计算机视觉主要集中在重建方面&#xff0c;2012年以后在感知和重建两个领域都受到了深度学习的影响。应用场景包括自动驾驶、机器视觉、安防监控、其…

猿如意中的【PostgreSQL 数据库】工具详情介绍

猿如意中的【PostgreSQL 数据库】工具详情介绍 一、工具名称 PostgreSQL 数据库 二、下载安装渠道 PostgreSQL 数据库V14.2 通过CSDN官方开发的【猿如意】客户端进行下载安装。 2.1 什么是猿如意&#xff1f; 猿如意是一款面向开发者的辅助开发工具箱&#xff0c;包含了效…

jenkins-pipeline与变量

本文介绍如何在pipeline中使用变量 使用jenkins预定义的环境变量 jenkins预先定义了一些环境变量&#xff0c;在pipeline中使用${env.key}来调用 另外安装了第三方插件&#xff0c;会有新的环境变量&#xff0c;可以使用插件Environment Inject来查看 在pipeline中使用预定义…

Java二维数组项目练习

T1.显示所有书店客户的信息 示例代码 public static void main(String[] args) {String[][] users{{"1100","18","100"},{"1101","24","834"},{"1102","13","20000"},{"1103…

软件测试——用例篇

文章目录为什么在测试前要设计测试用例基于需求设计测试用例等价类边界值错误猜测法场景法因果图正交法为什么在测试前要设计测试用例 测试用例是执行测试的依据。可以复用&#xff08;回归测试的时候&#xff09;衡量需求的覆盖率自动化测试的依据有借鉴意义&#xff0c;后续…

OH----原子量的妙用--保护usb时序

1、问题&#xff1a; 展锐平台&#xff0c;usb otg高概率不能正确检测识别到 2、思路&#xff1a; usb使用musb控制器&#xff0c;展锐的平台处理代码是musb_sprd.c&#xff0c;在这个文件中对usb mode做检测和切换&#xff0c;log级别跳到最高&#xff0c;在probe中的关键函…

用 Taichi 加速 Python:提速 100+ 倍!

Python 已经成为世界上最流行的编程语言&#xff0c;尤其在深度学习、数据科学等领域占据主导地位。但是由于其解释执行的属性&#xff0c;Python 较低的性能很影响它在计算密集&#xff08;比如多重 for 循环&#xff09;的场景下发挥作用&#xff0c;实在让人又爱又恨。如果你…

PAT(乙级)2022年冬季考试

此前先后花了十元去做了乙级题&#xff0c;从最开始分别是70&#xff0c;35&#xff0c;43&#xff0c;33&#xff08;途中做了RobpCom,只搞定了签到题&#xff09;&#xff0c;想着报今年的冬季赛&#xff0c;但是报名费有点高啊&#xff0c;加上做下来感觉不怎么样&#xff0…

[附源码]Python计算机毕业设计Django的黄河文化科普网站

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

GitHub搜索开源项目

GitHub的流行&#xff0c; GitHub在开源世界的受欢迎程度自不必多言。再加上今天&#xff0c;GitHub官方又搞了个大新闻&#xff1a;私有仓库也改为免费使用&#xff0c;这在原来可是需要真金白银的买的。可见微软收购后&#xff0c;依然没有改变 GitHub 的定位&#xff0c;甚至…

使用高德地图展示点位和信息窗体展示数据及播放视频

使用高德地图做了一个在地图展示点位&#xff0c;并通过点击&#xff0c;显示直播的功能&#xff0c;这个任务是为了之后大屏做准备。 这是一个能展示多个点标记&#xff0c;并在点击的时候弹出信息窗体&#xff0c;并在信息窗体中播放视频&#xff0c;且展示相关信息以及操作…

【Lilishop商城】No3-6.模块详细设计,商品模块-2(商品及强关联附属 商品sku、批发、图册等等)的详细设计

仅涉及后端&#xff0c;全部目录看顶部专栏&#xff0c;代码、文档、接口路径在&#xff1a; 【Lilishop商城】记录一下B2B2C商城系统学习笔记~_清晨敲代码的博客-CSDN博客 全篇会结合业务介绍重点设计逻辑&#xff0c;其中重点包括接口类、业务类&#xff0c;具体的结合源代码…

队列的练习题

用队列实现栈 请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09; 实现 MyStack 类&#xff1a; void push(int x) 将元素 x 压入栈顶int pop()移除并返回栈顶元素…

I2C总线式驱动开发

文章目录前言一、Linux内核对I2C总线的支持1.1、理解I2C设备驱动、I2C总线驱动以及I2C核心之间的关系1.2、i2c二级外设驱动开发涉及到核心结构体及其相关接口函数&#xff1a;二、I2C总线二级外设驱动开发方法-名称匹配2.1、i2c二级外设client框架&#xff1a;2.2、i2c二级外设…