【机器学习】常用的分类算法代码实现

news2025/1/24 11:04:04

文章目录

  • 任务&数据集
  • 一、基算法
    • 1.1 决策树(Decision Tree)
    • 1.2 逻辑回归(Logistic Regression)
    • 1.3 支持向量机(Support Vector Machine, SVM)
  • 二、集成算法
    • 2.1 随机森林(Random Forest)
    • 2.2 AdaBoost(Adaptive Boosting)
    • 2.3 GBDT(Gradient Boosting Decision Tree)
    • 2.4 XGBoost(eXtreme Gradient Boosting)
    • 2.5 LightGBM(Light Gradient Boosting Machine)
  • 参考资料

任务&数据集

【任务描述】:基于鸢尾花数据集训练相应的分类器,实现分类。

【数据集】:鸢尾花数据集最初由Edgar Anderson 测量得到,而后在著名的统计学家和生物学家R.A Fisher于1936年发表的文章「The use of multiple measurements in taxonomic problems」中被使用,用它作为线性判别分析(Linear Discriminant Analysis)的一个例子,证明分类的统计方法。该数据集是在机器学习领域一个常用的数据集。

数据集地址:http://archive.ics.uci.edu/ml/datasets/Iris

在这里插入图片描述

一、基算法

1.1 决策树(Decision Tree)

(1)算法原理:决策树根据样本数据集的数据特征对数据集进行划分,直到针对所有特征都划分过,或者划分的数据子集的所有数据的类别标签相同。

(2)核心代码:

这里,我们给出完整的实现代码,其中一部分代码可复用(如数据加载、划分训练集测试集等),后面不再重复给出。

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error
import matplotlib.pyplot as plt
from sklearn2pmml import sklearn2pmml, PMMLPipeline  # 调用生成pmml文件方法

# 导入数据
iris = load_iris()
x = iris.data
y = iris.target
print(x.shape, y.shape)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=40)  # 划分训练集和验证集, random_state即随机种子设定为固定的值后,每次数据集的拆分结果一致,确保实验可重复性

# 训练模型
dc_tree = tree.DecisionTreeClassifier(criterion='entropy', min_samples_leaf=5)
dc_tree.fit(x_train, y_train)

# 测试模型
y_pred = dc_tree.predict(x_test)
acc = accuracy_score(y_test, y_pred)
print("acc:", acc)
mse = mean_squared_error(y_test, y_pred) ** 0.5
print("mse:", mse)

# 决策树可以可视化
fig = plt.figure(figsize=(20, 20))
tree.plot_tree(dc_tree, filled=True,
               feature_names=['sepal length', 'sepal width', 'petal length', 'petal width'],
               class_names=['Setosa', 'Versicolour', 'Virginica'])   # 山鸢尾、杂色鸢尾、维吉尼亚鸢尾

plt.savefig("./outputs/dt_iris.png", bbox_inches='tight', pad_inches=0.0)
plt.show()


# 生成PMML文件,方便通过多种工具查看、导出、调用等
pipeline = PMMLPipeline([("classifier", dc_tree)])
pipeline.fit(x_train, y_train)
sklearn2pmml(pipeline, "./outputs/dt_iris.pmml")

决策树可视化:
在这里插入图片描述

1.2 逻辑回归(Logistic Regression)

(1)算法原理:逻辑回归是一种线性分类器,通过logistic函数,将特征映射成一个概率值,来判断输入数据的类别。

(2)核心代码:

from sklearn.linear_model import LogisticRegression

# 训练模型
lr = LogisticRegression()
lr.fit(x_train, y_train)

1.3 支持向量机(Support Vector Machine, SVM)

(1)算法原理:寻找一个能够正确划分训练数据集并且几何间隔最大的分离超平面。

(2)核心代码:

from sklearn import svm

# 训练模型
svm = svm.SVC(C = 1.0, kernel = 'rbf', degree = 3, gamma = 'scale', coef0 = 0.0, shrinking = True, probability = False, tol = 0.001, cache_size = 200, class_weight = None, verbose = False, max_iter = -1, decision_function_shape = 'ovr', break_ties = False, random_state = None)
svm.fit(x_train, y_train)

二、集成算法

2.1 随机森林(Random Forest)

