基于线性支持向量机的词嵌入文本分类torch案例

news2024/9/28 17:31:39

在这里插入图片描述

一、前言

简介线性支持向量机,并使用线性支持向量机实现文本分类, 输入文本通过词嵌入方法转换成浮点张量,给出torch案例

线性支持向量机(Linear Support Vector Machine,简称Linear SVM)是一种常用的分类算法,它通过一个超平面来将数据分成两类。对于线性可分的数据集,线性SVM能够找到一个最优的超平面,使得距离最近的数据点到这个超平面的距离最大化,从而使得分类边界更加稳定。

二、项目介绍

在文本分类任务中,我们可以使用线性SVM来将文本分成两类,比如正面和负面。首先需要将文本转换成数字表示,这可以通过词嵌入(Word Embedding)方法来实现。词嵌入是将单词转换成向量表示的一种技术,它可以将单词之间的语义关系表达为向量之间的距离关系。在文本分类任务中,我们可以将每个单词转换成一个固定长度的向量,然后将所有单词的向量按照一定的顺序组合成一个文本向量,从而得到文本的数字表示。

三、Toy demo 项目展示

下面是使用PyTorch实现文本分类任务的示例代码,其中使用了线性SVM作为分类器:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class LinearSVM(nn.Module):
    def __init__(self, input_size):
        super(LinearSVM, self).__init__()
        self.linear = nn.Linear(input_size, 1)

    def forward(self, x):
        out = self.linear(x)
        return out

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, word_embedding):
        self.texts = texts
        self.labels = labels
        self.word_embedding = word_embedding

    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        text_vec = np.mean([self.word_embedding[word] for word in text.split() if word in self.word_embedding], axis=0)
        text_vec = torch.from_numpy(text_vec).float()
        return text_vec, label

    def __len__(self):
        return len(self.labels)

# 定义超参数
embedding_size = 50
lr = 0.01
num_epochs = 10

# 加载数据集
train_texts = ['good movie', 'not a good movie', 'bad movie', 'an excellent movie', 'i loved it', 'could have been better', 'completely ridiculous', 'not worth watching', 'it was okay', 'awesome movie']
train_labels = [1., -1, -1, 1, 1, -1, -1, -1, 0, 1]
word_embedding = {'good': np.random.rand(embedding_size), 'movie': np.random.rand(embedding_size), 'not': np.random.rand(embedding_size), 'bad': np.random.rand(embedding_size), 'an': np.random.rand(embedding_size), 'excellent': np.random.rand(embedding_size), 'i': np.random.rand(embedding_size), 'loved': np.random.rand(embedding_size), 'it': np.random.rand(embedding_size), 'could': np.random.rand(embedding_size), 'have': np.random.rand(embedding_size), 'been': np.random.rand(embedding_size), 'better': np.random.rand(embedding_size), 'completely': np.random.rand(embedding_size), 'ridiculous': np.random.rand(embedding_size), 'worth': np.random.rand(embedding_size), 'watching': np.random.rand(embedding_size), 'was': np.random.rand(embedding_size), 'okay': np.random.rand(embedding_size), 'awesome': np.random.rand(embedding_size)}
train_dataset = TextDataset(train_texts, train_labels, word_embedding)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

# 定义模型、损失函数和优化器
model = LinearSVM(embedding_size)
criterion = nn.HingeEmbeddingLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

# 训练模型
for epoch in range(num_epochs):
    for batch_data in train_dataloader:
        # print(batch_data)
        x, y = batch_data
        print("这里这里", y)

        x = x.unsqueeze(1)
        model.zero_grad()
        out = model(x)
        loss = criterion(out.squeeze(), y.float())
        loss.backward()
        optimizer.step()
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# 测试模型



    test_texts = ['good film', 'bad film']
    test_labels = [1, -1]
    test_dataset = TextDataset(test_texts, test_labels, word_embedding)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)

    with torch.no_grad():
        for batch_data in test_dataloader:
            x, y = batch_data
            x = x.unsqueeze(1)
            out = model(x)
            predicted = torch.sign(out.squeeze())
            print('Predicted:', predicted)
            print('True:', y)

四、运行结果

