【机器学习】—机器学习和NLP预训练模型探索之旅

news2024/12/23 4:46:13

目录

一.预训练模型的基本概念

1.BERT模型

2 .GPT模型

二、预训练模型的应用

1.文本分类

使用BERT进行文本分类

2. 问答系统

使用BERT进行问答

三、预训练模型的优化

 1.模型压缩

1.1 剪枝

权重剪枝

2.模型量化

2.1 定点量化

使用PyTorch进行定点量化

3. 知识蒸馏

3.1 知识蒸馏的基本原理

3.2 实例代码:使用知识蒸馏训练学生模型

四、结论


随着数据量的增加和计算能力的提升,机器学习和自然语言处理技术得到了飞速发展。预训练模型作为其中的重要组成部分,通过在大规模数据集上进行预训练,使得模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

一.预训练模型的基本概念

预训练模型是一种在大规模数据集上预先训练好的模型,可以作为其他任务的基础。预训练模型的优势在于其能够利用大规模数据集中的知识,提高模型的泛化能力和准确性。常见的预训练模型包括BERT(Bidirectional Encoder Representations from Transformers)、GPT(Generative Pre-trained Transformer)等。

1.BERT模型

BERT是由Google提出的一种双向编码器表示模型。BERT通过在大规模文本数据上进行掩码语言模型(Masked Language Model, MLM)和下一句预测(Next Sentence Prediction, NSP)的预训练,使得模型可以学习到深层次的语言表示。

2 .GPT模型

GPT由OpenAI提出,是一种基于Transformer的生成式预训练模型。GPT通过在大规模文本数据上进行自回归语言模型的预训练,使得模型可以生成连贯的文本。

二、预训练模型的应用

预训练模型在NLP领域有广泛的应用,包括但不限于文本分类、问答系统、机器翻译等。以下将介绍几个具体的应用实例。

1.文本分类

文本分类是将文本数据按照预定义的类别进行分类的任务。预训练模型可以通过在大规模文本数据上进行预训练,从而捕捉到丰富的语义信息,提高文本分类的准确性。

使用BERT进行文本分类

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 定义数据集
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# 准备数据
texts = ["I love this!", "I hate this!"]
labels = [1, 0]
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.1)

train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len=32)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len=32)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(3):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, predicted = torch.max(outputs.logits, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation Accuracy: {correct / total:.2f}')

2. 问答系统

问答系统是从文本中自动提取答案的任务。预训练模型可以通过在大规模问答数据上进行预训练,从而提高答案的准确性和相关性。

使用BERT进行问答

from transformers import BertForQuestionAnswering

# 加载预训练的BERT问答模型
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

# 输入问题和上下文
question = "What is the capital of France?"
context = "Paris is the capital of France."

# 编码输入
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')

# 模型预测
outputs = model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits

# 获取答案的起始和结束位置
start_idx = torch.argmax(start_scores)
end_idx = torch.argmax(end_scores) + 1

# 解码答案
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx]))
print(f'Answer: {answer}')

三、预训练模型的优化

在实际应用中,预训练模型的优化至关重要。常见的优化方法包括模型压缩、量化和蒸馏等。

 1.模型压缩

模型压缩是通过减少模型参数数量和计算量来提高模型效率的方法。压缩后的模型不仅运行速度更快,还能减少存储空间和内存占用。常见的模型压缩技术包括剪枝、量化和知识蒸馏等。

1.1 剪枝

剪枝(Pruning)是一种通过删除模型中冗余或不重要的参数来减小模型大小的方法。剪枝可以在训练过程中或训练完成后进行。常见的剪枝方法包括:

  • 权重剪枝(Weight Pruning):删除绝对值较小的权重,认为这些权重对模型输出影响不大。
  • 结构剪枝(Structured Pruning):删除整个神经元或卷积核,减少模型的计算量和存储需求。

剪枝后的模型通常需要重新训练,以恢复或接近原始模型的性能。

权重剪枝
import torch
import torch.nn.utils.prune as prune

# 定义一个简单的模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = torch.nn.Linear(10, 10)
    
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 对模型的全连接层进行权重剪枝
prune.l1_unstructured(model.fc, name='weight', amount=0.5)

# 查看剪枝后的权重
print(model.fc.weight)

2.模型量化

模型量化是通过降低模型参数的精度来减少计算量的方法。量化通常通过将浮点数表示的权重和激活值转换为低精度表示(如8位整数)来实现。这可以显著减少模型的存储空间和计算开销,同时在硬件上加速模型推理。

2.1 定点量化

定点量化(Fixed-point Quantization)是将浮点数表示的权重和激活值转换为固定精度的整数表示。常见的定点量化包括8位整数量化(INT8),这种量化方法在不显著降低模型精度的情况下,可以大幅提升计算效率。

使用PyTorch进行定点量化
import torch
import torch.quantization

# 加载预训练模型
model = SimpleModel()

