【知识拓展】机器学习基础(二):什么是模型、自定义模型、模型训练、模型调优

news2024/11/22 9:38:15

前言

        接上文,前文对模型没有过多介绍,随着看的资料增多,对模型有了更多的自我认识,记录一下。要了解模型,我们先从零开始创建一个模型开始:

        最简单的方法是使用Python和scikit-learn库。关于scikit-learn库,在这做个简单介绍,类似的库和框架有很多如NumPy、Pandas、TensorFlow、PyTorch,这些不是本文重点,后续有必要再补充,这里不做详细介绍。

Scikit-learn

        Scikit-learn(以前称为scikits.learn,也称为sklearn)是一个强大的Python机器学习库,它集成了众多简单高效的机器学习算法,通过一套共用的接口进行调用,极大地方便了机器学习的应用和研究。

一个简单的案例

       以下是一个简单的代码示例,演示如何创建和训练一个线性回归模型来预测数据。主要是为了方便大家更好的理解模型。它包括了如下内容:

test.py

        主要包括准备数据集、创建和训练模型、评估模型性能、生成模型文件

# 安装scikit-learn库
# 运行以下命令来安装scikit-learn库,如果你还没有安装它:
# !pip install scikit-learn

# 导入必要的库
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import joblib

# 生成一个简单的数据集
# 假设我们有一些简单的线性数据,y = 2x + 1
X = np.array([[i] for i in range(10)])  # 特征(Feature)
y = np.array([2*i + 1 for i in range(10)])  # 标签(Label)

# 将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

# 使用测试集进行预测
y_pred = model.predict(X_test)

# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")


# 保存模型
joblib.dump(model, 'linear_regression_model.pkl')

test2.py

        加载模型,并输入内容,测试输出结果

from joblib import load

# 加载模型
model = load('linear_regression_model.pkl')

# 使用模型进行预测
# 假设你有一个名为X的输入数据,可以这样进行预测:
predictions = model.predict([[10],[11]])

print(predictions)

执行结果

        依次执行test.py、test2.py。结果如下:

        

案例分析

模型

         通过案例,我们可以了解到,模型相当于一个函数,输入后得到对应输出值,上述案例是一个最简单线性回归模型(y=2x+1)。

        上述模型尽管复杂性和性能可能与那些大规模数据集上训练的常见预训练模型规模、复杂度、性能有巨大差距,但本质上是一样的,也可以被看作一个预训练模型。

简单模型(我们的线性回归模型)
  1. 训练数据:在一个小型数据集上训练,可能只涉及非常简单的特征和目标。
  2. 模型复杂度:模型结构简单,如线性回归、简单的神经网络等。
  3. 初步训练:模型经过初步训练,学到了一些特定任务的特征。
  4. 持续提升:可以在更多的数据上继续训练,不断改进性能。
复杂预训练模型(如 BERT、ResNet)
  1. 训练数据:在大型、丰富的数据集上训练,如 ImageNet、Wikipedia 文本等。
  2. 模型复杂度:模型结构复杂,包含大量的层和参数,如深度卷积神经网络、变压器网络等。
  3. 初步训练:模型经过大规模数据集的训练,学到了广泛的特征和模式。
  4. 迁移学习:在特定任务上进行微调,通过较少的数据和训练时间达到高性能。

模型文件

        生成了一个linear_regression_model.pkl文件,它能够被加载并运行得到符合预期的结果。

.pkl 文件用于保存和加载机器学习模型,使得我们可以将训练好的模型持久化存储,通过使用 joblib 库生成(joblib 在内部使用了 pickle 进行序列化,但进行了优化以提高性能,特别是处理大量数据时。)。

        个人一句话总结,生成模型的过程模型的进行序列化和反序列化操作。