这里这里 tensor([-1,  1])
这里这里 tensor([ 1, -1])
这里这里 tensor([1., 1.], dtype=torch.float64)
这里这里 tensor([ 0, -1])
这里这里 tensor([-1, -1])
Epoch [1/10], Loss: 0.8326
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1, -1])
这里这里 tensor([-1, -1])
这里这里 tensor([ 1, -1])
这里这里 tensor([1, 1])
这里这里 tensor([1., 0.], dtype=torch.float64)
Epoch [2/10], Loss: 0.7192
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([1., 1.])
这里这里 tensor([-1,  0])
这里这里 tensor([-1, -1])
这里这里 tensor([1, 1])
这里这里 tensor([-1, -1])
Epoch [3/10], Loss: 0.7749
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([ 1, -1])
这里这里 tensor([ 0, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([ 1., -1.], dtype=torch.float64)
Epoch [4/10], Loss: 0.5134
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([ 1., -1.], dtype=torch.float64)
这里这里 tensor([-1,  1])
这里这里 tensor([-1, -1])
这里这里 tensor([0, 1])
这里这里 tensor([-1,  1])
Epoch [5/10], Loss: 0.3517
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([1., 0.], dtype=torch.float64)
这里这里 tensor([ 1, -1])
这里这里 tensor([-1, -1])
这里这里 tensor([-1,  1])
Epoch [6/10], Loss: 0.4878
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([1, 1])
这里这里 tensor([-1, -1])
这里这里 tensor([-1.,  1.])
这里这里 tensor([ 0, -1])
Epoch [7/10], Loss: 0.6830
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([-1, -1])
这里这里 tensor([-1,  1])
这里这里 tensor([0., 1.])
这里这里 tensor([ 1, -1])
Epoch [8/10], Loss: 0.5232
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([0, 1])
这里这里 tensor([ 1, -1])
这里这里 tensor([-1, -1])
这里这里 tensor([ 1, -1])
这里这里 tensor([ 1., -1.], dtype=torch.float64)
Epoch [9/10], Loss: 0.4531
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])
这里这里 tensor([-1, -1])
这里这里 tensor([1, 1])
这里这里 tensor([-1,  1])
这里这里 tensor([-1,  0])
这里这里 tensor([-1.,  1.])
Epoch [10/10], Loss: 0.3971
Predicted: tensor([1., 1.])
True: tensor([ 1, -1])

五、损失函数

介绍nn.HingeEmbeddingLoss并使用

nn.HingeEmbeddingLossPyTorch中用于计算支持向量机的损失函数之一。它的作用是通过一个间隔边界将正样本和负样本分开。具体来说,该损失函数使用了一个margin参数,表示正负样本之间的间隔边界,然后计算正样本与该边界之间的距离和负样本与该边界之间的距离,并将它们相加。

该损失函数的数学公式如下:

l o s s ( x , y ) = 1 N ∑ i = 1 N max ⁡ ( 0 , − y i ( x i ⋅ w − b ) + m a r g i n ) loss(x,y) = \frac{1}{N}\sum_{i=1}^{N} \max(0, -y_i(x_i\cdot w - b) + margin) loss(x,y)=N1i=1Nmax(0,yi(xiwb)+margin)

其中, x x x表示输入样本, y y y表示对应的标签, w w w表示模型的权重, b b b表示模型的偏置, N N N表示样本数量, m a r g i n margin margin表示间隔边界。

在计算损失时,如果一个样本被正确地分类,则该样本的损失为0,否则根据该样本的标签和预测值,计算出该样本的距离,然后与 m a r g i n margin margin进行比较,得到该样本的损失值。

下面是一个使用nn.HingeEmbeddingLoss进行二分类任务的例子,其中使用的数据集为sklearn中的鸢尾花数据集:

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 加载数据集并进行预处理
iris = load_iris()
X, y = iris.data, iris.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# 定义模型、损失函数和优化器
model = nn.Linear(4, 1)
criterion = nn.HingeEmbeddingLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 100
batch_size = 16
for epoch in range(num_epochs):
    for i in range(0, len(X_train), batch_size):
        inputs = torch.FloatTensor(X_train[i:i+batch_size])
        labels = torch.FloatTensor(y_train[i:i+batch_size])
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels.float())
        loss.backward()
        optimizer.step()

    # 计算测试集准确率
    with torch.no_grad():
        inputs = torch.FloatTensor(X_test)
        labels = torch.FloatTensor(y_test)
        outputs = model(inputs).squeeze()
        predicted = torch.sign(outputs)
        accuracy = (predicted == labels).sum().item() / len(y_test)
    print(f"Epoch {epoch+1}: Loss={loss.item():.4f}, Accuracy={accuracy:.4f}")

