支持向量机和梯度提升决策树

news2024/9/19 10:47:00

支持向量机(SVM, Support Vector Machine)

原理

支持向量机是一种监督学习模型,适用于分类和回归任务。SVM 的目标是找到一个能够最大化类别间间隔的超平面,以便更好地对数据进行分类。对于非线性数据,SVM 通过核技巧将数据映射到高维空间,使得在该空间中可以找到一个线性可分的超平面。

公式

SVM 的核心思想是找到一个超平面使得两类数据点之间的间隔最大化。给定训练数据集 (xi,yi),其中 xi 是特征向量,yi 是类别标签(+1 或 -1),SVM 通过求解以下优化问题来找到最优的超平面:

满足约束条件:

其中,w 是超平面的法向量,b 是偏置。

对于非线性数据,SVM 使用核函数 K(xi,xj) 将数据映射到高维空间。常用的核函数包括线性核、径向基核(RBF 核)和多项式核等。

生活场景应用的案例

邮件分类:SVM 可以用于电子邮件的垃圾邮件分类。假设我们有一个包含电子邮件内容的数据集,并使用特征如词频、长度等。我们可以使用 SVM 模型来预测某封邮件是否为垃圾邮件。

案例描述

假设我们有一个包含电子邮件内容和标签的数据集,包括以下特征:

  • 词频(Term Frequency)
  • 词汇量(Vocabulary Size)
  • 邮件长度(Email Length)
  • 特定关键词出现次数(Keyword Occurrence)

我们希望通过这些特征预测某封邮件是否为垃圾邮件。我们可以使用 SVM 模型进行训练和预测。训练完成后,我们可以使用模型来预测新邮件是否为垃圾邮件,并评估模型的性能。

代码解析

下面是一个使用 Python 实现上述垃圾邮件分类案例的示例,使用了 scikit-learn 库。

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# 创建示例数据
data = {
    'term_frequency': [0.5, 0.1, 0.3, 0.8, 0.2, 0.4, 0.7, 0.6, 0.9, 0.3],
    'vocabulary_size': [100, 150, 120, 200, 130, 140, 170, 180, 220, 110],
    'email_length': [500, 600, 550, 700, 580, 620, 680, 650, 720, 590],
    'keyword_occurrence': [2, 1, 3, 5, 1, 2, 4, 3, 6, 1],
    'spam': [1, 0, 0, 1, 0, 0, 1, 0, 1, 0]  # 1: Spam, 0: Not Spam
}

# 转换为 DataFrame
df = pd.DataFrame(data)

# 特征和目标变量
X = df[['term_frequency', 'vocabulary_size', 'email_length', 'keyword_occurrence']]
y = df['spam']

# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建 SVM 模型并训练
model = SVC(kernel='linear', random_state=42)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 评估模型
accuracy = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)

print(f"Accuracy: {accuracy}")
print("Confusion Matrix:")
print(cm)
print("Classification Report:")
print(report)

在这个示例中:

  1. 我们创建了一个包含电子邮件信息的示例数据集,包括词频、词汇量、邮件长度和特定关键词出现次数等特征。
  2. 将数据集转换为 DataFrame 格式,并将特征和目标变量分开。
  3. 将数据集拆分为训练集和测试集。
  4. 使用训练集训练 SVM 分类模型。
  5. 通过测试集进行预测并评估模型的性能。
  6. 输出准确率(accuracy)、混淆矩阵(confusion matrix)和分类报告(classification report)。

这个案例展示了如何使用 SVM 模型来预测电子邮件是否为垃圾邮件,基于电子邮件的词频、词汇量、邮件长度和特定关键词出现次数等特征。模型训练完成后,可以用于预测新邮件是否为垃圾邮件,并帮助用户过滤垃圾邮件。

梯度提升决策树(GBDT, Gradient Boosting Decision Tree)

原理

梯度提升决策树(GBDT)是一种集成学习方法,通过迭代地训练多个决策树模型来逐步优化预测结果。每一棵新树都对之前树的预测误差进行拟合,从而减少整体的预测误差。GBDT 在回归和分类任务中都表现出色,尤其在处理非线性关系和复杂数据时具有优势。

GBDT 的核心思想是利用梯度下降算法来最小化损失函数。每一轮迭代中,GBDT 通过在负梯度方向上添加一个新模型来更新当前模型的预测。

公式

给定训练数据集 (xi,yi),GBDT 通过以下步骤构建模型:

  1. 初始化模型:

  1. 对于 m=1,2,...,Mm = 1, 2, ..., Mm=1,2,...,M:
  • 计算当前模型的负梯度(残差):

  • 拟合一个新的决策树模型 hm(x)h_m(x)hm(x) 来对负梯度进行拟合:

  • 更新模型:

其中,η 是学习率,控制每棵树对最终模型的贡献。

生活场景应用的案例

