深度学习:GPT-1的MindSpore实践

news2024/11/24 9:03:25

GPT-1简介

GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点:

  • NLP领域的迁移学习:通过最少的任务专项数据,利用预训练模型出色地完成具体的下游任务。
  • 语言建模作为预训练任务:使用无监督学习和大规模的文本语料库来训练模型
  • 为具体任务微调:采用预训练模型来适应监督任务

和BERT类似,GPT-1同样采取pre-train + fine-tune的思路:先基于大量未标注语料数据进行预训练, 后基于少量标注数据进行微调。但GPT-1在预训练任务思路和模型结构上与BERT有所差别。

GPT-1的目标是在预训练的过程中根据现有的所有词元,预测下一个词元。这个任务被称为“自回归语言建模”。

一个简单的例子:

输入序列为:“The sun rises in the”

训练数据的原句子为:“The sun rises in the east”

所以我们的目标输出为:“east”

将输入序列输入GPT模型,GPT根据输入预测下一个词元(“east”)在语料库中的概率分布

正确词元“east”作为一个“伪标签”来帮助模型训练

模型架构

GPT主要使用Transformer Decoder架构,但因为没有Encoder,所以在Transformer Decoder的基础上移除了计算Encoder与Decoder间注意力分数的Multi-Head Attention Layer。

Masked Multi-HeadSelf-Attention

Masked Multi-Head Self-Attention 是Multi-Head Attetion的变种。 最大的不同来自于MMSA的掩码机制,掩码机制防止模型通过观测未来的词元以进行“作弊”。

一个掩码词元<mask>被用于注意力分数矩阵,所以当前词元只能注意到序列中自己和自己之前的词元。未来的次元的注意力分数将被设为0以确保其在Softmax步骤后的实际贡献为0。

为什么掩码机制非常重要?

对于自回归任务,模型必须线性地生成词元,不能基于未来的信息预测下一个词元。

损失函数

GPT使用Cross-Entropy Loss作为损失函数:\mathcal{L} = - \sum_{t=1}^N \log P(w_t | w_1, w_2, \dots, w_{t-1})

交叉熵损失是这项任务的理想选择,因为它通过测量预测的概率分布与真实分布的距离来惩罚不正确的预测。它自然适于处理多类分类任务,其中模型从大量词汇表中选择一个标记。

模型输入

GPT-1的输入同样为句子或句子对,并添加Special Tokens。

  • [BOS]:表示句子的开始,(论文中给出的token表示为[START]),添加到序列最前;
  • [EOS]:表示序列的结束,(论文中的给出的[EXTRACT]),添加到序列最后,在进行分类任务时,会将 该special token对应的输出接入输出层;我们也可以理解为该token可以学习到整个句子的语义信息;
  • [SEP]:用于间隔句子对中的两个句子;
GPT Embedding 同样分为三类:token Embedding、Position Embedding、Segment Embedding

 

GPT-1模型具体参数

模型架构

  • 12个Transformer Decoder Block
  • hidden_size为768(模型输入和输出的向量纬度)
  • 注意力头数为12
  • FFN维度为3072
  • 词表(Vocab)大小为40000
  • 序列长度为512(上下文窗口长度)

训练过程

  • Adam优化器,超参数为:0.9, 0.99
  • 学习率:最大学习率:2.5x10e-4 使用2000步作为热身,随后线性衰退
  • 批大小:64
  • 梯度剪裁:1.0
  • Dropout率:0.1

训练过程

100000步,大约花费8张NVIDIA V100 GPU训练30天,共有117M参数。使用Xavier初始化,权重衰退为0.01。 

下游任务 

GPT按照生成式的逻辑统一了下游任务的应用模板,使用最后一个token([EOS]or[EXTRACT])对应的hidden state,输出到额外的输出层中,进行分类标签预测。
任务包括:文本分类(情感分类、新闻分类)、文本蕴含(根据前提推出假设)、文本语义相似度、多类选择(在多个next token中进行选择)

基于MindSpore微调GPT-1进行情感分类

# #安装mindnlp 0.4.0套件
# !pip install mindnlp
# !pip uninstall soundfile -y
# !pip install download
# !pip install jieba
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp.engine import Trainer

# loading dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

imdb_train.get_dataset_size()

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset

from mindnlp.transformers import OpenAIGPTTokenizer
# tokenizer
gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

#为方便体验流程,把原本数据集的十分之一拿出来体验训练和评估,
imdb_train, _ = imdb_train.split([0.1, 0.9], randomize=False)

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

# load GPT sequence classification model and set class=2
from mindnlp.transformers import OpenAIGPTForSequenceClassification  # Import the GPT model for sequence classification
from mindnlp import evaluate  # Import the evaluation module from MindNLP
import numpy as np  # Import NumPy for numerical operations

# Set up the GPT model for sequence classification with 2 output labels (binary classification).
model = OpenAIGPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)

