Sklearn入门

news2025/1/10 11:29:25

Scikit learn 也简称 sklearn, 是机器学习领域当中最知名的 python 模块之一.

Sklearn 包含了很多种机器学习的方式:

  1. Classification 分类
  2. Regression 回归
  3. Clustering 非监督分类
  4. Dimensionality reduction 数据降维
  5. Model Selection 模型选择
  6. Preprocessing 数据预处理

我们总能够从这些方法中挑选出一个适合于自己问题的, 然后解决自己的问题。

一般使用

看图选择学习方法

https://scikit-learn.org/stable/tutorial/machine_learning_map/index.html
从 START 开始,首先看数据的样本是否 >50,小于则需要收集更多的数据。

由图中,可以看到算法有四类,分类,回归,聚类,降维。

其中 分类和回归是监督式学习,即每个数据对应一个 label。 聚类 是非监督式学习,即没有 label。 另外一类是 降维,当数据集有很多很多属性的时候,可以通过 降维 算法把属性归纳起来。例如 20 个属性只变成 2 个,注意,这不是挑出 2 个,而是压缩成为 2 个,它们集合了 20 个属性的所有特征,相当于把重要的信息提取的更好,不重要的信息就不要了。

然后看问题属于哪一类问题,是分类还是回归,还是聚类,就选择相应的算法。 当然还要考虑数据的大小,例如 100K 是一个阈值。

可以发现有些方法是既可以作为分类,也可以作为回归,例如 SGD。

通用学习模式

要点

Sklearn 把所有机器学习的模式整合统一起来了,学会了一个模式就可以通吃其他不同类型的学习模式。

例如,分类器,

Sklearn 本身就有很多数据库,可以用来练习。 以 Iris 的数据为例,这种花有四个属性,花瓣的长宽,茎的长宽,根据这些属性把花分为三类。

我们要用 分类器 去把四种类型的花分开。

在这里插入图片描述

今天用 KNN classifier,就是选择几个临近点,综合它们做个平均来作为预测值。

导入模块

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

创建数据

iris=datasets.load_iris()
iris_x=iris.data
iris_y=iris.target
# 观察数,x有4个属性,y有3类
print(iris_x[:5, :])
print(iris_x.shape)
print(iris_y)
# 把数据集分为训练集和测试集
# (被分开后的数据集顺序会自动被打乱)
x_train,x_test,y_train,y_test=train_test_split(iris_x,iris_y,test_size=0.2)
print(x_test.shape)

在这里插入图片描述

建立模型-训练-预测

knn=KNeighborsClassifier()
knn.fit(x_train,y_train)
print(knn.predict(x_test))
# 将训练结果与真实值对比一下
print(y_test)

在这里插入图片描述

sklearn强大数据库datasets

要点

Scikit-learn(sklearn)是一个用于机器学习和数据挖掘的Python库,它包含了许多用于数据处理、特征工程、模型建立和评估等任务的工具和函数。datasets模块是scikit-learn中的一个重要模块,用于加载一些常用的数据集,以便进行实验和模型训练。以下是关于datasets模块的一些要点:

  1. 数据集的获取:datasets模块包含了一些经典的机器学习数据集,例如鸢尾花数据集(Iris)、手写数字数据集(Digits)、波士顿房价数据集(Boston Housing Prices)等。这些数据集可以帮助用户快速开始机器学习任务。

  2. 数据集的加载:使用sklearn.datasets.load_*()函数可以加载特定数据集。例如,load_iris()可以加载鸢尾花数据集,load_digits()可以加载手写数字数据集。加载后的数据通常是一个类似字典的数据结构,包括数据、标签、特征名称等信息。

  3. 数据集的描述:通过DESCR属性可以获取数据集的描述信息,包括数据集的特征、标签、来源等。

  4. 数据集的划分:datasets模块还提供了用于数据集划分的函数,例如train_test_split(),可以将数据集划分为训练集和测试集,用于模型的训练和评估。

  5. 示例数据集:datasets模块中还包含一些小规模的示例数据集,用于演示和测试代码。例如,make_classification()make_regression()函数可以生成分类和回归问题的合成数据集。

  6. 外部数据集加载:除了加载scikit-learn内置的数据集之外,datasets模块还提供了函数来加载外部数据集,例如fetch_openml()用于从OpenML数据库加载数据。

  7. 数据集的用途:datasets模块中的数据集通常用于机器学习算法的示例、实验和教育目的。用户可以使用这些数据集来学习和测试不同的机器学习技术,以便更好地理解和应用机器学习。