文件格式

        为什么格式是,pkl,和常见的不同?主要原因是使用的库或框架不同,因为不同的机器学习框架和应用场景对模型的存储和加载有着不同的需求。

        机器学习模型可以以多种格式保存,常见的包括.pkl(Pickle)、.pt(PyTorch)、.h5(HDF5)、ONNX(Open Neural Network Exchange)、PMML(Predictive Model Markup Language)以及.bin格式

        每种格式都有其特定的用途和优劣势。.pkl 文件通常用于保存任意 Python 对象,包括机器学习模型,但可能受到 Python 版本和库版本的影响。.pt 文件是 PyTorch 框架的标准格式,专门用于保存 PyTorch 模型。.h5 文件通常与 Keras 和 TensorFlow 框架一起使用,支持压缩和高性能读写。ONNX 格式允许在不同的深度学习框架之间转换模型。PMML 是一个用于表示数据挖掘和机器学习模型的通用 XML 格式。.bin 文件通常用于保存预训练的词嵌入模型等,以二进制格式存储,适用于特定的应用场景。选择模型保存格式应考虑工作流程和系统集成需求。

.pt

.bin

问题

        这个模型也太简单了,输出值等于输入值乘以2+1。这还有什么训练的必要吗?

        确实这个例子确实非常简单,因为我们直接定义了线性关系 y=2x+1。在这种情况下,我们实际上已经知道了模型的参数(系数为2,截距为1),不需要进行训练。但是,实际中的数据通常更加复杂和不确定。

        

        下面我们进行一个更为实际的问题:

        一个基于糖尿病数据集的例子,我们将使用线性回归模型来预测糖尿病进展。

案例二

        以下代码是一个基于糖尿病数据集的例子,我们将使用线性回归模型来预测糖尿病进展。步骤如下:

        ①加载糖尿病数据集,并将特征和目标变量分别存储在 Xy 中。

        ②数据集分割:将数据集分为训练集和测试集,测试集占20%。

        ③创建和训练模型:创建线性回归模型并使用训练数据进行训练。

        ④预测和评估:使用测试数据进行预测,并计算均方误差(MSE)和R^2得分来评估模型性能。

        ⑤输出模型参数:打印模型的系数和截距。

        ⑥可视化结果:绘制实际值与预测值的散点图,并添加一条45度线以便直观对比预测效果。

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

# 加载糖尿病数据集
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

# 将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

# 使用测试集进行预测
y_pred = model.predict(X_test)

# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse}")
print(f"R^2 Score: {r2}")

# 输出模型的参数
print(f"Model Coefficients: {model.coef_}")
print(f"Model Intercept: {model.intercept_}")

# 可视化预测结果
plt.scatter(y_test, y_pred)
plt.xlabel('Actual Progression')
plt.ylabel('Predicted Progression')
plt.title('Actual vs Predicted Diabetes Progression')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # 45度线
plt.show()

案例二分析

训练模型

        首先要明白一点,在机器学习中,“训练”指的是模型从数据中学习并调整内部参数的过程,以便能够对新的数据进行预测。这个过程是通过对训练数据进行迭代优化来完成的,模型会尝试不同的参数组合,以最大程度地减少预测误差。        

        上述运行结果中可以看出来,训练后得到的评估结果得分并不高。

        当你运行一次程序,即调用 model.fit(X_train, y_train) 进行训练时,模型确实会根据提供的训练数据来学习,并调整自身的参数。但是,模型学到的内容并不会存储在你的脑子里,而是存储在模型的内部参数中。当你调用 model.predict(X_test) 对新的数据进行预测时,模型会使用已经学到的参数来进行预测。

        每次运行程序,模型都会重新进行训练,即使你之前已经运行过相同的程序。这是因为计算机程序的运行是一个临时性的过程,程序结束后,模型的状态并不会被保留。要想在多次运行中保持模型的状态,你可以将训练好的模型保存到文件中,在需要时再加载使用。

        我们可以增加训练数据,想必会更精准。我们数据量目前无法变化,下面将训练数据比例增加,可以看到得分变高了。

        