(1)算法原理:使用CART树作为弱分类器,将多个不同的决策树进行组合,利用这种组合来降低单棵决策树的可能带来的片面性和判断不准确性。

(2)核心代码:

from sklearn.ensemble import RandomForestClassifier

# 训练模型
clf = RandomForestClassifier(n_estimators = 100, criterion = 'gini', max_depth = None, min_samples_split = 2, min_samples_leaf = 1, min_weight_fraction_leaf = 0.0, max_features = 'auto', max_leaf_nodes = None, min_impurity_decrease = 0.0, min_impurity_split = None, bootstrap = True, oob_score = False, n_jobs = None, random_state = None, verbose = 0, warm_start = False, class_weight = None, ccp_alpha = 0.0, max_samples = None)
clf.fit(x_train, y_train)

2.2 AdaBoost(Adaptive Boosting)

(1)算法原理:AdaBoost算法中,前一个基本分类器分错的样本会得到加强,加权后的全体样本再次被用来训练下一个基本分类器。同时,在每一轮中加入一个新的弱分类器,直到达到某个预定的足够小的错误率或达到预先指定的最大迭代次数。

(2)核心代码:

from sklearn.ensemble import AdaBoostClassifier

# 训练模型
clf = AdaBoostClassifier(base_estimator = None, n_estimators = 50, learning_rate = 1.0, algorithm = 'SAMME.R', random_state = None)
clf.fit(x_train, y_train)

2.3 GBDT(Gradient Boosting Decision Tree)

(1)算法原理:GBDT是每次建立单个分类器时,是在之前建立的模型的损失函数的梯度下降方向。GBDT的核心在于每一棵树学的是之前所有树结论和的残差,残差就是真实值与预测值的差值,所以为了得到残差,GBDT中的树全部是回归树,之所以不用分类树,是因为分类的结果相减是没有意义的。

(2)核心代码:

from sklearn.ensemble import GradientBoostingClassifier

# 训练模型
clf = GradientBoostingClassifier(loss = 'deviance', learning_rate = 0.1, n_estimators = 100, subsample = 1.0, criterion = 'friedman_mse', min_samples_split = 2, min_samples_leaf = 1, min_weight_fraction_leaf = 0.0, max_depth = 3, min_impurity_decrease = 0.0, min_impurity_split = None, init = None, random_state = None, max_features = None, verbose = 0, max_leaf_nodes = None, warm_start = False, presort = 'deprecated', validation_fraction = 0.1, n_iter_no_change = None, tol = 0.0001, ccp_alpha = 0.0)
clf.fit(x_train, y_train)

2.4 XGBoost(eXtreme Gradient Boosting)

(1)算法原理:XGBoost的原理与GBDT基本相同,但XGB是在GBDT基础上的优化。相比而言,XGB主要有2点优化:

  • XGB支持并行,速度快;
  • 损失函数加入了正则项,防止过拟合。

(2)核心代码:

from xgboost import XGBClassifier

# 训练模型
xgb = XGBClassifier(max_depth=3, learning_rate=0.1, n_estimators=100, verbosity=1, silent=None, objective='binary:logistic', booster='gbtree', n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, random_state=0, seed=None, missing=None)
xgb.fit(x_train, y_train)

# 特征重要性
print("Feature importances: ", list(xgb.feature_importances_))

这里,XGBoost模型的list(xgb.feature_importances_)可以输出各个分类特征对于分类的重要程度,可以帮助我们筛选出对分类任务较为重要的特征。

这里,XGBoost本身带有许多参数,我们可以采用网格搜索法 或者 Optuna 框架 实现调参:

  1. 基于GridSearch网格搜索法调参

当需要调整的参数量较少时,我们可以直接使用sklearn 中的 GridSearchCV 实现调参:

from sklearn.model_selection import GridSearchCV

# GridSearch网格搜索法调参
xgb = XGBClassifier(max_depth=3)

param_grid = {
    "learning_rate": [0.01, 0.1, 1],
    "n_estimators": [20, 40]
}

xgb = GridSearchCV(xgb, param_grid)
# 训练模型
xgb.fit(x_train, y_train)
# 输出最优参数
print("Best parameters found by grid search are: ", xgb.best_params_)
  1. 基于Optuna框架调参

