机器学习中的工作流机制

news2024/10/2 3:32:57

机器学习中的工作流机制

在项目开发的时候,经常需要我们选择使用哪一种模型。同样的数据,可能决策树效果不错,朴素贝叶斯也不错,SVM也挺好。有没有一种方法能够让我们用一份数据,同时训练多个模型,并用某种直观的方式(包括模型得分),观察到模型在既有数据上的效果?有的,管线工作流pipeline就是专门干这个的,再配上决策边界,所有模型只用一眼,就能确定优劣,选择你的梦中情模。上效果图。

在这里插入图片描述

分为两行,上面是sklearn自带数据集中的数据,分两类。从第二列开始,每一列是某种模型在当前数据集中的拟合效果。如何查看某种模型效果好坏?从两个方面,左上角的模型得分,和图中颜色深浅,两种颜色的分解代表模型的决策边界。

下面是笔者自己的数据,分为4类。同样不同颜色的分界代表两种类型的判别边界。如果只看模型得分,那得分为100%的模型有5个,选再根据决策边界进一步确定更优秀的模型,为工程所用。这里贴出笔者所用代码供各位修改,也可以直接取官方代码修改

def loadTrainData():
    df = pd.read_csv('./your/dataset/path/data.csv')
    trainDataLabel = df.values
    nodeData = trainDataLabel[:, :2], trainDataLabel[:, -1]
    return nodeData

def trainAnalySave():
    from matplotlib.colors import ListedColormap
    import joblib

    from sklearn.datasets import make_circles, make_classification, make_moons
    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
    from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
    from sklearn.gaussian_process import GaussianProcessClassifier
    from sklearn.gaussian_process.kernels import RBF
    from sklearn.inspection import DecisionBoundaryDisplay
    from sklearn.model_selection import train_test_split
    from sklearn.naive_bayes import GaussianNB
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.neural_network import MLPClassifier
    from sklearn.pipeline import make_pipeline
    from sklearn.preprocessing import StandardScaler
    from sklearn.svm import SVC
    from sklearn.tree import DecisionTreeClassifier

    names = [
        "Nearest Neighbors",
        "Linear SVM",
        "RBF SVM",
        "Gaussian Process",
        "Decision Tree",
        "Random Forest",
        "Neural Net",
        "AdaBoost",
        "Naive Bayes",
        "QDA",
    ]

    classifiers = [
        KNeighborsClassifier(3),
        SVC(kernel="linear", C=0.025),
        SVC(gamma=2, C=1),
        GaussianProcessClassifier(1.0 * RBF(1.0)),
        DecisionTreeClassifier(max_depth=5),
        RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
        MLPClassifier(alpha=1, max_iter=1000),
        AdaBoostClassifier(),
        GaussianNB(),
        QuadraticDiscriminantAnalysis(),
    ]

    # X, y = make_classification(
    #     n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1
    # )
    # rng = np.random.RandomState(2)
    # X += 2 * rng.uniform(size=X.shape)
    # linearly_separable = (X, y)

    nodeData = loadTrainData()

    datasets = [
        # make_moons(noise=0.3, random_state=0),
        make_circles(noise=0.2, factor=0.5, random_state=1),
        # linearly_separable,
        nodeData,
    ]

    # figure = plt.figure(figsize=(27, 9))
    figure = plt.figure(figsize=(15, 4))
    i = 1
    # iterate over datasets
    for ds_cnt, ds in enumerate(datasets):
        # preprocess dataset, split into training and test part
        X, y = ds
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.4, random_state=42
        )

        x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
        y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5

        # just plot the dataset first
        cm = plt.cm.RdBu
        cm_bright = ListedColormap(["#FF0000", "#00FF00", "#FFFF00", "#0000FF"])
        ax = plt.subplot(len(datasets), len(classifiers) + 1, i)
        if ds_cnt == 0:
            ax.set_title("Input data")
        # Plot the training points
        ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k")
        # Plot the testing points
        ax.scatter(
            X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6, edgecolors="k"
        )
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_xticks(())
        ax.set_yticks(())
        i += 1

        # iterate over classifiers
        for name, clf in zip(names, classifiers):
            ax = plt.subplot(len(datasets), len(classifiers) + 1, i)

            clf = make_pipeline(StandardScaler(), clf)
            clf.fit(X_train, y_train)
            score = clf.score(X_test, y_test)
            # DecisionBoundaryDisplay.from_estimator(
            #     clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5
            # )

            # save satisfied model
            savedPath = r'..\models\sklearn\\'
            savedList = ["Nearest Neighbors", "RBF SVM", "Neural Net"]
            if name in savedList:
                joblib.dump(clf, savedPath + name + '.pkl')

            # Plot the training points
            ax.scatter(
                X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k"
            )
            # Plot the testing points
            ax.scatter(
                X_test[:, 0],
                X_test[:, 1],
                c=y_test,
                cmap=cm_bright,
                edgecolors="k",
                alpha=0.6,
            )

            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_xticks(())
            ax.set_yticks(())
            if ds_cnt == 0:
                ax.set_title(name)
            ax.text(
                # x_max - 0.3,
                # y_min + 0.3,
                x_min + 0.4,
                y_max - 0.4 - ds_cnt,
                ("%.2f" % score),
                # ("%.2f" % score).lstrip("0"),
                # size=15,
                size=10,
                # horizontalalignment="right",
                horizontalalignment="left",
            )
            i += 1

    plt.tight_layout()
    plt.show()   

    nodeData = loadTrainData()