模型调优

        如果我们选择的模型合适,经过大量的数据训练后,再评估时,得分会达到我们的预期,即大部分预测会准确,那么这个时候我们的模型就可以直接使用了。

        如果大量数据训练后,依旧达不到要求,例如上述训练后,得分依旧不高(因为我们没有数据了) 。

        这种情况下,我还想继续提高得分,应该怎么办?

        那就需要改进模型的性能,这就是我们常说的

        模型调优

        主要包括以下内容:

  1. 特征工程:优化特征工程可能是提高模型性能的有效方法。尝试添加新的特征、组合现有特征或者进行特征选择,以提供更丰富和更有信息量的特征。

  2. 模型调优:对于已有的模型,尝试调整其超参数或参数,以提高模型的性能。你可以使用交叉验证和网格搜索等技术来寻找最佳的参数组合。

  3. 集成学习:尝试集成多个模型的预测结果,如投票法、堆叠法等。集成学习可以通过组合多个模型的预测结果来减少方差,从而提高整体性能。

  4. 模型融合:将多个模型的预测结果进行加权平均或者其他组合方式,以获得更稳健的预测。模型融合可以结合不同模型的优点,进一步提高预测的准确性。

  5. 更多数据:尝试收集更多的数据来训练模型。更多的数据通常可以帮助模型更好地学习数据的模式,提高性能。

  6. 异常值处理:检测并处理异常值,以减少其对模型的影响,从而提高模型的鲁棒性和泛化能力。

模型选择

        当然,还有一个原因可能是你的模型并不适合这个场景,这个时候需要你更换模型。在实际应用中,选择合适的模型是一个挑战,特别是对于初学者来说,我们需要学习的有:

  1. 了解常用模型:学习一些常用的机器学习模型,了解它们的原理、优缺点和适用场景。这可以帮助你在需要时更快地选择合适的模型。

  2. 参考文献和教程:阅读机器学习领域的文献和教程,了解不同模型的特点和适用范围。你可以从书籍、博客、论文、在线课程等资源中获取信息。

  3. 交叉验证和网格搜索:使用交叉验证和网格搜索等技术来比较不同模型的性能。这些技术可以帮助你在给定数据集上评估多个模型,并选择性能最好的模型和参数组合。

  4. 集成学习:尝试集成学习方法,如随机森林和梯度提升树等。集成学习将多个基本模型的预测结果进行组合,通常可以获得比单个模型更好的性能。

  5. 领域知识:了解你所处理数据的领域知识也是选择模型的重要因素。有时候,根据数据的特点和背景知识,可以有针对性地选择合适的模型。

一些学习建议

        要想了解更多机器学习相关的内容,基础很重要以下推荐一些基础学习内容:

  • Coursera 课程 - 机器学习(Andrew Ng):

    • 这门课程由斯坦福大学教授 Andrew Ng 主讲,是深入理解机器学习基础的绝佳选择。课程涵盖了监督学习、无监督学习、深度学习等多个领域,并提供了丰富的编程作业来帮助学员巩固所学知识。
    • Coursera 课程链接
  • 书籍 - 《Python机器学习》(Python Machine Learning)(Sebastian Raschka, Vahid Mirjalili):

    • 这本书介绍了机器学习的基本理论和实践,以及如何使用 Python 和 scikit-learn 库实现常见的机器学习算法。书中包含大量示例代码和实践项目,适合初学者和有一定编程经验的人士。
    • 《Python机器学习》书籍链接
  • 网站 - Kaggle

    • Kaggle 是一个数据科学竞赛平台,提供丰富的数据集、教程和机器学习竞赛。你可以在 Kaggle 上找到很多与机器学习和数据科学相关的项目,学习他人的代码和技巧,还可以参与竞赛提升自己的能力。
    • Kaggle 网站链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1718836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iframe内嵌网页自适应缩放 以展示源网页的比例尺寸