Optuna是一个开源的超参数优化(HPO)框架,用于自动执行超参数的搜索空间。 为了找到最佳的超参数集,Optuna使用贝叶斯方法。 它支持下面列出的各种类型的采样器:

  • GridSampler (使用网格搜索)
  • RandomSampler (使用随机采样)
  • TPESampler (使用树结构的Parzen估计器算法)
  • CmaEsSampler (使用CMA-ES算法)

具体实现可参考博客:基于 Optuna 的模型超参数优化

2.5 LightGBM(Light Gradient Boosting Machine)

(1)算法原理:LightGBM是在XGB基础上的优化,主要优化点在于速度更快、占用内存更小、精确度更高、支持类别变量。

(2)核心代码:

import lightgbm as lgb
from sklearn.model_selection import train_test_split

#3、训练、验证模型
gbm = lgb.LGBMRegressor(boosting_type = 'gbdt', num_leaves = 31, max_depth = -1, learning_rate = 0.05, n_estimators = 100, subsample_for_bin = 200000, objective = 'regression', class_weight = None, min_split_gain = 0.0, min_child_weight = 0.001, min_child_samples = 20, subsample = 1.0, subsample_freq = 0, colsample_bytree = 1.0, reg_alpha = 0.0, reg_lambda = 0.0, random_state = None, n_jobs = -1, silent = True, importance_type = 'split')
gbm.fit(X_train, y_train, eval_set = [(X_test, y_test)], eval_metric = 'l1', early_stopping_rounds = 10)
gbm.fit(X_train, y_train)

LightGBM中也涉及参数的调整,调参方式和XGBoost类似,此处不再赘述。

参考资料

几种风控算法的原理和代码实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1816040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哪个牌子洗地机最好?四款甄选佳品安利,质量放心

作为一个熟悉智能清洁家电的行业者,洗地机可谓是实用性最高的地面清洁工具,这个实用性一方面是清洁力强,它集合了扫地和拖地能力,另一方面是操作方便,清洁速度快。可是面对市面上种类繁多的智能清洁家电,往…

【Python】已完美解决:(Python键盘中断报错问题) KeyboardInterrupt

文章目录 一、问题背景二、可能出错的原因三、错误代码示例四、正确代码示例(结合实战场景)五、注意事项 已解决:Python中处理KeyboardInterrupt(键盘中断)报错问题 一、问题背景 在Python编程中,当我们运…

晨持绪科技:抖音网店怎么做有前景

在数字时代的浪潮中,抖音平台以其独特的魅力和庞大的用户基础成为电商的新阵地。开设一家有前景的抖音网店,不仅需要对市场脉搏有敏锐的洞察力,还需融合创新思维与数据驱动的营销策略。 明确定位是成功的先声。深入分析目标消费群体的需求与偏…

SpringCash

文章目录 简介引入依赖常用注解application.yml使用1. 启动类添加注解使用方法上添加注解 简介 Spring Cache是一个框架,实现了基于注解的缓存功能底层可以使用EHCache、Caffeine、Redis实现缓存。 注解一般放在Controller的方法上,CachePut 注解一般有…

【Java面试】十九、并发篇(下):线程池

文章目录 1、为什么要使用线程池2、线程池的执行原理2.1 七个核心参数2.2 线程池的执行原理 3、线程池用到的常见的阻塞队列有哪些4、如何确定核心线程数开多少个?5、线程池的种类有哪些?6、为什么不建议用Executors封装好的静态方法创建线程池7、线程池…

虚拟化 之七 详解构造带有 jailhouse 的 openEuler 系统

构造一个默认带有 jailhouse 的 openEuler 系统实际上就是创建一个包含 jailhouse 软件包的 openEuler 发行版,创建的过程在 x86 和 嵌入式平台差距很大,因此,本文我们分别进行详细介绍。 x86_64 平台 对于 x86_64 平台,如果手动从头创建(参考 Linux From Scratch)一个自…

Kubernetes 集群架构

etcd 集群状态存储:etcd 存储所有 Kubernetes 对象的状态,例如部署、pod、服务、配置映射和机密。配置管理:集群配置的更改存储在 etcd 中,允许 Kubernetes 管理和维护集群的所需状态。 注意:etcd 可能位于 kube-syst…

【ARM Cache 与 MMU/MPU 系列文章 2.1 -- 什么是 Cache PoP 及 PoDP ?】