# 定义量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# 准备量化模型
model = torch.quantization.prepare(model, inplace=True)

# 模拟量化后的推理过程
# 这里应该使用训练数据对模型进行微调,但为了简单起见,省略此步骤
model = torch.quantization.convert(model, inplace=True)

# 查看量化后的模型
print(model)

3. 知识蒸馏

知识蒸馏(Knowledge Distillation)是通过将大模型(教师模型,Teacher Model)的知识转移到小模型(学生模型,Student Model)的方法,从而提高小模型的性能和效率。知识蒸馏的核心思想是通过教师模型的软标签(soft labels)指导学生模型的训练。

3.1 知识蒸馏的基本原理

在知识蒸馏过程中,学生模型不仅学习训练数据的真实标签,还学习教师模型对训练数据的输出,即软标签。软标签包含了更多的信息,比如类别之间的相似性,使学生模型能够更好地泛化。

蒸馏损失函数通常由两部分组成:

  • 交叉熵损失:衡量学生模型输出与真实标签之间的差异。
  • 蒸馏损失:衡量学生模型输出与教师模型软标签之间的差异。

总体损失函数为这两部分的加权和。

3.2 实例代码:使用知识蒸馏训练学生模型

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# 定义教师模型和学生模型
teacher_model = SimpleModel()
student_model = SimpleModel()

# 加载示例数据
data = torch.randn(100, 10)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# 定义蒸馏训练函数
def distillation_train(student_model, teacher_model, data_loader, optimizer, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    for data, labels in data_loader:
        optimizer.zero_grad()
        
        # 教师模型输出
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        
        # 学生模型输出
        student_logits = student_model(data)
        
        # 计算蒸馏损失
        loss_ce = F.cross_entropy(student_logits, labels)
        loss_kl = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction='batchmean'
        ) * (temperature ** 2)
        
        loss = alpha * loss_ce + (1.0 - alpha) * loss_kl
        loss.backward()
        optimizer.step()

# 定义优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)

# 进行蒸馏训练
for epoch in range(10):
    distillation_train(student_model, teacher_model, data_loader, optimizer)

# 验证学生模型
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, labels in data_loader:
        outputs = student_model(data)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Student Model Accuracy: {correct / total:.2f}')

四、结论

预训练模型在机器学习和自然语言处理领域具有重要意义。通过在大规模数据集上进行预训练,模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1685039.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HQL面试题练习 —— 品牌营销活动天数

题目来源:小红书 目录 1 题目2 建表语句3 题解 1 题目 有营销活动记录表,记录了每个品牌每次营销活动的开始日期和营销活动的结束日期,现需要统计出每个品牌的总营销天数。 注意: 1:苹果第一行数据的营销结束日期比第二行数据的营…

系统思考—跳出症状看全局

今年的《系统思考—跳出症状看全局》课程不断进行了迭代优化。通过一个企业的真实案例,我们与学员共同探讨了线性思考与系统思考的区别,并学习了如何从全局角度做出更加明智的决策,一切就绪,期待学员的共创。

xxe漏洞--xml外部实体注入漏洞

1.xxe漏洞介绍 XXE(XML External Entity Injection)是一种攻击技术,它允许攻击者注入恶意的外部实体到XML文档中。如果应用程序处理XML输入时未正确配置,攻击者可以利用这个漏洞访问受影响系统上的敏感文件、执行远程代码、探测内…

操作系统底层运行原理 —— 基于线程安全的消息机制