房价预测:GBDT 可以用于房价预测。假设我们有一个包含房屋特征的数据集,如面积、房间数、位置等。我们可以使用 GBDT 模型来预测房价。

案例描述

假设我们有一个包含房屋特征的数据集,包括以下特征:

  • 面积(Square Footage)
  • 房间数(Number of Rooms)
  • 房龄(Age of the House)
  • 位置评分(Location Score)

我们希望通过这些特征预测房价。我们可以使用 GBDT 模型进行训练和预测。训练完成后,我们可以使用模型来预测新房屋的房价,并评估模型的性能。

代码解析

下面是一个使用 Python 实现上述房价预测案例的示例,使用了 scikit-learn 库。

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

# 创建示例数据
data = {
    'square_footage': [1500, 2500, 1800, 3000, 1200, 2000, 3500, 2800, 1600, 2200],
    'number_of_rooms': [3, 4, 3, 5, 2, 4, 5, 4, 3, 4],
    'age_of_house': [10, 5, 8, 2, 15, 7, 1, 4, 12, 6],
    'location_score': [7, 8, 7, 9, 5, 7, 10, 8, 6, 7],
    'price': [300000, 450000, 350000, 500000, 250000, 400000, 550000, 480000, 320000, 410000]
}

# 转换为 DataFrame
df = pd.DataFrame(data)

# 特征和目标变量
X = df[['square_footage', 'number_of_rooms', 'age_of_house', 'location_score']]
y = df['price']

# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建 GBDT 模型并训练
model = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Mean Squared Error: {mse}")
print(f"R-squared: {r2}")

# 可视化特征重要性
feature_importance = model.feature_importances_
features = X.columns
indices = np.argsort(feature_importance)