请阅读【ARM Cache 及 MMU/MPU 系列文章专栏导读】 及【嵌入式开发学习必备专栏】 文章目录 PoP 及 PoDPCache PoDPCache PoP应用和影响PoP 及 PoDP Cache PoDP 点对深度持久性(Point of Deep Persistence, PoDP)是内存系统中的一个点,在该点达到的任何写操作即使在系统供电…

git下载项目登录账号或密码填写错误不弹出登录框

错误描述 登录账号或密码填写错误不弹出登录框 二、解决办法 控制面板\用户帐户\凭据管理器 找到对应的登录地址进行更新或者删除 再次拉取或者更新就会提示输入登录信息

Linux服务器安装Jupyter,并设置公网访问详细教程

本章教程,主要介绍如何在Linux服务器上安装jupyter,并可以通过公网地址进行访问。 一、安装jupyter pip install jupyter二、生成jupyter配置文件 jupyter notebook --generate-config三、编辑这个配置文件 找到配置文件并修改以下配置项: # 允许所有 IP 地址访问 c.Noteb…

nodejs湖北省智慧乡村旅游平台-计算机毕业设计源码 00232

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,旅游行业当然也不能排除在外。智慧乡村旅游平台是以实际运用为开发背景,运用软件工程开发方法,采…

【ai】blender4.1 安装插件

开源软件,所以资料充足插件及配置 下载插件插件是python开发的 编辑中的偏好设置 点击选中 点击一键切换中文英文 切换主题 插件源码

卸载MySQL5.0,安装MySQL8.0

卸载MySQL 1、以管理员身份运行cmd,删除MySQL服务 2、卸载MySQL 3、删除残余文件 4、清楚注册表 winR -> regedit 5、删除环境变量 安装MySQL步骤 官方下载地址 https://www.mysql.com/downloads/ 以上步骤即完成MySQL数据库安装。

XSKY 在金融行业:新一代分布式核心信创存储解决方案

近日,国家金融监督管理总局印发了《关于银行业保险业做好金融“五篇大文章”的指导意见》,在数字金融领域提出明确目标,要求银行业保险业数字化转型成效明显,数字化经营管理体系基本建成,数字化服务广泛普及&#xff0…

SaaS企业营销:如何通过联盟计划实现销售增长?

联盟营销计划在国外saas行业非常盛行,国内如何借鉴国外的成功案例运用联盟计划实现销售增长呢?林叔今天以最近新发现的leadpages为例分享下经验。 Leadpages是一款用户友好的落地页制作工具,提供多种预设计模板、A/B测试和分析功能&#xff0…

【嵌入式DIY实例】-Nokia 5110显示DHT11/DHT22传感器数据

Nokia 5110显示DHT11/DHT22传感器数据 文章目录 Nokia 5110显示DHT11/DHT22传感器数据1、硬件准备2、代码实现2.1 显示DHT11数据2.2 显示DHT22数据本文介绍如何将 ESP8266 NodeMCU 开发板 (ESP-12E) 与 DHT11 数字湿度和温度传感器以及诺基亚 5110 LCD 连接。 NodeMCU 从 DHT11…

EMI电路

PFC 功率部分 1 、整流桥是串联 2 、 PFC 电感串联 3 、二极管并联 4 、 MOSFET 并联 EMI电路图

2024年山西泵管阀展览会11月8日开展

2024中国(山西)国际水务科技博览会 暨水处理技术设备与泵管阀展览会 时间:2024年11月8-10日 地点:山西潇河国际会展中心 经研究,由山西省水处理行业协会、山西省水暖阀门商会、山西省固废产业协会联合主办的2024…

SpringBoot+MyBatis批量插入数据的三种方式

文章目录 1. 背景介绍2. 方案介绍2.1 第一种方案,用 for语句循环插入(不推荐)2.2 第二种方案,利用mybatis的foreach来实现循环插入(不推荐)2.3 第三种方案,使用sqlSessionFactory实现批量插入&a…

3.00003 postmaster守护线程的启动流程调用以及辅助流程的启动流程调用是怎样的

文章目录 架构图相关数据结构child_process_kinds[]数组 (launch_backend.c:179)相关函数main.cPostmasterMain(postmaster.c:489)ServerLoop(postmaster.c:1624)BackendStartup(postmaster.c:3544)postmaster_child_launch (launch_backend.c:265)StartChildProcess (postma…