if __name__ == '__main__':
    trainAnalySave()

注意,这里的DecisionBoundaryDisplay模块,需要安装sklearn的较新版本,因而python也需要较高版本。

最后打个广告,如果有想进修服务器开发相关的技能,这里是可以让你秒变大神的时光隧道。 enjoy~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/836469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于STM32103移植FreeRTOS

目录 一、FreeRTOS协议栈下载 二、准备工程文件与协议代码 三、移植FreeRTOS协议栈 一、FreeRTOS协议栈下载 1、官网下载 FreeRTOS - Market leading RTOS (Real Time Operating System) for embedded systems with Internet of Things extensionshttps://www.freertos.or…

SQL-每日一题【1193. 每月交易 I】

题目 Table: Transactions 编写一个 sql 查询来查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数及其总金额。 以 任意顺序 返回结果表。 查询结果格式如下所示。 示例 1: 解题思路 1.题目要求我们查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数…

Flutter(八)事件处理与通知

1.原始指针事件处理 一次完整的事件分为三个阶段:手指按下、手指移动、和手指抬起,而更高级别的手势(如点击、双击、拖动等)都是基于这些原始事件的。 Listener 组件 Flutter中可以使用Listener来监听原始触摸事件 Listener({…

Go语言开发者的Apache Arrow使用指南:读写Parquet文件

Apache Arrow是一种开放的、与语言无关的列式内存格式,在本系列文章[1]的前几篇中,我们都聚焦于内存表示[2]与内存操作[3]。 但对于一个数据库系统或大数据分析平台来说,数据不能也无法一直放在内存中,虽说目前内存很大也足够便宜…

FuncGPT竟然限免?!速来体验全球首个AIGF!

近日,飞算SoFlu软件机器人重磅推出全新功能——FuncGPT(慧函数)。FuncGPT(慧函数)是一款函数AI生产器,它能够根据用户的需求快速生成 Java 语言的函数代码。FuncGPT(慧函数)拥有强大…

行政资产管理信息系统

行政资产管理信息系统是通过专业设计开发的资产管理解决方案,旨在为企业建立和完善资产管理体系。该系统可以有效地控制资产的购买和应用,从而节省资金,完成资产的有效管理。   资产管理信息系统的核心功能是统一资产管理,可以…

java 版本企业招标投标管理系统源码+多个行业+tbms+及时准确+全程电子化

​ 功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查…

kernel pwn入门

Linux Kernel 介绍 Linux 内核是 Linux 操作系统的核心组件,它提供了操作系统的基本功能和服务。它是一个开源软件,由 Linus Torvalds 在 1991 年开始开发,并得到了全球广泛的贡献和支持。 Linux 内核的主要功能包括进程管理、内存管理、文…

2023-08-04 Untiy进阶 C#知识补充4——C#5主要功能与语法

文章目录 一、概述二、回顾——线程三、线程池四、Task 任务类五、同步和异步 ​ 注意:在此仅提及 Unity 开发中会用到的一些功能和特性,对于不适合在 Unity 中使用的内容会忽略。 一、概述 C# 5 调用方信息特性(C# 进阶内容)异步…

途乐证券|俄罗斯宣布9月削减石油出口量

当地时间周四,美股兜售潮仍在持续,三大股指连续第二个交易日团体收跌。到收盘,道指跌落0.19%,标普500指数跌落0.25%,纳指跌幅为0.10%。 美国ISM7月非制造业PMI下滑 数据面上,美国供应办理协会ISM周四发布的…

掌握好视频翻译软件的使用方法,帮你跨越语言障碍

嘿,翻译小达人们,你知道吗,当你看到一段充满神秘符号的英语视频,脑袋里冒出一大片问号的时候,别慌!我们有比手动翻译更妙的解决办法——视频翻译。嗯,这货可不一般,它能帮你解读视频…

锂电设备振动监测:早发现问题,早预防故障

锂电设备在生产过程中的振动问题可能导致设备故障、损坏和生产线停机,对企业产生严重影响。为了确保锂电池生产的稳定性和可靠性,振动监测成为了关键一步。通过引入智能无线温振传感器及其监测分析软件,企业可以早发现问题、早预防故障&#…

[PaddlePaddle] [学习笔记] PaddlePaddle 官方文档 —— 使用Python和NumPy构建神经网络模型

1. 机器学习和深度学习综述 1.1 人工智能、机器学习、深度学习的关系 近些年人工智能、机器学习和深度学习的概念十分火热,但很多从业者却很难说清它们之间的关系,外行人更是雾里看花。在研究深度学习之前,先从三个概念的正本清源开始。概括…

节能延寿:ARM Cortex-M微控制器下的低功耗定时器应用

嵌入式系统的开发在现代科技中发挥着至关重要的作用。它们被广泛应用于从智能家居到工业自动化的各种领域。在本文中,我们将聚焦于使用ARM Cortex-M系列微控制器实现低功耗定时器的应用。我们将详细介绍在嵌入式系统中如何实现低功耗的定时器功能,并附上代码示例。 嵌入式系…

详聊API接口?淘宝API接口在ERP系统中扮演者什么角色?

什么是API? API全称应用程序编程接口(Application Programming Interface),是一组用于访问某个软件或硬件的协议、规则和工具集合。电商API就是各大电商平台提供给开发者访问平台数据的接口。目前,主流电商平台如淘宝…

Ubuntu 22.04安装和使用ROS1可行吗

可行。 测试结果 ROS1可以一直使用下去的,这一点不用担心。Ubuntu会一直维护的。 简要介绍 Debian发行版^_^ AI:在Ubuntu 22.04上安装ROS1是可行的,但需要注意ROS1对Ubuntu的支持只到20.04。因此,如果要在22.04上安装ROS1&am…

Kubernetes v1.20 二进制部署

架构 k8s集群master01:192.168.80.101 kube-apiserver kube-controller-manager kube-scheduler etcd k8s集群master02:192.168.80.102 k8s集群node01:192.168.80.103 kubelet kube-proxy docker k8s集群node02:192.168.80…

问题解决和批判性思维是软件工程的重要核心

软件工程的重心在于问题解决和批判性思维(合理设计和架构降低复杂度),而非仅局限于编程。 许多人误以为软件工程就只是编程,即用编程语言编写指令,让计算机按照这些指令行事。但实际上,软件工程的内涵远超…

Leetcode-每日一题【剑指 Offer 04. 二维数组中的查找】

题目 在一个 n * m 的二维数组中,每一行都按照从左到右 非递减 的顺序排序,每一列都按照从上到下 非递减 的顺序排序。请完成一个高效的函数,输入这样的一个二维数组和一个整数,判断数组中是否含有该整数。 示例: 现有矩阵 matri…

机器学习基础之《特征工程(3)—特征预处理》

一、什么是特征预处理 通过一些转换函数将特征数据转换成更加适合算法模型的特征数据过程 处理前,特征值是数值,处理后,进行了特征缩放