总的来说,sklearn.datasets模块提供了一个便捷的方式来访问和加载各种常见的数据集,以便于机器学习任务的实验和研究。通过使用这些数据集,用户可以更轻松地开始构建、训练和评估机器学习模型。

#自己生成数据
from sklearn import datasets

X, y = datasets.make_regression(n_samples=100, n_features=100, n_informative=10, n_targets=1, bias=0.0, effective_rank=None, tail_strength=0.5, noise=0.0, shuffle=True, coef=False, random_state=None)
print(X)  # 打印特征数据
print(y)  # 打印目标数据

在这里插入图片描述

导入模块

from sklearn import datasets
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

导入数据-训练模型

loaded_data = load_diabetes()
data_x = loaded_data.data
data_y = loaded_data.target

model = LinearRegression()
model.fit(data_x, data_y)

predictions = model.predict(data_x[:4, :])
actual_values = data_y[:4]

print("Predictions:", predictions)
print("Actual Values:", actual_values)

在这里插入图片描述

sklearn的常用属性与功能

我们以load_diabetes 数据集,它是一个用于糖尿病预测的回归数据集进行分析

from sklearn import datasets
from sklearn.linear_model import LinearRegression

# 加载糖尿病预测数据集
loaded_data = datasets.load_diabetes()
data_X = loaded_data.data
data_y = loaded_data.target

# 创建线性回归模型
model = LinearRegression()

训练和预测

# 拟合数据
model.fit(data_X, data_y)

# 进行预测
predictions = model.predict(data_X[:4, :])
actual_values = data_y[:4]

print("Predictions:", predictions)
print("Actual Values:", actual_values)

在这里插入图片描述

参数和分数

然后,model.coef_model.intercept_ 属于 Model 的属性, 例如对于 LinearRegressor 这个模型,这两个属性分别输出模型的斜率和截距(与y轴的交点)。

print(model.coef_)
print(model.intercept_)

在这里插入图片描述
model.get_params() 是 scikit-learn 中用于获取模型超参数(hyperparameters)的方法。超参数是在训练模型之前设置的参数,它们不是由模型学习而来,而是在模型训练之前由用户指定的。这些超参数可以影响模型的行为和性能。

print(model.get_params())

在这里插入图片描述

model.score(data_X, data_y) 它可以对 Model 用 R 2 R^2 R2 的方式进行打分,输出精确度。

print(model.score(data_X, data_y))

在这里插入图片描述

高级使用

正则化Normalization

数据标准化

from sklearn import preprocessing
import numpy as np
a=np.array([[10,2.7,3.6],
            [-100,5,-2],
            [120,20,40]],dtype=np.float64)
print(preprocessing.scale(a))

在这里插入图片描述

数据标准化对机器学习成效的影响

from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.datasets._samples_generator import make_classification
from sklearn.svm import SVC
import matplotlib.pyplot as plt

X, y = make_classification(n_samples=300,n_features=2,
                           n_redundant=0,n_informative=2,
                           random_state=22,n_clusters_per_class=1,
                           scale=100)
# print(X)
# print(y)
plt.scatter(X[:,0],X[:,1],c=y)
plt.show()

X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)
clf=SVC()
clf.fit(X_train,y_train)
print(clf.score(X_test,y_test))

X=preprocessing.scale(X)
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)
clf=SVC()
clf.fit(X_train,y_train)
print(clf.score(X_test,y_test))