plt.figure(figsize=(10, 6))
plt.title('Feature Importances')
plt.barh(range(len(indices)), feature_importance[indices], color='b', align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.show()

在这个示例中:

  1. 我们创建了一个包含房屋信息的示例数据集,包括面积、房间数、房龄和位置评分等特征。
  2. 将数据集转换为 DataFrame 格式,并将特征和目标变量分开。
  3. 将数据集拆分为训练集和测试集。
  4. 使用训练集训练 GBDT 回归模型。
  5. 通过测试集进行预测并评估模型的性能。
  6. 输出均方误差(Mean Squared Error, MSE)和决定系数(R-squared, R²)。

  1. 可视化特征的重要性,展示每个特征对预测结果的相对贡献。

这个案例展示了如何使用 GBDT 模型来预测房屋的价格,基于房屋的面积、房间数、房龄和位置评分等特征。模型训练完成后,可以用于预测新房屋的价格,并帮助用户在买房或卖房时做出更好的决策。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言程序设计17

程序设计17 问题17_1代码17_1结果17_1 问题17_2代码17_2结果17_2 问题17_3代码17_3结果17_3 问题17_1 下列给定程序的功能是 :调用函数 f u n fun fun 将指定源文件中的内容复制到指定的目标文件中,复制成功时函数返回 1 1 1 ,失败时返回 …

Brant-2:开启脑信号分析新篇章的基础模型

人工智能咨询培训老师叶梓 转载标明出处 脑信号,包括通过侵入性或非侵入性方式收集的脑电图(EEG)和立体脑电图(SEEG)等生物测量信息,为我们理解大脑的生理功能和相关疾病的机制提供了宝贵的洞见。然而脑信号…

vue3+cesium创建地图

1.我这边使用的是cdn引入形式 比较简单的方式 不需要下载依赖 在项目文件的index.html引入 这样cesium就会挂载到window对象上面去了 <!-- 引入cesium-js文件 --><script src"https://cesium.com/downloads/cesiumjs/releases/1.111/Build/Cesium/Cesium.js"…

Java二十三种设计模式-外观模式(9/23)

外观模式&#xff1a;简化复杂系统的统一接口 引言 外观模式&#xff08;Facade Pattern&#xff09;是一种结构型设计模式&#xff0c;它为子系统中的一组接口提供一个统一的高层接口。外观模式定义了一个可以与复杂子系统交互的简化接口&#xff0c;使得子系统更加易于使用…

科研绘图系列:R语言GWAS曼哈顿图(Manhattan plot)

介绍 曼哈顿图(Manhattan Plot)是一种常用于展示全基因组关联研究(Genome-Wide Association Study, GWAS)结果的图形。GWAS是一种研究方法,用于识别整个基因组中与特定疾病或性状相关的遗传变异。 特点: 染色体表示:曼哈顿图通常将每个染色体表示为一个水平条,染色体…

Git 基础操作手册:轻松掌握常用命令

Git 操作完全手册&#xff1a;轻松掌握常用命令 引言一、暂存&#xff1a;git add ✏️二、提交&#xff1a;git commit &#x1f4dd;三、拉取、拉取合并 &#x1f504;四、推送&#xff1a;git push &#x1f310;五、查看状态&#xff1a;git status &#x1f4ca;六、查看历…

第09课 Scratch入门篇:小鸡啄米-自制积木实现

小鸡啄米-自制积木 故事背景&#xff1a; 在上一章的案例中&#xff0c;实现了小鸡啄米的动画&#xff0c;但是发现太多的重复代码&#xff0c;是我们编程的时候代码泰国繁琐&#xff0c;我们可以使用自制积木&#xff0c;将相同的代码提取出来制作成一个新的积木&#xff0c;在…

ESP32CAM人工智能教学17

ESP32CAM人工智能教学17 内网穿透,视频上云 小智一心想让ESP32Cam视频能发送到云端,这样我们在任何地方,都能查看到家里的摄像头了,甚至能控制摄像头的小车了。 内网穿透的技术原理想要让ESP32Cam把视频发送到云端,就必须要做内网穿透。也就是用户的手机(或电脑),可以…

Windows中如何配置Gradle环境变量?

本篇教程,主要介绍,如何在Windows中配置Gradle7.4环境变量 一、下载安装包 安装包下载;gradle-7.4-all.zip 二、解压安装包

Python基础知识笔记---保留字

保留字&#xff0c;也称关键字&#xff0c;是指被编程语言内部定义并保留使用的标识符。 一、保留字概览 二、保留字用途 1. False&#xff1a;表示布尔值假。 2. None&#xff1a;表示空值或无值。 3. True&#xff1a;表示布尔值真。 4. not&#xff1a;布尔逻辑操作符…

LabVIEW 使用 I/O 服务器

I/O 服务器是共享变量引擎&#xff08;SVE&#xff09;插件&#xff0c;用于与不使用NI专有的NI发布-订阅协议&#xff08;NI-PSP&#xff09;的设备和应用程序通信。I/O 服务器充当LabVIEW VI中的共享变量与OPC、Modbus或EPICS数据标签之间的桥梁。它们插入SVE中&#xff0c;提…

生成式人工智能最重要的三个神经网络,从谷歌DeepDream、Magenta、到NVIDIA的StyleGAN

神经网络模型&#xff08;Neural Network Model&#xff09;是一种受生物大脑启发的机器学习模型&#xff0c;用于模拟人脑的结构和功能。它由大量相互连接的人工神经元&#xff08;节点&#xff09;组成&#xff0c;这些神经元按层级结构排列&#xff0c;通常包括输入层、隐藏…

OSI七层网络模型:构建网络通信的基石

在计算机网络领域&#xff0c;OSI&#xff08;Open Systems Interconnection&#xff09;七层模型是理解网络通信过程的关键框架。该模型将网络通信过程细分为七个层次&#xff0c;每一层都有其特定的功能和职责&#xff0c;共同协作完成数据从发送端到接收端的传输。接下来&am…

制品库nexus

详见&#xff1a;Sonatype Nexus Repository搭建与使用&#xff08;详细教程3.70.1&#xff09;-CSDN博客 注意事项&#xff1a; 1.java8环境使用nexus-3.69.0-02-java8-unix.tar.gz包 2.java11环境使用nexus-3.70.1-02-java11-unix.tar.gz包 3.注意使用制品库/etc/yum.repos.…

leetcode 1596 每位顾客经常订购的商品(postgresql)

需求 表&#xff1a;Customers ---------------------- | Column Name | Type | ---------------------- | customer_id | int | | name | varchar | ---------------------- customer_id 是该表主键 该表包含所有顾客的信息 表&#xff1a;Orders ---------------------- …

LeetCode 144.二叉树的前序遍历 C写法

LeetCode 144.二叉树的前序遍历 思路&#x1f9d0;&#xff1a; 遍历很简单&#xff0c;但是我们需要开空间进行值的存储&#xff0c;结点个数也可以用递归进行统计&#xff0c;开好空间就可以用数组进行值的存储&#xff0c;注意下标要么用全局&#xff0c;要么指针解引用&…

【Canvas与艺术】五色五角大楼

【成图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>五L莫比乌斯五角大楼</title><style type"text/css&qu…

解决国内 github.com 打不开进不去的方法

解决国内 github.com 打不开进不去的方法 修改host文件 电脑的hosts文件在下面这个地址&#xff0c;找到hosts文件 C:\Windows\System32\Drivers\etc 右键点击hosts文件,选择复制,然后粘贴到桌面上。右键点击桌面上的hosts文件,选择“用记事本打开该文件”,hosts 文件中需要写…

VPX电源模块性能检测需要测试哪些指标?

随着技术的不断进步&#xff0c;VPX电源模块已成为军工、通讯等关键行业不可或缺的组成部分。接下来&#xff0c;让我们一起了解VPX电源以及如何检测其性能。 什么是VPX电源&#xff1f; VPX电源是一种军用嵌入式计算机规范下的电源产品&#xff0c;采用了7排引脚阵列高速连接器…