前言 学过Android应用开发的大概都知道Handler这个东东,这也是面试中老生常谈的问题。其实不仅仅是Android,iOS以及PC的操作系统,底层也离不开消息机制。这个属于生产消费者问题。 什么是生产者消费者模式 生产者消费者模式(Pr…

【UE Websocket】“WebSocket Server”插件使用记录

1. 在商城中下载“WebSocket Server”插件 该插件具有如下节点,基本可以满足WebSocket服务端的所有需求 2. 如果想创建一个基本的服务端,我们可以新建一个actor蓝图,添加如下节点 3. UE运行后,我们可以使用在线的websocket测试助手…

使用MicroPython和pyboard开发板(15):使用LCD和触摸传感器

使用LCD和触摸传感器 pybaord的pyb对LCD设备也进行了封装,可以使用官方的LCD显示屏。将LCD屏连接到开发板,连接后。 使用LCD 先用REPL来做个实验,在MicroPython提示符中输入以下指令。请确保LCD面板连接到pyboard的方式正确。 >>…

认识NXP新型微处理器:MCX工业和物联网微控制器

目录 概述 1 MCX工业和物联网微控制器介绍 2 MCX 系列微控制器类型 2.1 MCX N系列微控制器 2.1.1 主要特征 2.1.2 MCX N系列产品 2.1.3 MCX N9xx和N5xx MCU选型表 2.2 MCX A系列微控制器 2.2.1 主要特征 2.2.2 MCX A系列产品 2.2.3 MCX A MCU的架构 2.3 MCX W系…

Unity射击游戏开发教程:(24)创造不同的敌人

在这篇文章中,我们将讨论添加一个可以承受多次攻击的新敌人和一些动画来使事情变得栩栩如生。敌人没有任何移动或射击行为。这将有助于增强未来敌人的力量。 我们将声明一个 int 来存储敌人可以承受的攻击数量,并将其设置为 3。

Unity修改Project下的Assets的子文件的图标

Unity修改文件夹的图标 示例: 在右键可以创建指定文件夹。 github链接 https://github.com/SeaeeesSan/SimpleFolderIconCSDN资源的链接 https://download.csdn.net/download/GoodCooking/89347361 去GitHub下载支持原作者哦。重要的事情 截图来自GitHub 。 U…

信息系统项目管理师0127:工具与技术(8项目整合管理—8.6管理项目知识—8.6.2工具与技术)

点击查看专栏目录 文章目录 8.6.2 工具与技术8.6.2 工具与技术 专家判断管理项目知识过程中,应征求具备如下领域相关专业知识或接受过相关培训的个人或小组的意见,涉及的领域包括:知识管理、信息管理、组织学习、知识和信息管理工具以及来自其他项目的相关信息等。 知识管理…

【2024】高校网络安全管理运维赛

比赛时间:2024-05-06 Re-easyre 基本的base64换表,用CyberChef解密 Re-babyre 进入主函数,发现输入四次 看一下就知道是大数求解 (当初写的时候差不多 不知道为什么第四个总是算错…) from z3 import *s Solver() # 设置一个解方程的类…

产品经理-需求收集(二)

1. 什么是需求 指在一定的时期中,一定场景中,无论是心理上还是生理上的,用户有着某种“需要”,这种“需要”用户自己不一定知道的,有了这种“需要”后用户就有做某件事情的动机并促使达到其某种目的,这也就…

Redis 主从复制、哨兵与集群

一、Redis 主从复制 1. 主从复制的介绍 主从复制,是指将一台Redis服务器的数据,复制到其他的Redis服务器。前者称为主节点(Master),后者称为从节点(Slave);数据的复制是单向的,只能由主节点到从节点。 默认情况下&a…

如何快速从手动测试转向自动化测试

寻求具有无缝持续集成和持续交付 (CI/CD) 的高效 DevOps 管道比以往任何时候都更加重要。想象一下这样一个场景:您的软件组织显著减少了人工工作量、降低了成本,并更加自信地发布了软件更新。换句话说,通过将 Web UI 和 API 测试结合在一起&a…

展现金融科技前沿力量,ATFX于哥伦比亚金融博览会绽放光彩

不到半个月的时间里,高光时刻再度降临ATFX。而这一次,是ATFX不曾拥有的桂冠—“全球最佳在线经纪商”(Best Global Online Broker)。2024年5月15日至16日,拉丁美洲首屈一指的金融盛会—2024年哥伦比亚金融博览会(Money Expo Colombia 2024) 于…

前端开发攻略---用Vue实现无限滚动的几种方法

目录 1、原理 2、使用CSS动画 代码: 3、使用JS实现 代码: 1、原理 复制内容:将需要滚动的内容复制一次,并将这些副本放置在原始内容的后面。这样,当用户滚动到内容的末尾时,就会无缝地切换回到内容的起…

【Python】—— 公共的方法

目录 (一)公共操作 1.1 公共操作之运算符加号 1.2 公共操作之运算符乘号 1.3 公共操作之运算符判断数据是否存在 (二)公共方法 2.1 公共方法-len 2.2 公共方法-del 2.3 公共方法-max和min 2.4 公共方法-range 2.5 公共方…

如果 SEC 批准以太坊现货 ETF,会有更多山寨币 ETF 吗?

撰文:Protos 编译:Ismay,BlockBeats 文章来源香港Web3媒体Techun News 编者按:SEC 已与交易平台和 ETF 申请人就 19b-4 规则变更请求进行沟通,这表明以太坊现货 ETF 获批的可能性大大增加。与此同时山寨币投资者猜测…

嵌入式开发中树莓派和单片机关键区别

综合了几篇帖子作以信息收录:树莓派和单片机作为嵌入式系统领域中两种广泛使用的设备,各自有着不同的特性和应用场景,文章从五个方面进行比对展开。 架构与性能: 树莓派:是一款微型计算机,通常配备基于AR…

解决 git:OpenSSL SSL_read: SSL_ERROR_SYSCALL, errno 0

解决 git:OpenSSL SSL_read: SSL_ERROR_SYSCALL, errno 0 问题 git pull报错:fatal: unable to access ‘https://github.com/aircrushin/ultrav-music.git/’: Failed to connect to github.com port 443 after 21077 ms: Couldn’t connect to serve…