# Set the padding token ID in the model configuration to match the tokenizer's padding token ID.
model.config.pad_token_id = gpt_tokenizer.pad_token_id

# Resize the token embedding layer to account for any added tokens (e.g., special tokens).
model.resize_token_embeddings(model.config.vocab_size + 3)

from mindnlp.engine import TrainingArguments  # Import training arguments for model training configuration.

# Define training arguments.
training_args = TrainingArguments(
    output_dir="gpt_imdb_finetune",  # Directory to save model checkpoints and outputs.
    evaluation_strategy="epoch",  # Evaluate the model at the end of each epoch.
    save_strategy="epoch",  # Save model checkpoints at the end of each epoch.
    logging_strategy="epoch",  # Log metrics and progress at the end of each epoch.
    load_best_model_at_end=True,  # Automatically load the best model (based on evaluation metrics) at the end of training.
    num_train_epochs=1.0,  # Number of training epochs (default is 1 for quick experimentation).
    learning_rate=2e-5  # Learning rate for the optimizer.
)

# Load the accuracy metric for evaluation.
metric = evaluate.load("accuracy")

# Define a function to compute metrics during evaluation.
def compute_metrics(eval_pred):
    logits, labels = eval_pred  # Unpack predictions (logits) and true labels.
    predictions = np.argmax(logits, axis=-1)  # Convert logits to class predictions using argmax.
    return metric.compute(predictions=predictions, references=labels)  # Compute accuracy metric.

# Initialize the Trainer class with the model, training arguments, datasets, and metric computation function.
trainer = Trainer(
    model=model,  # The GPT model to be fine-tuned.
    args=training_args,  # Training configuration arguments.
    train_dataset=dataset_train,  # Training dataset (must be preprocessed and tokenized).
    eval_dataset=dataset_val,  # Validation dataset for evaluation.
    compute_metrics=compute_metrics  # Metric computation function for evaluation.
)

# start training
trainer.train()

