【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

news2025/1/12 1:51:24

【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

支持向量机(Support Vector Machine,SVM)是一种强大的监督学习算法,用于二分类和多分类问题。它的主要思想是找到一个最优的超平面,可以在特征空间中将不同类别的数据点分隔开。

下面是使用PyTorch实现支持向量机算法的基本步骤和原理:

  1. 数据准备: 首先,你需要准备你的训练数据。每个数据点应该具有特征(Feature)和对应的标签(Label)。特征是用于描述数据点的属性,标签是数据点所属的类别。

  2. 数据预处理: 根据SVM的原理,数据点需要线性可分。因此,你可能需要进行一些数据预处理,如特征缩放或标准化,以确保数据线性可分。

  3. 定义模型: 在PyTorch中,你可以定义一个支持向量机模型作为一个线性模型,例如使用nn.Linear

  4. 定义损失函数: SVM的目标是最大化支持向量到超平面的距离,即最大化间隔(Margin)。这可以通过最小化损失函数来实现,通常使用hinge loss(合页损失)。PyTorch提供了nn.MultiMarginLoss损失函数,它可以用于SVM训练。

  5. 定义优化器: 选择一个优化器,如torch.optim.SGD,来更新模型的参数以最小化损失函数。

  6. 训练模型: 使用训练数据对模型进行训练。在每个训练步骤中,计算损失并通过优化器更新模型参数。

  7. 预测: 训练完成后,你可以使用训练好的模型对新的数据点进行分类预测。对于二分类问题,可以使用模型的输出值来判断数据点所属的类别。

2.数学公式

当使用支持向量机(SVM)进行数据分类预测时,目标是找到一个超平面(或者在高维空间中是一个超曲面),可以将不同类别的数据点有效地分隔开。以下是SVM的数学原理:

  1. 超平面方程: 在二维情况下,超平面可以表示为

    w 1 x 1 + w 2 x 2 + b = 0 w_1 x_1 + w_2 x_2 + b = 0 w1x1+w2x2+b=0

  2. 决策函数: 数据点 (x) 被分为两个类别的决策函数为

    f ( x ) = w T x + b f(x) = w^T x + b f(x)=wTx+b

  3. 间隔(Margin): 对于一个给定的超平面,数据点到超平面的距离被称为间隔。支持向量机的目标是找到能最大化间隔的超平面。间隔可以用下面的公式计算:

    间隔 = 2 ∥ w ∥ \text{间隔} = \frac{2}{\|w\|} 间隔=w2

  4. 支持向量: 支持向量是离超平面最近的那些数据点。这些点对于确定超平面的位置和间隔非常重要。支持向量到超平面的距离等于间隔。

  5. 最大化间隔: SVM 的目标是找到一个超平面,使得所有支持向量到该超平面的距离(即间隔)都最大化。这等价于最小化法向量的范数 (|w|),即:

    最小化 1 2 ∥ w ∥ 2 \text{最小化} \quad \frac{1}{2}\|w\|^2 最小化21w2

  6. 对偶问题和核函数: 对偶问题的解决方法涉及到拉格朗日乘子,可以得到一个关于训练数据点的内积的表达式。这样,如果直接在高维空间中计算内积是非常昂贵的,可以使用核函数来避免高维空间的计算。核函数将数据映射到更高维的空间,并在计算内积时使用高维空间的投影,从而实现了在高维空间中的计算,但在计算上却更加高效。

总之,SVM利用线性超平面来分隔不同类别的数据点,通过最大化支持向量到超平面的距离来实现分类。对偶问题和核函数使SVM能够处理非线性问题,并在高维空间中进行计算。以上是SVM的基本数学原理。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import torch
import torch.nn as nn
import pandas as pd
import numpy as np  # Don't forget to import numpy for the functions using it
import matplotlib.pyplot as plt  # Import matplotlib for plotting
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix

class SVM(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SVM, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)