plt.scatter(X[:,0],X[:,1],c=y)
plt.show()
  • X, y = make_classification(n_samples=300,n_features=2, n_redundant=0,n_informative=2, random_state=22,n_clusters_per_class=1, scale=100)
    这段代码是使用Python中的make_classification函数来生成一个用于分类问题的模拟数据集。让我一步一步解释这个代码的各个参数和作用:
  1. X, y = make_classification(...)
    这一行代码创建了两个变量Xy,分别用于存储生成的模拟数据的特征和标签。
  2. n_samples=300
    这个参数指定了生成的数据集中样本的数量,本例中为300个样本。
  3. n_features=2
    这个参数指定了生成的数据集中的特征数量,本例中为2个特征。这表示每个样本有两个特征。
  4. n_redundant=0
    这个参数指定了生成的特征中不相关的特征的数量。在本例中,设置为0,表示没有不相关的特征。所有的特征都与分类任务相关。
  5. n_informative=2
    这个参数指定了生成的特征中信息相关的特征的数量。在本例中,设置为2,表示有2个特征与分类任务相关,这些特征包含了分类信息。
  6. random_state=22
    这个参数是随机数生成器的种子,用于确保每次运行代码时生成的模拟数据都是相同的。设置不同的种子会得到不同的数据集。
  7. n_clusters_per_class=1
    这个参数指定了每个类别中的聚类数量。在本例中,每个类别只有一个聚类,即没有内部聚类结构。
  8. scale=100
    这个参数用于控制生成的数据的尺度(缩放因子)。它会影响特征的值的范围。在本例中,特征值被放大了100倍,以增加数据的差异性和散布。

总的来说,这段代码生成了一个包含300个样本和2个特征的模拟数据集,这些特征中包含了用于分类任务的相关信息,没有不相关的特征。这个数据集适用于用于测试和演示分类算法的目的,因为我们知道数据的生成方式和真实的标签信息。

plt.scatter(X[:,0], X[:,1], c=y) 这行代码使用了 Matplotlib 库来创建散点图(scatter plot)。让我解释一下这行代码的各个部分:

  • plt 是 Matplotlib 库的别名,通常用于创建各种类型的图形和可视化。
  • scatter 是 Matplotlib 中用于创建散点图的函数。
    接下来,是函数的参数:
  • X[:,0]X[:,1] 是 NumPy 数组 X 的切片操作,用于获取数据集中的特征。X[:,0] 表示取所有行的第一个特征,X[:,1] 表示取所有行的第二个特征。这通常用于表示数据集中的两个特征,其中第一个特征在 x 轴上,第二个特征在 y 轴上。
  • c=y 是设置散点颜色的参数。y 是标签数组,它包含了每个数据点的类别或标签信息。通过将 c 参数设置为 y,您可以指定每个数据点的颜色,使散点图能够根据类别或标签来着色。每个类别通常使用不同的颜色。

因此,这行代码的作用是创建一个散点图,其中 x 轴和 y 轴分别表示数据集中的两个特征(X[:,0] 和 X[:,1]),而散点的颜色由标签数组 y 决定,从而可以可视化数据点在特征空间中的分布情况以及它们的类别信息。这种可视化通常用于数据探索和分类问题的可视化分析。

在这里插入图片描述
在这里插入图片描述

(这个结果我不太理解,为什么归一化大部分情况下反而会少一些,按理来说应该归一化后的效果会好,后面再想想吧。郁闷……)
在这里插入图片描述

检验神经网络(Evaluation)

Training and Test data

训练数据和测试数据就好比是我们在学习和考试时的两个不同阶段。

  1. 训练数据:这就像我们在学校上课时学习知识的阶段。在这个阶段,我们学习了各种知识和技能,例如数学、历史、语言等。机器学习模型也需要学习,所以我们用一部分数据来教它,让它知道如何做出正确的决策。这部分数据叫做训练数据。

  2. 测试数据:一旦我们学习了知识,就要进行考试来测试自己是否真的学会了。同样,机器学习模型也需要测试,以确保它学到了正确的东西。这时,我们用一部分不同的数据来测试模型,看看它是否能够正确地回答问题或者做出预测。这部分数据叫做测试数据。