trainer.evaluate(dataset_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2246573.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CKA认证 | Day2 K8s内部监控与日志

第三章 Kubernetes监控与日志 1、查看集群资源状态 在 Kubernetes 集群中&#xff0c;查看集群资源状态和组件状态是非常重要的操作。以下是一些常用的命令和解释&#xff0c;帮助你更好地管理和监控 Kubernetes 集群。 1.1 查看master组件状态 Kubernetes 的 Master 组件包…

概念解读|K8s/容器云/裸金属/云原生...这些都有什么区别?

随着容器技术的日渐成熟&#xff0c;不少企业用户都对应用系统开展了容器化改造。而在容器基础架构层面&#xff0c;很多运维人员都更熟悉虚拟化环境&#xff0c;对“容器圈”的各种概念容易混淆&#xff1a;容器就是 Kubernetes 吗&#xff1f;容器云又是什么&#xff1f;容器…

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源&#xff0c;并设置数据库所在的位置&#xff0c;三条固定写法 2.建立和数据库服务器之间的连接&#xff0c;连接好了后&#xff…

移动充储机器人“小奥”的多场景应用(上)

在当前现代化城市交通体系中&#xff0c;移动充储机器人“小奥”发挥着至关重要的作用。该机器人不仅是一个简单的设备&#xff0c;而是一个集成了高科技的移动充电站&#xff0c;为新能源汽车提供了一种前所未有的便捷充电解决方案。该机器人配备了先进的电池管理系统&#xf…

element dialog会隐藏body scroll 导致tab抖动 解决方案如下

element dialog会隐藏body scroll 导致tab抖动 解决方案如下 在dialog标签添加 :lockScroll"false"搞定

Android 功耗分析(底层篇)

最近在网上发现关于功耗分析系列的文章很少&#xff0c;介绍详细的更少&#xff0c;于是便想记录总结一下功耗分析的相关知识&#xff0c;有不对的地方希望大家多指出&#xff0c;互相学习。本系列分为底层篇和上层篇。 大概从基础知识&#xff0c;测试手法&#xff0c;以及案例…

Bugku CTF_Web——my-first-sqli

Bugku CTF_Web——my-first-sqli 进入靶场 随便输一个看看 点login没有任何回显 方法一&#xff1a; 上bp抓包 放到repeter测试 试试万能密码&#xff08;靶机过期了重新开了个靶机&#xff09; admin or 11--shellmates{SQLi_goeS_BrrRrRR}方法二&#xff1a; 拿包直接梭…

BUUCTF—Reverse—easyre(1)

非常简单的逆向 拿到exe文件先查下信息&#xff0c;是一个64位程序&#xff0c;没有加壳&#xff08;壳是对代码的加密&#xff0c;起混淆保护的作用&#xff0c;一般用来阻止逆向&#xff09;。 然后拖进IDA(64位)进行反汇编 打开以后就可以看到flag flag{this_Is_a_EaSyRe}

全面击破工程级复杂缓存难题

目录 一、走进业务中的缓存 &#xff08;一&#xff09;本地缓存 &#xff08;二&#xff09;分布式缓存 二、缓存更新模式分析 &#xff08;一&#xff09;Cache Aside Pattern&#xff08;旁路缓存模式&#xff09; 读操作流程 写操作流程 流程问题思考 问题1&#…

React基础知识一

写的东西太多了&#xff0c;照成csdn文档编辑器都开始卡顿了&#xff0c;所以分篇写。 1.安装React 需要安装下面三个包。 react:react核心包 react-dom:渲染需要用到的核心包 babel:将jsx语法转换成React代码的工具。&#xff08;没使用jsx可以不装&#xff09;1.1 在html中…

Vue3中使用:deep修改element-plus的样式无效怎么办?

前言&#xff1a;当我们用 vue3 :deep() 处理 elementui 中 el-dialog_body和el-dislog__header 的时候样式一直无法生效&#xff0c;遇到这种情况怎么办&#xff1f; 解决办法&#xff1a; 1.直接在 dialog 上面增加class 我试过&#xff0c;也不起作用&#xff0c;最后用这种…

鸿蒙进阶-状态管理

大家好啊&#xff0c;这里是鸿蒙开天组&#xff0c;今天我们来学习状态管理。 开始组件化开发之后&#xff0c;如何管理组件的状态会变得尤为重要&#xff0c;咱们接下来系统的学习一下这部分的内容 状态管理机制 在声明式UI编程框架中&#xff0c;UI是程序状态的运行结果&a…

leetcode:129. 求根节点到叶节点数字之和

给你一个二叉树的根节点 root &#xff0c;树中每个节点都存放有一个 0 到 9 之间的数字。 每条从根节点到叶节点的路径都代表一个数字&#xff1a; 例如&#xff0c;从根节点到叶节点的路径 1 -> 2 -> 3 表示数字 123 。 计算从根节点到叶节点生成的 所有数字之和 。…

(南京观海微电子)——GH7006+BOE2.6_GV026WVQ-N81-1QP0_800RGB480_MIPI_LVDS_RGB原理介绍

1. 原理介绍 2. 代码 // Model - GV026WVQ-1QP0 // IC - GH7006 // Width - 800 // Height - 480 // REV: - V01 // DATA - 20240507 // INTERFACE- MIPI //"Vfp" value"16" /> //"…

速度革命:esbuild如何改变前端构建游戏 (1)

什么是 esbuild&#xff1f; esbuild 是一款基于 Go 语言开发的 JavaScript 构建打包工具&#xff0c;以其卓越的性能著称。相比传统的构建工具&#xff08;如 Webpack&#xff09;&#xff0c;esbuild 在打包速度上有着显著的优势&#xff0c;能够将打包速度提升 10 到 100 倍…

Excel的图表使用和导出准备

目的 导出Excel图表是很多软件要求的功能之一&#xff0c;那如何导出Excel图表呢&#xff1f;或者说如何使用Excel图表。 一种方法是软件生成图片&#xff0c;然后把图片写到Excel上&#xff0c;这种方式&#xff0c;因为格式种种原因&#xff0c;导出的图片不漂亮&#xff0c…

自动化运维-Linux通用性日志切割脚本

一、公司提供的参考脚本&#xff1a; #!/bin/bash # 定义需要清理的文件 log_file("/mpjava/ly.mp.dfpv.acc.biz/bin/nohup.out""/mpjava/ly.mp.dfpv.acc.service/bin/nohup.out"# 添加更多微服务的日志目录路径 ) # 获取当天日期 date_now$(date %Y%m%d)…

Let‘s Encrypt SSL证书:acmessl.cn申请免费3个月证书

目录 一、CA机构 二、Lets Encrypt特点 三、申请SSL 一、CA机构 ‌Lets Encrypt‌是一个由非营利组织Internet Security Research Group (ISRG)运营的证书颁发机构&#xff08;CA&#xff09;&#xff0c;旨在通过自动化和开放的方式为全球网站提供免费、可靠的SSL/TLS证书。…

Java连接MySQL数据库进行增删改查操作

Test 1 首先去查看一下MySQL的版本&#xff1a;mysql -V&#xff08;在cmd中&#xff09;记得要启动MySQL服务在cmd中验证是否可以登录数据库成功&#xff1a;mysql -u root -p&#xff08;然后输入密码&#xff1a;root&#xff09;Test 2 在IDEA创建项目在SQLyog中创建数据…

从搭建uni-app+vue3工程开始

技术栈 uni-app、vue3、typescript、vite、sass、uview-plus、pinia 一、项目搭建 1、创建以 typescript 开发的工程 npx degit dcloudio/uni-preset-vue#vite-ts my-vue3-project2、安装sass npm install -D sass// 安装sass-loader&#xff0c;注意需要版本10&#xff0c;…