在上面的例子中,首先加载鸢尾花数据集并进行标准化处理。然后定义了一个包含一个线性层的模型,使用nn.HingeEmbeddingLoss是一个PyTorch中的损失函数,用于支持向量机(SVM)学习。在SVM中,目标是将两个类别的数据分开,并且在最大化间隔的同时最小化错误分类的数量。Hinge损失函数是一种常用的SVM损失函数,它对正确分类的样本给予0损失,对于错误分类的样本给予一个非零的损失,损失随着距离正确分类边界的距离线性增加。

HingeEmbeddingLoss需要输入一个标量值作为阈值,将大于等于该阈值的样本视为正样本,将小于该阈值的样本视为负样本。具体而言,对于一个大小为 N N N的批次,标签 y y y应该是一个大小为 N N N的张量,其中1表示正类,-1表示负类。如果 y i = 0 y_i = 0 yi=0,则样本 i i i将被忽略,即不计入损失函数中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/429789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TiDB实战篇-TiDB Cluster部署

简介 部署TiDB Cluster部署,熟系集群的基础操作。 集群规划 机器拓扑 3pd,3tikv,1tidb_server.1tiflash,监控。 192.168.66.10192.168.66.20192.168.66.21 pd_servers tikv_servers tidb_servers tiflash_servers pd_servers tikv_servers monitoring_servers…

MySQL中使用IN()查询到底走不走索引?

MySQL中使用IN()查询到底走不走索引? 看数据量 EXPLAIN SELECT * from users WHERE is_doctor in (0,1); 很明显没走索引,下面再看一个sql。 EXPLAIN SELECT * from users WHERE is_doctor in (2,1);又走索引了,所以…

Yolov5一些知识

1 Yolov5四种网络模型 Yolov5官方代码中,给出的目标检测网络中一共有4个版本,分别是Yolov5s、Yolov5m、Yolov5l、Yolov5x四个模型。 1.1Yolov5网络结构图 eg:Yolov5s 2.1 Yolov3&Yolov4网络结构图 2.1.1 Yolov3网络结构图 Yolov3的网络结构是…

Matlab论文插图绘制模板第86期—带置信区间的折线图

在之前的文章中,分享了很多Matlab折线图的绘制模板: 进一步,分享一种特殊的折线图:带置信区间的折线图。 先来看一下成品效果: 特别提示:本期内容『数据代码』已上传资源群中,加群的朋友请自行…

【C++技能树】快速文本匹配 --正则表达式介绍与C++正则表达式使用

Halo,这里是Ppeua。平时主要更新C语言,C,数据结构算法…感兴趣就关注我吧!你定不会失望。 0.正则表达式存在必要性 在日常生活,或者刷题过程中我们难免需要检测一段字符是否需要是否符合规定,或在一大段字符中寻找自己想要的信息…

Mysql 数据库介绍

数据库介绍 数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,每个数据库都有一个或多个不同的API接口用于创建,访问,管理,搜索和复制所保存的数据。 我们也可以将数据存储在文件中&#xff0…

支持m2的主板换m2硬盘无法识别的问题,主板:七彩虹H410-T

记录一下我的电脑换m2硬盘遇到无法读取的问题,也给有同样问题的人留个参考,特别是七彩虹主板 主板:七彩虹H410-T 遇到的问题: m2 硬盘插上主板后,开机无法识别,打开我的电脑没有相应的盘,设备…

代码随想录---142. 环形链表 II

给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整…

【C++】vector的实现

模拟实现vector类前言一、迭代器二、重载 [ ]三、构造函数相关(重点)(1)构造函数(2)构造并使用n个值为value的元素初始化(3)区间构造(4)拷贝构造三、析构函数…

什么是科学