训练数据和测试数据的分开使用是为了确保模型不仅能够背诵记住数据(所谓的"死记硬背"),还能够理解数据背后的规律,从而在未来遇到新问题时也能够做出正确的决策。就像我们学校学习知识并参加考试一样,机器学习模型需要训练和测试才能变得聪明和可靠。这个过程有助于我们确保模型在真实世界中的表现良好。

误差曲线

误差曲线是用来帮助我们了解机器学习模型的性能如何随着不同情况的变化而变化的一种图形。通俗来说,它是一条曲线,告诉我们模型在不同情况下犯错误的程度。

让我们用一个例子来说明:

假设你是一名投篮手,你想知道在不同距离和角度下,你的投篮命中率如何变化。你决定进行实验,记录每次投篮的结果,并将其可视化为一条曲线。

  • 横轴表示投篮的距离,从近到远。
  • 纵轴表示命中率,从低到高。

你开始投篮,记录下了每一次投篮的结果。随着距离的增加,你的命中率逐渐下降,这就是误差曲线的趋势。这条曲线告诉你,当距离较近时,你的命中率较高,但当距离增加时,你的命中率下降了。

在机器学习中,我们使用误差曲线来了解模型在不同情况下的表现。例如,对于分类问题,我们可以绘制一个误差曲线,其中横轴表示模型的复杂度(例如,决策树的深度),纵轴表示错误率(模型犯错的次数)。通过观察这条曲线,我们可以找到模型的最佳复杂度,使其在不过度拟合或欠拟合的情况下表现最好。

总之,误差曲线是一种图形工具,帮助我们理解模型在不同条件下的性能变化,就像投篮手通过曲线了解了他的命中率如何随着距离的变化而变化一样。这有助于我们选择最适合问题的机器学习模型或参数设置。

准确度曲线

准确度曲线是用来评估机器学习模型在不同情况下的表现的一种图形工具。通俗来说,它是一条曲线,告诉我们当我们改变模型的某个设置或参数时,模型的准确度(正确预测的比例)会如何变化。

让我们用一个例子来说明:

假设你是一名厨师,你想知道在不同温度下烤蛋糕的时间对于最终的美味程度有何影响。你决定进行一系列实验,每次在不同温度下烤蛋糕,并记录每个蛋糕的美味度。然后,你将这些数据可视化为一条曲线。

  • 横轴表示烤蛋糕的温度,从低到高。
  • 纵轴表示美味度,从低到高。

你开始烤蛋糕,记录下了每个温度下蛋糕的美味度。随着温度的升高,蛋糕的美味度也逐渐上升,这就是准确度曲线的趋势。这条曲线告诉你,在不同温度下烤蛋糕的时间会影响蛋糕的美味度,高温下更容易得到美味的蛋糕。

在机器学习中,我们使用准确度曲线来了解模型在不同条件下的表现。例如,对于分类问题,我们可以绘制一个准确度曲线,其中横轴表示模型的某个参数的取值范围,纵轴表示模型的准确度。通过观察这条曲线,我们可以找到模型在哪个参数值下表现最佳,从而帮助我们选择最合适的模型设置。

总之,准确度曲线是一种工具,帮助我们理解模型在不同情况下的性能变化,就像厨师通过曲线了解了烤蛋糕温度和美味度之间的关系一样。这有助于我们优化模型的参数或设置,以提高其性能。

正则化

防止过拟合
看训练集和测试集的结果对比

神经网络的调参

