使用 Amazon SageMaker 构建文本摘要应用

news2024/11/17 13:49:24

背景介绍

文本摘要,就是对给定的单个或者多个文档进行梗概,即在保证能够反映原文档的重要内容的情况下,尽可能地保持简明扼要。质量良好的文摘能够在信息检索过程中发挥重要的作用,比如利用文摘代替原文档参与索引,可以有效缩短检索的时间,同时也能减少检索结果中的冗余信息,提高用户体验。随着信息爆炸时代的到来,自动文摘逐渐成为自然语言处理领域的一项重要的研究课题。

文本摘要的需求来自多个我们真实的客户案例,对于大量的长文本对于新闻领域,金融领域,法律领域是司空见惯的。而在人力成本越来越高的今天,雇佣大量的专业人员进行信息精炼或者内容审核无疑要投入大量的资金。而自动文本摘要就显得意义非凡,具体来说,通过大量数据训练的深度学习模型可以在几百毫秒内产生长度可控的文本摘要,这大大地提升了摘要生成效率,节约了大量人力成本。

对于目前的技术,可以根据摘要产生的方式大体可以分为两类:1)抽取式文本摘要:找到一个文档中最重要的几个句子并对其进行拼接;2)生成式文本摘要:直接建模为序列到序列的生成问题,根据源文本直接递归生成摘要。对于抽取式摘要,其具备效率高,解释性强的优势,但是抽取得到的文本在语义连续性上相较生成式摘要有所不足,故这里我们主要展示生成式摘要。

Amazon SageMaker是亚马逊云计算(Amazon Web Service)的一项完全托管的机器学习平台服务,算法工程师和数据科学家可以基于此平台快速构建、训练和部署机器学习 (ML) 模型,而无需关注底层资源的管理和运维工作。它作为一个工具集,提供了用于机器学习的端到端的所有组件,包括数据标记、数据处理、算法设计、模型训练、训练调试、超参调优、模型部署、模型监控等,使得机器学习变得更为简单和轻松;同时,它依托于 Amazon 强大的底层资源,提供了高性能 CPU、GPU、弹性推理加速卡等丰富的计算资源和充足的算力,使得模型研发和部署更为轻松和高效。同时,本文还基于 Huggingface,Huggingface 是 NLP 著名的开源社区,并且与 Amazon SagaMaker 高度适配,可以在 Amazon SagaMaker 上以几行代码轻松实现 NLP 模型训练和部署。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

解决方案概览

在此示例中, 我们将使用 Amazon SageMaker 执行以下操作:

  • 环境准备
  • 下载数据集并将其进行数据预处理
  • 使用本地机器训练
  • 使用 Amazon SageMaker BYOS 进行模型训练
  • 托管部署及推理测试

环境准备

我们首先要创建一个 Amazon SageMaker Notebook,笔记本实例类型最好选择 ml.p3.2xlarge,因为本例中用到了本地机器训练的部分用来测试我们的代码,卷大小建议改成10GB或以上,因为运行该项目需要下载一些额外的数据。

 

笔记本启动后,打开页面上的终端,执行以下命令下载代码。

cd ~/SageMaker
git clone https://github.com/HaoranLv/nlp_transformer.git

下载数据集并将其进行数据预处理

这里给出若干开源的中英文数据集:

1.公开数据集 (英文)

  • XSUM,227k BBC articles
  • CNN/Dailymail,93k articles from the CNN, 220k articles from the Daily Mail
  • NEWSROOM,3M article-summary pairs written by authors and editors in the newsrooms of 38 major publications
  • Multi-News,56k pairs of news articles and their human-written summaries from the http://com
  • Gigaword,4M examples extracted from news articles,the task is to generate theheadline from the first sentence
  • arXiv, PubMed,two long documentdatasets of scientific publications from http://org(113k) andPubMed (215k). The task is to generate the abstract fromthe paper body.
  • BIGPATENT,3 millionU.S. patents along with human summaries under nine patent classification categories

2.公开数据集 (中文)

  • 哈工大的新浪微博短文本摘要 LCSTS
  • 教育新闻自动摘要语料 chinese_abstractive_corpus
  • NLPCC 2017 task3 Single Document Summarization
  • 娱乐新闻等 “神策杯”2018高校算法大师赛

本文以 Multi-News 为例,数据分为两列,headlines 代表摘要,text 代表全文。由于文本数据集较小,故直接官网下载原始 csv 文件上传到 SageMaker Notebook 即可。如下是部分数据集样例。

 

找到 hp_data.ipynb 运行代码。

首先加载数据集