需求:这是我最近开发的低代码平台遇到的需求 ,要求将配置好的应用在弹框中预览(将预览网页内嵌入弹框中) 但是内嵌进入后 他会截取一部分(我源网站网页尺寸 是1980x1080 或者 3060X2160等等) 但是我这个dialog弹框只有我自定义的1000多px的宽高 他只会展示我iframe网页的一部分…

Docker安装Zookeeper(单机)

Docker安装Zookeeper(单机) 目录 Docker安装Zookeeper(单机)拉取镜像创建目录添加配置文件启动容器测试 拉取镜像 docker pull zookeeper创建目录 mkdir -p /data/zookeeper/data # 数据挂载目录 mkdir -p /data/zookeeper/conf…

身份认证与口令攻击

身份认证与口令攻击 身份认证身份认证的五种方式口令认证静态口令动态口令(一次性口令)动态口令分类 密码学认证一次性口令认证S/KEY协议改进的S/KEY协议 其于共享密钥的认证 口令行为规律和口令猜测口令规律口令猜测 口令破解操作系统口令破解Windows密码存储机制Windows密码破…

一步将 CentOS 7.x 原地迁移至 RHEL 7.9

《OpenShift / RHEL / DevSecOps 汇总目录》 在《在离线环境中将 CentOS 7.X 原地迁移至 RHEL 7.9》一文中为了实现从 CentOS 7.X 原地迁移至 RHEL 7.9,我们第一步先将一个测试环境 CentOS 7.5 升级到 CentOS 7.9,然后在第二步使用 convert2rhel &…

太阳能语音警示杆在户外的应用及其作用

一、太阳能语音警示杆的主要应用领域 交通管理:在城市道路、乡村公路、高速公路等交通要道,太阳能语音警示杆可以用于提醒驾驶员注意前方路况、减速慢行或者避让施工区域。例如,在临时施工路段,警示杆可以播放“前方施工&#xf…

HTML语义化标签

<header> 主要用于网页整体顶部&#xff0c;<article>头部&#xff0c;<section>头部 <nav> 导航&#xff0c;一般有主要导航&#xff0c;路径导航&#xff0c;章节导航&#xff0c;内容目录导航 <main> 网页主要区域&#xff0c;一般一个网页…

Mysql基础教程(11):DISTINCT

MySQL DISTINCT 用法和实例 当使用 SELECT 查询数据时&#xff0c;我们可能会得到一些重复的行。比如学生表中有很多重复的年龄。如果想得到一个唯一的、没有重复记录的结果集&#xff0c;就需要用到 DISTINCT 关键字。 MySQL DISTINCT用法 在 SELECT 语句中使用 DISTINCT 关…

STM32高级控制定时器之输入捕获模式

目录 概述 1 输入捕获模式 1.1 原理介绍 1.2 实现步骤 1.3 发生输入捕获流程 2 使用STM32Cube配置工程 2.1 软件环境 2.2 配置参数 2.3 生成项目文件 3 功能实现 3.1 PWM调制占空比函数 3.2 应用函数库 4 测试 4.1 功能框图 4.2 运行结果 源代码下载地址&#xf…

chrome调试手机网页

前期准备 1、 PC端安装好chrmoe浏览器 2、 安卓手机安装好chrmoe浏览器 3、 数据线 原文地址&#xff1a;https://lengmo714.top/343880cb.html 手机打开调试模式 进入手机设置&#xff0c;找到开发者模式&#xff0c;然后启用USB调试 打开PC端chrome调试功能 1、点击chr…

部署专属网页版ChatGPT-Next-Web

背景 工作学习中经常使用chat-gpt, 需求是多端使用gpt问答&#xff0c;因此搭建一个网页版本方便多个平台使用。最后选择了 ChatGPT-Next-Web 部署说明 一键部署自己的web页面&#xff0c;因为是使用免费的vercel托管的&#xff0c;vercel节点在全球都有&#xff0c;理论上突…

OAK相机如何将 YOLOv10 模型转换成 blob 格式?