逐渐增加神经层并绘制最终误差或精度的图表是一种常见的方法,用于评估不同神经网络结构的性能。这可以帮助你确定哪种网络结构在你的特定问题上表现最佳。以下是一般的步骤:

  1. 选择初始网络结构:首先,选择一个适度的初始神经网络结构,包括层数、每层的神经元数量、激活函数等。这个结构将作为基准模型。

  2. 定义一系列的网络结构:确定一系列不同的神经网络结构,可以逐渐增加层数、减少层数、改变神经元数量等。这些结构将构成你的实验组。

  3. 训练和评估:对于每个不同的网络结构,执行以下步骤:

    • 使用训练集训练神经网络。
    • 使用验证集评估模型的性能。你可以计算损失函数的值(误差)或其他性能指标(如准确度)。
  4. 记录结果:为每个不同的网络结构记录最终的误差或精度。你可以将这些结果保存在一个列表或数组中,以备后续绘制。

  5. 绘制图表:使用图表库(如Matplotlib)绘制一个图表,横轴表示不同的网络结构,纵轴表示最终的误差或精度。你可以使用折线图、柱状图或其他适当的图表类型。

  6. 分析结果:通过观察图表,你可以确定哪种网络结构在验证集上表现最佳。这可以帮助你选择最佳的神经网络结构,以进行进一步的训练和测试。

请注意,这个过程可能需要一些时间,因为对每个不同的网络结构进行训练和评估需要一些计算资源。此外,你还应该小心过拟合问题,确保验证集的性能能够反映模型在未见过的数据上的泛化能力。

交叉验证1 Cross-validation

Model基础验证法

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

iris=load_iris()
x=iris.data
y=iris.target

x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=4)

knn=KNeighborsClassifier()
knn.fit(x_train,y_train)

print(knn.score(x_test,y_test))

random_state 参数用于控制数据集分割的随机性,以确保实验的可重复性和结果的一致性。如果你希望在不同运行中获得相同的分割结果,可以使用相同的 random_state 值。

在这里插入图片描述

Model交叉验证法

from sklearn.model_selection import cross_val_score

# 使用K折交叉验证模块
scores = cross_val_score(knn, x, y, cv=5, scoring='accuracy')

# 将5次的预测准确率打印出
print(scores)

# 将5次的预测准确率平均值打印出
print(scores.mean())

cv=5 指定了 K 折交叉验证中的 K 值,也就是将数据分成 5 个不同的子集来进行验证。这表示数据将被分为 5 份,每次使用其中一份作为测试集,其余四份作为训练集,然后重复这个过程 5 次,每次选择不同的测试集。
scoring='accuracy' 指定了用于评估模型性能的指标。在这里,使用准确度(accuracy),这是分类问题中常用的评估指标,表示模型正确分类的样本比例。
scores 是一个包含了 5 次交叉验证中每次的准确度得分的数组。每个元素表示模型在一个测试集上的准确度。

在这里插入图片描述

以准确率判断

import matplotlib.pyplot as plt

k_range=range(1,31)
k_score=[]

for k in k_range:
    knn=KNeighborsClassifier(n_neighbors=k)
    score=cross_val_score(knn,x,y,cv=10,scoring='accuracy')
    k_score.append(score.mean())
    
plt.plot(k_range,k_score)
plt.xlabel('k')
plt.ylabel('score')
plt.show()

n_neighbors 是 K 近邻算法中的一个超参数,用于指定在分类时考虑多少个最近邻居的类别。在这行代码中,k 是一个变量,可以代表任何整数值。通常,你会根据问题的性质和数据的特点来选择合适的 k 值。选择合适的 k 值对于 K 近邻算法的性能至关重要,因为它会直接影响到分类的准确性。

在这里插入图片描述

以平均方差判断

import matplotlib.pyplot as plt

k_range=range(1,31)
k_score=[]

for k in k_range:
    knn=KNeighborsClassifier(n_neighbors=k)
    loss=-cross_val_score(knn,x,y,cv=10,scoring='neg_mean_squared_error')
    k_score.append(loss.mean())

plt.plot(k_range,k_score)
plt.xlabel('k')
plt.ylabel('MSE')
plt.show()

scoring='neg_mean_squared_error' 表示使用均方误差的负值作为评估模型性能的指标。在 scikit-learn 中,评分指标通常是越高越好,因此负均方误差用于将均方误差的度量转化为一个与其他指标一致的度量。

在这里插入图片描述

交叉验证2 Cross-validation(Learning curve 检视过拟合)