人人都是价值观-思辨专家_个人渣记录仅为自己搜索用的博客-CSDN博客 相关文章 人人都是中医爱好者 科学定义 关于“科学”这个词的定义,历史上曾出现过多种版本,但是目前为止还没有一个是世人公认的定义。 历史上达尔文(Charles Robert Darwin&#xff…

利用阿里云免费部署openai的Chatgpt国内直接用

背景 国内无法直接访问ChatGPT,一访问就显示 code 1020。而且最近OpenAI查的比较严格,开始大规模对亚洲地区开始封号,对于经常乱跳IP的、同一个ip一堆账号的、之前淘宝机刷账号的,账号被封的可能性极大。 那么有没有符合openai规定…

< element-Ui表格组件:表格多选功能回显勾选时因分页问题,导致无法勾选回显的全部数据 >

文章目录👉 前言👉 一、解决思路👉 二、实现代码(仅供参考,具体问题具体分析)> HTML模板> Js模板往期内容 💨👉 前言 在 Vue elementUi 开发中,elementUI中表格在…

Linux服务器怎么修改系统时间

Linux服务器怎么修改系统时间 linux服务器的系统时间,有的时候会产生误差,导致我们的程序出现一些延迟,或者其他的一些错误,那么怎么修改linux的系统时间呢? 我是艾西,今天又是跟linux小白分享小知识的时间…

C语言函数大全-- l 开头的函数

C语言函数大全 本篇介绍C语言函数大全-- l 开头的函数 1. labs&#xff0c;llabs 1.1 函数说明 函数声明函数功能long labs(long n);计算长整型的绝对值long long int llabs(long long int n);计算long long int 类型整数的绝对值 1.2 演示示例 #include <stdio.h> …

Python-Python基本用法(全:含基本语法、用户交互、流程控制、数据类型、函数、面向对象、读写文件、异常、断言等)

1 环境准备 编辑器&#xff1a;Welcome to Python.org 解释器&#xff1a;pycharm&#xff1a;Thank you for downloading PyCharm! (jetbrains.com) 2 Quick start 创建项目 new project create demo print(Dad!!)3 基本语法 3.1 print 直接打印 print(Dad!!)拼接打印…

记录-Vue.js模板编译过程揭秘:从模板字符串到渲染函数

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 Vue.js是一个基于组件化和响应式数据流的前端框架。当我们在Vue中编写模板代码时&#xff0c;它会被Vue编译器处理并转换为可被浏览器解析的JavaScript代码。Vue中的模板实际上是HTML标记和Vue指令的组…

STM32HAL库 串口USART的使用

STM32HAL库 串口USART的使用 文章目录STM32HAL库 串口USART的使用前言一、配置USART1串口通信引脚二、使用步骤三、串口中断回调函数1. 配置2. 在icode中增加usart.c和usart.h文件3. 中断处理对比4. 编写串口控制程序总结前言 本文为串口输出打印的hal库&#xff0c;参考洋桃电…

【LeetCode】剑指 Offer 57. 和为 s 的数字 p280 -- Java Version

1. 题目介绍&#xff08;57. 和为 s 的数字&#xff09; 面试题57&#xff1a;和为 s 的数字&#xff0c; 一共分为两小题&#xff1a; 题目一&#xff1a;和为 s 的两个数字题目二&#xff1a;和为 s 的连续正数序列 2. 题目1&#xff1a;和为s的两个数字 题目链接&#xff1…

图结构基本知识

图1. 相关概念2. 图的表示方式3. 图的遍历3.1 深度优先遍历&#xff08;DFS&#xff09;3.2 广度优先遍历&#xff08;BFS&#xff09;1. 相关概念 图G(V,E) &#xff1a;一种数据结构&#xff0c;可表示“多对多”关系&#xff0c;由顶点集V和边集E组成&#xff1b;顶点(vert…

数据库管理-第六十七期 SQL Domain 2(20230414)

数据库管理 2023-04-14第六十七期 SQL Domain 21 Domain函数示例总结第六十七期 SQL Domain 2 昨晚割接&#xff0c;搭了一套19c的ADG&#xff0c;今天睡了个懒觉&#xff0c;早上把笔记本内存扩到了64GB&#xff0c;主要是为了后面做实验。然后下午拼了个乐高&#xff0c;根据…