def train(model, X, y, num_epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class classification
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        inputs = torch.tensor(X, dtype=torch.float32)
        labels = torch.tensor(y, dtype=torch.long)  # Use long for class indices

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')


def test(model, X, y):
    inputs = torch.tensor(X, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.long)  # Use long for class indices

    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == labels).float().mean()
        print(f'Accuracy on test set: {accuracy:.2f}')

# Define the plot functions
def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()


def main():
    data = pd.read_excel('iris.xlsx')
    X = data.iloc[:, :-1].values
    y = data.iloc[:, -1].values

    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    num_classes = len(label_encoder.classes_)
    model = SVM(X_train.shape[1], num_classes)
    num_epochs = 1000
    learning_rate = 0.001

    train(model, X_train, y_train, num_epochs, learning_rate)

    # Call the test function to get predictions
    inputs = torch.tensor(X_test, dtype=torch.float32)
    labels = torch.tensor(y_test, dtype=torch.long)
    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

    # Convert torch tensors back to numpy arrays
    y_true = labels.numpy()
    y_pred = predicted.numpy()

    test(model, X_test, y_test)

    # Call the plot functions
    conf_matrix = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(conf_matrix, label_encoder.classes_)
    plot_predictions_vs_true(y_true, y_pred)


if __name__ == '__main__':
    main()


7.运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/872657.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Megatron-DeepSpeed】张量并行工具代码mpu详解(四):张量并行版Embedding层及交叉熵的实现及测试

相关博客 【Megatron-DeepSpeed】张量并行工具代码mpu详解(四):张量并行版Embedding层及交叉熵的实现及测试 【Megatron-DeepSpeed】张量并行工具代码mpu详解(三):张量并行层的实现及测试 【Megatron-DeepSpeed】张量并行工具代码mpu详解(一)&#xff1a…

时序预测 | MATLAB实现基于CNN卷积神经网络的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于CNN卷积神经网络的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于CNN卷积神经网络的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 1.Matlab实现CNN卷积神经网络时间序列预测未…

webpack中常见的Loader

目录 1.webpack中的loader是什么?配置方式 2. loader特性3.常见的loader 1.webpack中的loader是什么? loader 用于对模块的"源代码"进行转换,在 import 或"加载"模块时预处理文件 webpack做的事情,仅仅是分…

Linux printf函数输出问题

1.printf函数并不会直接将数据输出到屏幕,而是先放到缓冲区中。 原因是: 解决效率和性能的问题。 比如说,printf在打印数据到屏幕上的时候不经过缓冲区,而是直接调用内核,此时内核就相当于另外一个进程,这…

Linux之【进程间通信(IPC)】-总结篇

Linux之【进程间通信(IPC)】-总结篇 管道System V共享内存System V消息队列System V信号量IPC资源的管理方式 往期文章 1.进程间通信之管道 2.进程间通信之System V共享内存 管道 进程之间具有独立性,拥有自己的虚拟地址空间,因…

基于TorchViz详解计算图(附代码)

文章目录 0. 前言1. 计算图是什么?2. TorchViz的安装3. 计算图详解 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,…

【学会动态规划】买卖股票的最佳时机 IV(18)

目录 动态规划怎么学? 1. 题目解析 2. 算法原理 1. 状态表示 2. 状态转移方程 3. 初始化 4. 填表顺序 5. 返回值 3. 代码编写 写在最后: 动态规划怎么学? 学习一个算法没有捷径,更何况是学习动态规划, 跟我…

【马蹄集】第二十二周——进位制与字符串专题

进位制与字符串专题 目录 MT2179 01操作MT2182 新十六进制MT2172 萨卡兹人MT2173 回文串等级MT2175 五彩斑斓的串 MT2179 01操作 难度:黄金    时间限制:1秒    占用内存:128M 题目描述 刚学二进制的小码哥对加减乘除还不熟,他…

DataGrip 安装 与 连接MySQL数据库