from sklearn.model_selection import learning_curve #学习曲线模块
from sklearn.datasets import load_digits #digits数据集
from sklearn.svm import SVC #Support Vector Classifier
import matplotlib.pyplot as plt #可视化模块
import numpy as np

digits=load_digits()
x=digits.data
y=digits.target

train_sizes,train_loss,test_loss=learning_curve(SVC(gamma=0.001),x,y,cv=10,scoring='neg_mean_squared_error',
                                                train_sizes=[0.1,0.25,0.5,0.75,1])

train_loss_mean=-np.mean(train_loss,axis=1)
test_loss_mean=-np.mean(test_loss,axis=1)

plt.plot(train_sizes, train_loss_mean, 'o-', color="r",
         label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",
        label="Cross-validation")

plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

train_sizes=[0.1, 0.25, 0.5, 0.75, 1] 表示计算学习曲线时会使用不同比例的训练集,分别为总样本数量的 10%、25%、50%、75% 和 100%。这意味着会计算模型在不同规模的训练集上的性能。

print(train_sizes)
[ 161  404  808 1212 1617]

axis 参数是 NumPy 中的一个重要参数,它用于指定在进行数组操作时沿着哪个轴执行操作。在这个代码中,axis=1 用于计算均方误差的平均值时,表示沿着第二个轴(即列轴)进行操作。
具体来说,这行代码的目的是计算训练集上均方误差的平均值,而 axis=1 表示在每一行(样本)上执行均方误差的平均值计算。也就是说,对于每个训练集大小,train_loss 中的每一行都包含了一个样本集上的均方误差。通过使用 axis=1np.mean(train_loss, axis=1) 将计算每个训练集大小对应的所有样本的均方误差的平均值。
简而言之,axis=1 指定了在每个行上执行操作,以计算每个训练集大小的均方误差平均值。这是因为 train_loss 是一个二维数组,每行代表一个训练集大小,每列代表一个样本。

在这里插入图片描述

交叉验证3 Cross-validation(validation_curve 检视过拟合)

from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits #digits数据集
from sklearn.svm import SVC #Support Vector Classifier
import matplotlib.pyplot as plt #可视化模块
import numpy as np

#digits数据集
digits = load_digits()
X = digits.data
y = digits.target

#建立参数测试集
param_range = np.logspace(-6, -2.3, 5)

#使用validation_curve快速找出参数对模型的影响
train_loss, test_loss = validation_curve(
    SVC(), X, y, param_name='gamma', param_range=param_range, cv=10, scoring='neg_mean_squared_error')

#平均每一轮的平均方差
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

#可视化图形
plt.plot(param_range, train_loss_mean, 'o-', color="r",
         label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",
        label="Cross-validation")

plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

param_range = np.logspace(-6, -2.3, 5) 这行代码用于创建一个包含 5 个值的数组 param_range,这些值是在对数刻度上均匀分布的。即再 10 的 -6 次方到 10 的 -2.3 次方之间均匀取值
#[1.00000000e-06 8.41395142e-06 7.07945784e-05 5.95662144e-04 5.01187234e-03]

param_name='gamma' 是用于 validation_curve 函数的一个参数,用于指定要在验证曲线分析中调整的超参数的名称。在这个特定的示例中,param_name 被设置为 'gamma',意味着你正在调整支持向量机(SVM)模型中的 gamma 超参数。

  • gamma 是 SVM 模型中的一个重要超参数,控制了决策边界的复杂度。具体来说,它影响了样本点被视为支持向量的权重。较小的 gamma 值使决策边界更加平滑,而较大的 gamma 值使决策边界更加复杂和灵活。
  • validation_curve 函数会在指定的超参数范围内,对指定的超参数(在这里是 'gamma')进行多次模型训练和评估,以了解不同超参数取值对模型性能的影响。这种分析通常用于超参数调优,以选择最佳的超参数值,以获得最佳的模型性能。

因此,在这个代码中,param_name='gamma' 告诉 validation_curve 函数,你希望调整的是 SVM 模型的 gamma 超参数,然后它将生成一组不同 gamma 值的模型,分析它们在验证集上的性能,以帮助你选择最合适的 gamma 值。