df=pd.read_csv(./data/hp/summary/news_summary.csv')

而后进行数据清洗

class Settings:

    TRAIN_DATA = "./data/hp/summary/news_summary_total.csv"
    Columns = ['headlines', 'text']
    encoding = 'latin-1'
    columns_dict = {"headlines": "headlines", "text": "text"}
    df_column_list = ['text', 'headlines']
    SUMMARIZE_KEY = ""
    SOURCE_TEXT_KEY = 'text'
    TEST_SIZE = 0.2
    BATCH_SIZE = 16
    source_max_token_len = 128
    target_max_token_len = 50
    train_df_len = 82332
    test_df_len = 20583
    
class Preprocess:
    def __init__(self):
        self.settings = Settings

    def clean_text(self, text):
        text = text.lower()
        text = re.sub('\[.*?\]', '', text)
        text = re.sub('https?://\S+|www\.\S+', '', text)
        text = re.sub('<.*?>+', '', text)
        text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
        text = re.sub('\n', '', text)
        text = re.sub('\w*\d\w*', '', text)
        return text

    def preprocess_data(self, data_path):
        df = pd.read_csv(data_path, encoding=self.settings.encoding, usecols=self.settings.Columns)
        # simpleT5 expects dataframe to have 2 columns: "source_text" and "target_text"
        df = df.rename(columns=self.settings.columns_dict)
        df = df[self.settings.df_column_list]
        # T5 model expects a task related prefix: since it is a summarization task, we will add a prefix "summarize: "
        df[self.settings.SOURCE_TEXT_KEY] = df[self.settings.SOURCE_TEXT_KEY]

        return df
settings=Settings
preprocess=Preprocess()
df = preprocess.preprocess_data(settings.TRAIN_DATA)

随后完成训练集和测试集的划分并分别保存:

df.to_csv('./data/hp/summary/news_summary_cleaned.csv',index=False)
df2=pd.read_csv('./data/hp/summary/news_summary_cleaned.csv')
order=['text','headlines']
df3=df2[order]
train_df, test_df = train_test_split(df3, test_size=0.2,random_state=100)
train_df.to_csv('./data/hp/summary/news_summary_cleaned_train.csv',index=False)
test_df.to_csv('./data/hp/summary/news_summary_cleaned_test.csv',index=False)

 

使用本地机器训练

在完成了上述的数据处理过程后,就可以进行模型训练了,下面的命令运行后即开始模型训练,代码会自动 Huggingface hub 中加载 google/pegasus-large 作为预训练模型,而后使用我们处理后的数据集进行模型训练。

!python -u examples/pytorch/summarization/run_summarization.py \
--model_name_or_path google/pegasus-large \
--do_train \
--do_eval \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--save_strategy epoch \
--evaluation_strategy epoch \
--overwrite_output_dir \
--predict_with_generate \
--train_file './data/hp/summary/news_summary_cleaned_train.csv' \
--validation_file './data/hp/summary/news_summary_cleaned_test.csv' \
--text_column 'text' \
--summary_column 'headlines' \
--output_dir='./models/local_train/pegasus-hp' \
--num_train_epochs=1.0 \
--eval_steps=500 \
--save_total_limit=3 \
--source_prefix "summarize: " > train_pegasus.log

训练完成后,会提示日志信息如下。 

  

并且会对验证集的数据进行客观指标评估,这里使用 Rouge 进行评估。

 

模型结果文件及相应的日志等信息会自动保存在./models/local_train/pegasus-hp/checkpoint-500

 

我们可以直接用这个产生的模型文件进行本地推理。注意这里的模型文件地址的指定为你刚刚训练产生的。

import pandas as pd
df=pd.read_csv('./data/hp/summary/news_summary_cleaned_small_test.csv')
print('原文:',df.loc[0,'text'])
print('真实标签:',df.loc[0,'headlines'])
from transformers import pipeline
summarizer=pipeline("summarization",model="./models/local_train/Pegasus-hp/checkpoint-500")
print('模型预测:',summarizer(df.loc[0,'text'], max_length=50)[0]['summary_text'])

输出如下:

原文: Germany on Wednesday accused Vietnam of kidnapping a former Vietnamese oil executive Trinh Xuan Thanh, who allegedly sought asylum in Berlin, and taking him home to face accusations of corruption. Germany expelled a Vietnamese intelligence officer over the suspected kidnapping and demanded that Vietnam allow Thanh to return to Germany. However, Vietnam said Thanh had returned home by himself.
真实标签: Germany accuses Vietnam of kidnapping asylum seeker 
模型预测: Germany accuses Vietnam of kidnapping ex-oil exec, taking him home

 

到这里,就完成了一个模型的本地训练和推理过程。

使用 Amazon SageMaker BYOS 进行模型训练

在上文的范例中,我们使用本地环境一步步的训练了一个较小的模型,验证了我们的代码。现在,我们需要把代码进行整理,在 Amazon SageMaker 上,进行可扩展至分布式的托管训练任务。

首先,我们要将上文的训练代码整理至一个 python 脚本,然后使用 SageMaker 上预配置的 Huggingface 容器,我们提供了很多灵活的使用方式来使用该容器,具体可以参考 Hugging Face Estimator。

由于 SageMaker 预置的 Huggingface 容器已经具备推理逻辑, 故这里只需要将上一步中的训练脚本引入容器即可, 具体流程如下:

启动一个 Jupyter Notebook,选择 python3 作为解释器完成如下工作:

权限配置

import sagemaker
import os
sess = sagemaker.Session()
role = sagemaker.get_execution_role()

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

数据上传到 S3

# dataset used
dataset_name = ' news_summary'
# s3 key prefix for the data
s3_prefix = 'datasets/news_summary'
WORK_DIRECTORY = './data/'
data_location = sess.upload_data(WORK_DIRECTORY, key_prefix=s3_prefix)
data_location

 

定义超参数并初始化 estimator。

from sagemaker.huggingface import HuggingFace

# hyperparameters which are passed to the training job
hyperparameters={'text_column':'text',
                 'summary_column':'headlines',
                 'train_file':'/opt/ml/input/data/train/news_summary_cleaned_train.csv',
                 'validation_file':'/opt/ml/input/data/test/ news_summary_cleaned_test.csv',
                 'output_dir':'/opt/ml/model',
                 'do_train':True,
                 'do_eval':True,
                 'max_source_length': 128,
                 'max_target_length': 128,
                 'model_name_or_path': 't5-large',
                 'learning_rate': 3e-4,
                 'num_train_epochs': 1,
                 'per_device_train_batch_size': 2,#16
                 'gradient_accumulation_steps':2, 
                 'save_strategy':'epoch',
                 'evaluation_strategy':'epoch',
                 'save_total_limit':1,
                 }
distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}
# create the Estimator
huggingface_estimator = HuggingFace(
        entry_point='run_paraphrase.py',
        source_dir='./scripts',
        instance_type='ml.p3.2xlarge',#'ml.p3dn.24xlarge'
        instance_count=1,
        role=role,
        max_run=24*60*60,
        transformers_version='4.6',
        pytorch_version='1.7',
        py_version='py36',
        volume_size=128,
        hyperparameters = hyperparameters,
#         distribution=distribution
)