编辑&#xff1a;OAK中国 首发&#xff1a;oakchina.cn 喜欢的话&#xff0c;请多多&#x1f44d;⭐️✍ 内容可能会不定期更新&#xff0c;官网内容都是最新的&#xff0c;请查看首发地址链接。 Hello&#xff0c;大家好&#xff0c;这里是OAK中国&#xff0c;我是Ashely。 专…

NVIDIA Blackwell Architecture

本文翻译自&#xff1a;NVIDIA Blackwell Architecture https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/ 文章目录 了解技术突破1、新型人工智能超级芯片2、第二代 Transformer 引擎3、Secure AI4、NVLink 和 NVLink 交换机5、解压缩引擎6、可…

景源畅信数字:抖音新手如何找好自己的发布领域?

在短视频的浪潮中&#xff0c;抖音以其独特的魅力吸引了众多用户。对于刚踏入这个平台的新手来说&#xff0c;找到适合自己的发布领域至关重要。那么&#xff0c;如何在这个充满竞争的平台上找到自己的定位呢?接下来&#xff0c;就让我们一起来探讨这个问题。 一、明确兴趣爱好…

Java18+ springboot+mysql +Thymeleaf 技术架构开发的全套同城服务家政上门系统源码(APP用户端+APP服务端+PC管理端)

Java springbootmysql Thymeleaf 技术架构开发的全套同城服务家政上门系统源码&#xff08;APP用户端APP服务端PC管理端&#xff09; 家政上门预约系统&#xff1a;该系统综合运用springboot、java1.8、vue移动支付、微信授权登录等技术&#xff0c;由用户小程序、站长小程序、…

AI网络爬虫:无限下拉滚动页面的另类爬取方法

现在很多网页都是无限下拉滚动的。可以拉动到底部&#xff0c;然后保存网页为mhtml格式文件。 接着&#xff0c;在ChatGPT中输入提示词&#xff1a; 你是一个Python编程高手&#xff0c;要完成一个关于爬取网页内容的Python脚本的任务&#xff0c;下面是具体步骤&#xff1a; …

利用依赖结构矩阵管理架构债务

本文讨论了如何利用依赖结构矩阵&#xff08;DSM&#xff0c;Dependency Structure Matrix&#xff09;管理和识别架构债务&#xff0c;并通过示例应用展示了这一过程。原文: Managing Architecture Debt with Dependency Structure Matrix Vlado Paunovic Unsplash 技术债务&a…

imx6ull - 制作烧录SD卡

1、参考NXP官方的手册《i.MX_Linux_Users_Guide.pdf》的这一章节&#xff1a; 1、SD卡分区 提示&#xff1a;我们常用的SD卡一个扇区的大小是512字节。 先说一下i.MX6ULL使用SD卡启动时的分区情况&#xff0c;NXP官方给的镜像布局结构如下所示&#xff1a; 可以看到&#xff0c…

simulink基础学习笔记

写在前面 这个笔记是看B站UP 快乐的宇航boy 所出的simulink基础教程系列视频过程中记下来的&#xff0c;写的很粗糙不完整&#xff0c;也不会补。视频教程很细跟着做就行。 lesson1-7节的笔记up有&#xff0c;可以加up的群&#xff0c;里面大佬挺活跃的。 lesson8 for循环 For …

【项目管理知识】项目质量管理措施

1、持续改进&#xff08;PDCA&#xff09; 戴明循环或称PDCA循环、PDSA循环。戴明循环的研究起源于20世纪20年代&#xff0c;先是有着“统计质量控制之父”之称的著名的统计学家沃特阿曼德休哈特&#xff08;Walter A. Shewhart&#xff09;在当时引入了“计划-执行-检查&…

统计各个商品今年销售额与去年销售额的增长率及排名变化

文章目录 测试数据需求说明需求实现分步解析 测试数据 -- 创建商品表 DROP TABLE IF EXISTS products; CREATE TABLE products (product_id INT,product_name STRING );INSERT INTO products VALUES (1, Product A), (2, Product B), (3, Product C), (4, Product D), (5, Pro…