DataGrip 安装 与 连接MySQL数据库 Jetbrains是著名的编程工具商业软件提供商,旗下有很多软件。包括IDE、团队开发工具、插件和微软.Net辅助工具、包括自创语言Kotlin等。我们通常用的和说的全家桶,主要就是指它的IDE套件。Jetbrains的IDE工具都支持跨平…

web-Element

在vueapp里<div><!-- <h1>{{message}}</h1> --><element-view></element-view></div> <div><!-- <h1>{{message}}</h1> --><element-view></element-view></div>在view新建个文件 <t…

AIGC+游戏:一个被忽视的长赛道

&#xff08;图片来源&#xff1a;Pixels&#xff09; AIGC彻底变革了游戏&#xff0c;但还不够。 数科星球原创 作者丨苑晶 编辑丨大兔 消费还没彻底复苏&#xff0c;游戏却已经出现拐点。 在游戏热度猛增的背后&#xff0c;除了版号的利好因素外&#xff0c;AIGC技术的广泛…

项目实战 — 消息队列(8){网络通信设计②}

目录 一、客户端设计 &#x1f345; 1、设计三个核心类 &#x1f345; 2、完善Connection类 &#x1f384; 读取请求和响应、创建channel &#x1f384; 添加扫描线程 &#x1f384; 处理不同的响应 &#x1f384; 关闭连接 &#x1f345; 3、完善Channel类 &#x1f384; 编…

机器学习编译系列

机器学习编译MLC 1. 引言2. 机器学习编译--概述2.1 什么是机器学习编译 1. 引言 陈天奇目前任教于CMU&#xff0c;研究方向为机器学习系统。他是TVM、MXNET、XGBoost的主要作者。2022年夏天&#xff0c;陈天奇在B站开设了《机器学习编译》的课程。   《机器学习编译》课程共分…

2023最新水果编曲软件FL Studio 21.1.0.3267音频工作站电脑参考配置单及系统配置要求

音乐在人们心中的地位日益增高&#xff0c;近几年音乐选秀的节目更是层出不穷&#xff0c;喜爱音乐&#xff0c;创作音乐的朋友们也是越来越多&#xff0c;音乐的类型有很多&#xff0c;好比古典&#xff0c;流行&#xff0c;摇滚等等。对新手友好程度基本上在首位&#xff0c;…

全网最牛,Appium自动化测试框架-关键字驱动+数据驱动实战(一)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、关键字驱动框架…

Stm32-使用TB6612驱动电机及编码器测速

这里写目录标题 起因一、电机及编码器的参数二、硬件三、接线四、驱动电机1、TB6612电机驱动2、定时器的PWM模式驱动电机 五、编码器测速1、定时器的编码器接口模式2、定时器编码器模式测速的原理3、编码器模式的配置4、编码器模式相关代码5、测速方法 六、相关问题以及解答1、…

关于Cesium的常见需求整理之点位和弹窗(点位弹窗)

一、点位上图 ①在Cesium中&#xff0c;每个自定义的地图元素被视为一个entity对象&#xff0c;如果我们要添加点位到地图上&#xff0c;那就必须先创建一个entity对象。 var entity new Cesium.Entity({position: position, });以上代码我们创建了一个entity对象&#xff0…

Autosar通信入门系列06-聊聊CAN通信的线与机制与ACK应答

本文框架 1. 概述2. CAN通信的线与机制3. ACK应答机制理解 1. 概述 本文为Autosar通信入门系列介绍&#xff0c;如您对AutosarMCAL配置&#xff0c;通信&#xff0c;诊断等实战有更高需求&#xff0c;可以参见AutoSar 实战进阶系列专栏&#xff0c;快速链接&#xff1a;AutoSa…

数据库基础(增删改查)

目录 MySQL 背景知识 数据库基础操作 1.创建数据库 2.查看所有数据库 3.选中指定的数据库 4.删除数据库 数据库表操作 MySQL的数据类型 1.创建表 3.查看指定表的结构 4.删除表 增删改 新增操作 修改(Updata) 删除语句 面试题 查询操作 指定列查询 查询的列为表达式…