启动模型训练。

huggingface_estimator.fit(
  {'train': data_location+'/news_summary_cleaned_train.csv',
   'test': data_location+'/news_summary_cleaned_test.csv',}
)

 

训练启动后,我们可以在 Amazon SageMaker 控制台看到这个训练任务,点进详情可以看到训练的日志输出,以及监控机器的 GPU、CPU、内存等的使用率等情况,以确认程序可以正常工作。训练完成后也可以在 CloudWatch 中查看训练日志。

 

托管部署及推理测试

完成训练后,我们可以轻松的将上面的模型部署成一个实时可在生产环境中调用的端口。

from sagemaker.huggingface.model import HuggingFaceModel

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
#    env= {'HF_TASK':'text-generation'},
   model_data="s3://sagemaker-us-west-2-847380964353/huggingface-pytorch-training-2022-04-19-05-56-07-474/output/model.tar.gz",  # path to your trained SageMaker model
   role=role,                                            # IAM role with permissions to create an endpoint
   transformers_version="4.6",                           # Transformers version used
   pytorch_version="1.7",                                # PyTorch version used
   py_version='py36',                                    # Python version used
    
)
predictor = huggingface_model.deploy(
   initial_instance_count=1,
   instance_type="ml.g4dn.xlarge"
)

 

模型调用

from sagemaker.huggingface.model import HuggingFacePredictor
predictor=HuggingFacePredictor(endpoint_name='huggingface-pytorch-inference-2022-04-19-06-41-55-309')

import time
s=time.time()
df=pd.read_csv('./data/hp/summary/news_summary_cleaned_small_test.csv')
print('原文:',df.loc[0,'text'])
print('真实标签:',df.loc[0,'headlines'])
out=predictor.predict({
        'inputs': df.loc[0,'text'],
        "parameters": {"max_length": 256},
    })
e=time.time()
print('模型预测:' out)

 

输出如下:

原文: Germany on Wednesday accused Vietnam of kidnapping a former Vietnamese oil executive Trinh Xuan Thanh, who allegedly sought asylum in Berlin, and taking him home to face accusations of corruption. Germany expelled a Vietnamese intelligence officer over the suspected kidnapping and demanded that Vietnam allow Thanh to return to Germany. However, Vietnam said Thanh had returned home by himself.
真实标签: Germany accuses Vietnam of kidnapping asylum seeker 
模型预测: Germany accuses Vietnam of kidnapping ex-oil exec, taking him home

Amazon SageMaker

以上就是使用 Amazon SageMaker 构建文本摘要应用的全部过程,可以看到通过 Amazon SageMaker 可以非常便利地结合 Huggingface 进行 NLP 模型的搭建,训练,部署的全流程。整个过程仅需要准备训练脚本以及数据即可通过若干命令启动训练和部署,同时,我们后续还会推出,使用 Amaozn SageMaker 进行更多 NLP 相关任务的实现方式,敬请关注。

参考资料

  • Amazon Sagemaker: https://docs.aws.amazon.com/sagemaker/index.html
  • Huggingface:Hugging Face – The AI community building the future.
  • Code Link:GitHub - HaoranLv/nlp_transformer: nlp task tools

本篇作者

吕浩然

亚马逊云科技应用科学家,长期从事计算机视觉,自然语言处理等领域的研究和开发工作。支持数据实验室项目,在时序预测,目标检测,OCR,自然语言生成等方向有丰富的算法开发以及落地实践经验。

文章来源:https://dev.amazoncloud.cn/column/article/630a0a80afd24c6ba216ffa8?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=CSDN

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/438430.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构复习题(包含答案)

第一章 概论 一、选择题 1、研究数据结构就是研究&#xff08; D &#xff09;。 A. 数据的逻辑结构 B. 数据的存储结构 C. 数据的逻辑结构和存储结构 D. 数据的逻辑结构、存储结构及其基本操作 2、算法分析的两个主要方面是&#xff08; A …

Multi-modal Alignment using Representation Codebook

Multi-modal Alignment using Representation Codebook 题目Multi-modal Alignment using Representation Codebook译题使用表示子空间的多模态对齐期刊/会议CVPR 摘要&#xff1a;对齐来自不同模态的信号是视觉语言表征学习&#xff08;representation learning&#xff09;…

SpringMVC文件上传、异常处理、拦截器

SpringMVC文件上传、异常处理、拦截器 基本配置准备&#xff1a;maven项目模块 application.xml <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.…

数据库系统-数据物理存储

文章目录 一、DBMS原理1.1 DB物理存储1.1.1 磁盘的结构&特性1.1.2 DBMS数据存储&查询原理记录&#xff1a;磁盘块 1.2 DB文件组织方法1.2.1 无序文件组织1.2.2 有序记录文件1.2.3 散列文件(Hash File)1.2.4 聚簇文件(Clustering File) 1.3 Oracle 物理存储简介 一、DBM…

程序员最常见的谎言

小伙伴们大家好&#xff0c;我是阿秀。 上周看到知乎上有位网友总结了自己的10年程序员生涯中最常说的一些谎言&#xff0c;一共有15条&#xff0c;看完我直呼内行&#xff01;&#xff01; 全中&#xff01;每一枪都中了&#xff01;每一条我都说过。 我基本都说过他说过的那些…

下载VMware安装Ubuntu18.04.6系统

文章目录 前言&#xff1a;1. 下载VMWare与Ubuntu镜像2. 安装VMware3. VMware下安装Ubuntu 前言&#xff1a; 这篇文章是为交叉编译做的铺垫&#xff0c;本文内容比较简单&#xff0c;按步骤操作即可。 1. 下载VMWare与Ubuntu镜像 为了节约大家时间&#xff0c;省的各种找&am…

【Unity VR开发】结合VRTK4.0:摄像机碰撞变黑

语录&#xff1a; 人是要长大的&#xff0c;有天你也会推着婴儿车幸福地在街上行走&#xff0c;而曾经的喜欢&#xff0c;不管曾经怎样&#xff0c;都会幻化成风&#xff0c;消失在时光的隧道。所以向前走&#xff0c;向前走&#xff0c;无须回头。 前言&#xff1a; 往往我们在…

基于html+css的图片展示12

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

新品发布 | 智安网络【零信任SDP访问控制系统】正式上线!

近日&#xff0c;根据《2022年云办公行业研究报告》显示&#xff0c;2021-2023年&#xff0c;中国协同办公行业保持每年10%以上的增长率&#xff0c;远程办公人员的数量在过去十年中增长了115%&#xff0c;互联网时代已迎来远程办公需求大潮。 风口之下&#xff0c;员工对企业…