在这里插入图片描述

保存模型

先简单的训练一个SVC模型

from sklearn import svm
from sklearn import datasets

clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X,y)

使用pickle保存

import pickle
#保存Model(注:save文件夹要预先建立,否则会报错)
with open('save/clf.pickle','wb')as f:
    pickle.dump(clf,f)
# 读取
with open('save/clf.pickle','rb')as f:
    clf2=pickle.load(f)
    print(clf2.predict(X[0:100]))

这段代码演示了如何使用 Python 中的 pickle 模块保存和加载机器学习模型。

  1. 保存模型
    • pickle 模块用于序列化 Python 对象,包括机器学习模型。
    • with open('save/clf.pickle', 'wb') as f 打开一个文件 'save/clf.pickle',准备将模型保存到这个文件中。
    • pickle.dump(clf, f) 将机器学习模型 clf 保存到文件中。clf 是之前已经训练好的模型,它会被保存到文件 'save/clf.pickle' 中。
  2. 读取模型
    • with open('save/clf.pickle', 'rb') as f 打开之前保存的文件 'save/clf.pickle',准备加载模型。
    • clf2 = pickle.load(f) 从文件中加载模型,并将其存储在 clf2 变量中。
    • print(clf2.predict(X[0:100])) 使用加载的模型 clf2 对新数据 X[0:100] 进行预测。

这段代码的作用是将一个已训练好的机器学习模型保存到文件中,以便以后可以轻松地加载它并在新数据上进行预测。这在实际应用中很有用,因为你可以在训练模型之后将其保存,以避免每次都重新训练模型。然后,你可以随时加载模型并在需要时进行预测。

在这里插入图片描述

使用joblib保存

import joblib
#保存Model(注:save文件夹要预先建立,否则会报错)
joblib.dump(clf,'save/clf.pkl')

clf3=joblib.load('save/clf.pkl')
print(clf3.predict(X[0:100]))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1068214.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java-包装类

这里写目录标题 包装类(Wrapper)包装类和基本数据的转换 String VS StringBuffer VS StringBuilderStringStringBufferStringBuilder 包装类(Wrapper) 针对八种基本数据类型相应的引用类型 基本数据类型包装类booleanBooleancha…

C++设计模式(1)-- 单例模式

基本概念 在一个项目中,全局范围内,某个类的实例有且仅有一个,通过这个唯一实例向其他模块提供数据的全局访问,这种模式就叫单例模式,单例模式的典型应用就是任务队列 涉及一个类多对象操作的函数有以下几个&#xff…

嵌入式基础知识-IP地址与子网划分

本篇介绍IP地址与子网划分的一些基础知识,在嵌入式开发,使用网络功能时,需要了解网络的一些基础知识。 1 IP地址 1.1 IPv4与IPv6 对比信息IPv4IPv6长度32位128位地址表示形式点分十进制冒分十六进制表示示例192.168.5.1002002:0000:0000:0…

this关键字在不同上下文中的值是如何确定的?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

最全解决docker配置kibana报错 Kibana server is not ready yet

问题复现: 在浏览器输入http://192.168.101.65:5601/ 访问kibana报错 Kibana server is not ready yet 问题报错: 首先查看kibana的日志 docker logs kibana 看到报错如下: {"type":"log","timestamp":&q…

【小笔记】复杂模型小数据可能会造成过拟合还是欠拟合?

【学而不思则罔,思而不学则殆】 10.8 问题 针对这个问题,我先问了一下文心一言 它回答了为什么会过拟合和欠拟合,但并没有回答我给的场景。 简单分析 分析模型 复杂模型就表示模型的拟合能力很强,对于数据中特征&#xff08…

如何保证 RabbitMQ 的消息可靠性?

项目开发中经常会使用消息队列来完成异步处理、应用解耦、流量控制等功能。虽然消息队列的出现解决了一些场景下的问题,但是同时也引出了一些问题,其中使用消息队列时如何保证消息的可靠性就是一个常见的问题。如果在项目中遇到需要保证消息一定被消费的…