Go语言基础——通过获取网站API的cUrl,生成可直接执行的request代码(附:详细实战案例和源码)

作者&#xff1a;非妃是公主 专栏&#xff1a;《Golang》 博客地址&#xff1a;https://blog.csdn.net/myf_666 个性签&#xff1a;顺境不惰&#xff0c;逆境不馁&#xff0c;以心制境&#xff0c;万事可成。——曾国藩 文章目录 一、cUrl是什么&#xff1f;二、cUrl如何获取…

OpenAI-ChatGPT最新官方接口《从0到1生产最佳实例》全网最详细中英文实用指南和教程,助你零基础快速轻松掌握全新技术(十一)(附源码)

Production Best Practices 生产最佳实例 前言Introduction 导言Setting up your organization 设置您的组织Managing billing limits 管理计费限额API keys API密钥Staging accounts 演示账户 Building your prototype 构建您的原型Additional tips 其它技巧 Techniques for i…

【Java 数据结构】队列的实现及相关OJ题

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了 博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点!人生格言&#xff1a;当你的才华撑不起你的野心的时候,你就应该静下心来学习! 欢迎志同道合的朋友一起加油喔&#x1f9be;&am…

图的搜索算法---8.2 ZOJ1002-Fire Net--合理布置碉堡

目录 问题&#xff1a; 分析&#xff1a; 主要的算法代码&#xff1a; 完整代码&#xff1a; 出问题的代码&#xff1a; 原因&#xff1a; 正确代码&#xff1a; 代码分析&#xff1a; 算法函数讲解&#xff1a; CanPut函数 solve函数 运行结果&#xff1a; 问题&am…

勇创世界一流!移动云为我国数字经济发展提供有力支撑

今年2月&#xff0c;中共中央、国务院印发了《数字中国建设整体布局规划》&#xff0c;通过顶层设计及布局&#xff0c;擘画了我国数字经济发展蓝图。近日在国新办举行的第六届数字中国建设峰会新闻发布会上&#xff0c;相关负责人也对我国数字经济现阶段取得的成绩进行了总结&…

IPWorks VoIP 2022 Crack

网络电话组件 IPWorks VoIP 提供 SIP 和 IVR 组件&#xff0c;旨在促进 CTI 应用程序中的常见 VoIP 操作。快速集成功能以建立拨出电话、接听来电以及根据您的自定义 IVR 菜单路由呼叫。还支持其他 SIP 功能&#xff0c;例如文本到语音、播放预先录制的消息、通话录音和电话会议…

总结:helm

一、介绍 Helm是k8s的包管理工具&#xff0c;类似Linux系统常用的 apt、yum等包管理工具&#xff0c;基于go 语言开发。使用helm可以简化k8s应用部署 二、基本概念 Helm的基本概念 Chart&#xff1a;一个 Helm 包&#xff0c;其中包含了运行一个应用所需要的镜像、依赖和资源…

flask实际开发:flask和nginx如何配置支持websocket

使用pycharm启动flask项目有坑&#xff0c;先修改pycharm设置 1、点击Edit Confiturations 2、配置启动方式 1 新增启动配置 2 选择使用python命令执行 3 给配置设置一个名字 4 设置要启动的模块的位置&#xff0c;flask基本都是app.py 模块 最后别忘记&#xff1a;点击右侧的…

Linux网络-传输层UDP/TCP详解

目录 计算机网络的层状结构 UDP协议 UDP报文格式 理解UDP/TCP报文的本质 UDP的特点 UDP的缓冲区 sendto/recvfrom/send/recv/write/read IO类接口 UDP是全双工的 UDP注意事项 UDP协议&#xff0c;实现简单聊天室&#xff08;服务端客户端&#xff09; TCP协议 TCP协…

SpringBoot集成 ElasticSearch

Spring Boot 集成 ElasticSearch 对于ElasticSearch比较陌生的小伙伴可以先看看ElasticSearch的概述ElasticSearch安装、启动、操作及概念简介 好的开始啦~ 1、基础操作 1.1、导入依赖 <dependency><groupId>org.springframework.boot</groupId><arti…

【是C++,不是C艹】 第一个C++程序 | 命名空间 | 输入输出

&#x1f49e;&#x1f49e; 欢迎来到 Claffic 的博客 &#x1f49e;&#x1f49e; &#x1f449;专栏&#xff1a;《是C&#xff0c;不是C艹》&#x1f448; 前言&#xff1a; 在认识了C的来历之后&#xff0c;我们就要开始正式学习C了&#xff0c;系好安全带&#xff0c;准备…