Mybatis 拦截器(Mybatis插件原理)

Mybatis为我们提供了拦截器机制用于插件的开发,使用拦截器可以无侵入的开发Mybatis插件,Mybatis允许我们在SQL执行的过程中进行拦截,提供了以下可供拦截的接口: Executor:执行器ParameterHandler:参数处理…

深入解析PostgreSQL:命令和语法详解及使用指南

文章目录 摘要引言基本操作安装与配置连接和退出 数据库操作创建数据库删除数据库切换数据库 表操作创建表删除表插入数据查询数据更新数据删除数据 索引和约束创建索引创建约束 用户管理创建用户授权用户修改用户密码 备份和恢复备份数据库恢复数据库 高级特性结语参考文献 摘…

在win10里顺利安装了apache2.4.41和php7.4.29以及mysql8.0.33

一、安装apache和php 最近在学习网站搭建。其中有一项内容是在windows操作系统里搭建apachephp环境。几天前根据一本书的上的说明尝试了一下,在win10操作系统里安装这两个软件:apache2.4.41和php7.4.29,安装以后apche能正常启动,…

【转载】LLM-Native 产品的变与不变

1. LLM-Native:AGI 的另一种路径 《银河系漫游指南》的作者——道格拉斯亚当斯曾经对「技术」一词做出这样一种解释: 「技术」是描述某种尚未发挥作用的东西的词汇。 这是一个充满实用主义的定义,这句话可以被更直观地表述为:当…

机器学习7:pytorch的逻辑回归

一、说明 逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。 二项式逻辑回归的一个例子是预测人…

安卓玩机----解锁system分区 可读写系统分区 magisk面具模块

玩机教程----安卓机型解锁system分区 任意修改删除系统文件 system分区可读写 参考上个博文可以了解到解锁system分区的有关常识。但目前很多机型都在安卓12 13 基础上。其实最简单的方法就在于刷写一个解锁system分区的第三方补丁包。在面具更新不能解锁系统分区的前提下。…

8.2 JUC - 5.CountdownLatch

目录 一、是什么?二、demo演示三、应用之同步等待多线程准备完毕四、 应用之同步等待多个远程调用结束五、CountDownLatch 原理 一、是什么? CountdownLatch 用来进行线程同步协作,等待所有线程完成倒计时。 其中构造参数用来初始化等待计数…

C#,数值计算——数据建模Fitab的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// Fitting Data to a Straight Line /// </summary> public class Fitab { private int ndata { get; set; } private double a { get; set; } …

RabbitMQ之Fanout(扇形) Exchange解读

目录 基本介绍 适用场景 springboot代码演示 演示架构 工程概述 RabbitConfig配置类&#xff1a;创建队列及交换机并进行绑定 MessageService业务类&#xff1a;发送消息及接收消息 主启动类RabbitMq01Application&#xff1a;实现ApplicationRunner接口 基本介绍 Fa…

跨域请求方案整理实践

项目场景&#xff1a; 调用接口进行手机验证提示,项目需要调用其它域名的接口,导致前端提示跨域问题 问题描述 前端调用其他域名接口时报错提示: index.html#/StatisticalAnalysisOfVacancy:1 Access to XMLHttpRequest at http://xxxxx/CustomerService/template/examineMes…

openGauss学习笔记-92 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用MOT SQL覆盖和限制

文章目录 openGauss学习笔记-92 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用MOT SQL覆盖和限制92.1 不支持的特性92.2 MOT限制92.3 不支持的DDL操作92.4 不支持的数据类型92.5 不支持的索引DDL和索引92.6 不支持的DML92.7 不支持的JIT功能&#xff08;…

ThingsBoard如何自定义tcp-transport

1、概述 很久没有更新了,一直忙于其他的事情,最近去搞了一个在ThingsBoard中自定义一个tcp-transport,用于连接使用tcp长连接的设备,目前使用tcp和mqtt协议连接服务端的设备还是很多,ThingsBoard的PE版提供了Integration是可以实现tcp的接入,但是CE版是没有提供接入tcp长…