Rust机器学习之Linfa

news2025/1/16 12:41:53

Rust机器学习之Linfa

众所周知,Python之所以能成为机器学习的首选语言,与其丰富易用的库有很大关系。某种程度上可以说是诸如numpypandasscikit-learnmatplotlibpytorchnetworks…等一系列科学计算和机器学习库成就了Python今天编程语言霸主的地位。基本上今天的机器学习任务主要就建立在上面列举的这6个库上。这6个库在Rust上都有对应的替代方案,我会带大家一起学习如何使用Rust及其库替代Python来更好得完成机器学习任务。

Python库Rust替代方案教程
numpyndarrayRust机器学习之ndarray
pandasPolars Rust机器学习之Polars
scikit-learnLinfaRust机器学习之Linfa
pytorchtch-rsRust机器学习之tch-rs
networkspetgraphRust机器学习之petgraph
matplotlibplottersRust机器学习之plotters

本文将带领大家用Linfa实现一个完整的Logistics回归,过程中带大家学习Linfa的基本用法。

数据和算法工程师偏爱Jupyter,为了跟Python保持一致的工作环境,文章中的示例都运行在Jupyter上。因此需要各位搭建Rust交互式编程环境(让Rust作为Jupyter的内核运行在Jupyter上),相关教程请参考 《Rust交互式编程环境搭建》。

在这里插入图片描述

文章目录

    • 什么是Linfa
    • 逻辑(Logistic)回归
    • 用Linfa实现逻辑回归
      • 安装Linfa
      • 加载数据
      • 数据探索
      • 模型训练
      • 模型优化
    • 总结

什么是Linfa

Linfa 是一组Rust高级库的集合,提供了常用的数据处理方法和机器学习算法。Linfa对标Python上的scikit-learn,专注于日常机器学习任务常用的预处理任务和经典机器学习算法,目前Linfa已经实现了scikit-learn中的全部算法,这些算法按算法类型组织在各子包中:

名字功能状态类别备注
clustering数据聚类无监督学习用于无标记数据的聚类,包括K-Means、高斯混合模型、DBSCAN和OPTICS等算法
kernel用于数据变换的核方法预处理将特征向量映射到更高维空间
linear线性回归部分拟合包含一般最小二乘法(OLS)、广义线性模型(GLM)
elasticnet弹性网络监督学习带有弹性网络约束的线性回归
logistic逻辑回归部分拟合包含两类逻辑回归模型
reduction降维预处理扩散映射和主成分分析(PCA)
trees决策树监督学习线性决策树
svm支持向量机监督学习标记数据集的分类或回归分析
hierarchical聚集层次聚类无监督学习聚类和构建聚类层次结构
bayes朴素贝叶斯监督学习包含高斯朴素贝叶斯
ica独立成分分析无监督学习包含FastICA实现
pls偏最小二乘法监督学习包含用于降维和回归的PLS估计
tsne降维无监督学习包含精确解和Barnes-Hut近似t-SNE
preprocessing标准化和向量化预处理包含各种常用数据预处理方法
nn最近邻和最小距离预处理空间索引结构和距离函数
ftrlFTRL-Proximal部分拟合包含L1和L2正则化

按类别进行一个分类整理会更清晰:

在这里插入图片描述

图1. Linfa子包分类

这些子包几乎涵盖了机器学习所需的所有方面。可以说,Linfa当前最新稳定版0.6.0的功能与scikit-learn完全一致。

逻辑(Logistic)回归

因为本文的重点是如何用Rust解决机器学习问题,所以我们不会深入研究逻辑回归的具体工作原理。然而,我们应该至少对它的含义有一个基本的理解。

逻辑回归是一种统计模型,用于测量结果的概率,如真/假、接受/拒绝等,也可以扩展到多个类别。逻辑回归内部使用logistic函数(也叫S曲线),该函数可以写成:
s ( x ) = 1 1 + e − x s(x) = \frac{1}{1+e^{-x}} s(x)=1+ex1
这个函数是一个S曲线,得到的结果在0和1之间,x的值越大,s(x)越接近1,x的值越小,s(x)越接近0,具体曲线如下:

在这里插入图片描述

图2. Logistic函数图像

Logistic回归的目的是找到与给定数据集拟合最好的函数。简单地说,它模拟了数据中我们关注的随机变量(0或1)的概率。

在机器学习中,通常使用梯度下降来寻找最优模型,这是一种寻找局部最小值的优化方法。目标通常是计算误差,然后将误差最小化。

用Linfa实现逻辑回归

本文的目标是演示如何用Rust构建简单的机器学习应用。为了方便演示和阅读,我们这里使用一个仅包含100条记录的非常小的数据集。

我们还将跳过机器学习的数据准备工作,这里可能包括异常值处理、标准化、数据清洗等预处理步骤。这是数据科学的一个非常重要的部分,但这不在本文的重点,这部分内容大家可以阅读《Rust机器学习之ndarray》 和 《Rust机器学习之Polars》 。

我们使用的数据集和简单,其结构如下:

score1score2accepted
32.7228330406032343.307173064300630
64.039320415060178.031688020182321

第一列表示学生第一次考试的成绩,第二列表示第二次考试的成绩。这两列是我们数据集的特征;第三列是数据集的目标,表示该学生是否会被学校录取,1表示录取,0表示拒接。

我们机器学习任务的目标是训练一个模型,该模型可以根据两次考试的分数可靠地预测学生是否会被学校录取。我们将数据拆分为训练集和测试集,其中65条数据为训练集,保存在train.csv中;35条数据为测试集,保存在test.csv中。最后,我们将测试训练得到的模型在尚未观测的数据上是否表现良好。

安装Linfa

安装使用Linfa非常简单,只需要在Cargo .toml加入

[dependencies]
linfa = { version = "0.6.0", features = ["openblas-system"] }
linfa-logistic = "0.6.0"

这里我们需要linfalinfa-logistic两个包,其中linfa提供了基础工具集,linfa-logistic提供了逻辑回归算法。

这里我们还添加了openblas-system特性,让我们的底层计算运行在libopenblas上。Linfa支持多个BLAS/LAPACK后端:

LinuxmacOSWindows
OpenBLAS
Netlib
Intel MKL

如果你用的操作系统是macOS或Windows,这里请替换成intel-mkl-system

在机器学习中,我们更喜欢使用Jupyter。如果你已经搭建好Rust交互式编程环境(可以参考 《Rust交互式编程环境搭建》),可以直接通过下面代码引入linfalinfa-logistic :

:dep linfa = {version="0.6.0", features = ["openblas-system"]}
:dep linfa-logistic = {version="0.6.0"}

除了Linfa外,我们还需要用到ndarray来处理n维向量;用csvndarray-csv来加载csv格式的数据。

:dep ndarray = {version = "0.15.6"}
:dep ndarray-csv = {version = "0.5.1"}
:dep csv = {version = "1.1"}

加载数据

任何机器学习的第一步都是载入数据。我们这里也不例外。我们需要从.data/train.csv.data/test.csv文件中读取数据,并将其转换为ndarray,再用ndarray创建Linfa Dataset

fn load_data(path: &str) -> Dataset<f64, &'static str, Ix1> {
    let mut reader = ReaderBuilder::new()
        .has_headers(false)
        .delimiter(b',')
        .from_path(path)
        .expect("can create reader");

    let array: Array2<f64> = reader
        .deserialize_array2_dynamic()
        .expect("can deserialize array");

    let (data, targets) = (
        array.slice(s![.., 0..2]).to_owned(),
        array.column(2).to_owned(),
    );

    let feature_names = vec!["test 1", "test 2"];

    Dataset::new(data, targets).map_targets(|x| {
            if *x as usize == 1 {
                "accepted"
            } else {
                "denied"
            }
        })
        .with_feature_names(feature_names)
}

简单解释一下上面的代码。

首先我们用csv::ReaderBuilder读入csv文件。这里的has_headers(false)表示读入的文件没有表头,·.delimiter(b',')表示数据用逗号分隔。

接着用ndarray-csv库提供了deserialize_array2_dynamic()方法可以将csv格式数转换成ndarray::Array2(二维数组)。然后我们将此ndarray二维数组切分成featuretarget,我们的数据集中前两列是feature,最后一列是target

有了featuretarget我们就可以用Dataset::new(data, targets)创建Linfa Dataset。Dataset创建好后我们还对里面的数据做了些处理,map_targets中的闭包将target的值映射到字符串(0=“denied”;1=“accepted”),并用with_feature_namesfeature字段进行了命名。

最后将创建并处理好的Dataset对象返回给调用者。使用时只需要传入文件路径即可

let train = load_data("data/train.csv");
let test  = load_data("data/test.csv");

数据探索

在开始模型训练之前,我们先看一下数据的分布情况。

首先我们将数据分成正例和负例,在可视化时用两种不同颜色来区分两类数据。代码实现上很简单,只需要根据数据集中target的值将数据放入对应类型的列表中即可。代码实现如下:

let mut positive = vec![];
let mut negative = vec![];

let records = train.records().clone().into_raw_vec();
let features: Vec<&[f64]> = records.chunks(2).collect();
let targets = train.targets().clone().into_raw_vec();
for i in 0..features.len() {
    let feature = features.get(i).expect("feature exists");
    if let Some(&"accepted") = targets.get(i) {
        positive.push((feature[0], feature[1]));
    } else {
        negative.push((feature[0], feature[1]));
    }
}

有了数据后,我们用散点图将数据的分布描绘在图上。这里我使用plotters进行绘图,关于如何使用plotters进行数据可视化后面会有专门的教程教大家使用,这里大家先结合注释大体浏览一下代码功能:

:dep plotters = { version = "^0.3.0", default_features = false, features = ["evcxr", "all_series"] }

extern crate plotters;
use plotters::prelude::*;

evcxr_figure((640, 480), |root| {
    // 设置图表参数
    let mut ctx = ChartBuilder::on(&root)
        .set_label_area_size(LabelAreaPosition::Left, 40)// 设置y轴标签区域大小
        .set_label_area_size(LabelAreaPosition::Bottom, 40)// 设置x轴标签区域大小
        .build_cartesian_2d(0.0..120.0, 0.0..120.0) // 设置直角坐标系的范围
        .unwrap();

    // 设置网格
    ctx.configure_mesh().draw().unwrap();

    // 绘制正例散点图
    ctx.draw_series(
        positive
            .iter()
            .map(|point| TriangleMarker::new(*point, 5, &BLUE)),
    )
    .unwrap();
	
    // 绘制负例散点图
    ctx.draw_series(
        negative
            .iter()
            .map(|point| Circle::new(*point, 5, &RED)),
    )
    .unwrap();
    Ok(())
})

上代码输出的数据分布如下图:

在这里插入图片描述

图3. 训练集数据分布

模型训练

接下来我们正式进入模型构建环节。这个工作可以分为如下几步:

  1. 构造逻辑回归模型,并用训练集数据进行训练;
  2. 用测试集数据对训练出的模型进行测试;
  3. 构建混淆矩阵评估模型在测试集上的精度。

混淆矩阵本质上是一个 2 × 2 2 \times 2 2×2的表,它显示了真阳性(TP)、假阳性(FP)、真阴性(TN)和假阴性(FN),我们可以通过混淆矩阵计算模型的准确率、精确率和召回率等指标。

混淆矩阵预测值
PositiveNegative
真实值PositiveTPFN
NegativeFPTN

以上3步Linfa都有封装好的接口可以直接调用。

构造逻辑回归模型

Linfa提供LogisticRegression用于构造逻辑回归模型,下面代码创建逻辑回归模型,并用训练集进行训练:

let model = LogisticRegression::default()
        .max_iterations(max_iterations)
        .gradient_tolerance(0.0001)
        .fit(train)
        .expect("can train model");

其中max_iterations()方法用于设置最大迭代次数,gradient_tolerance()用于设置梯度下降的学习率,当变化值小于该值时则停止迭代。调大学习率可以提高算法速度,但是最终得到的可能是局部最优,不是全局最优。

最后,调用.fit(train)开始用传入的训练集训练模型。

测试模型

模型训练好后,可以调用.predict(test)用测试集对模型进行测试:

let validation = model.set_threshold(threshold).predict(test);

这里set_threshold用来设置预测“正”类的概率阈值,默认值为0.5。

创建混淆矩阵

最有一步,我们根据测试的结果构造混淆矩阵。Linfa提供了confusion_matrix方法可以在测试结果上直接生成混淆矩阵:

let confusion_matrix = validation
        .confusion_matrix(test)
        .expect("can create confusion matrix");

至此,模型训练的核心步骤完成了。接下来我们需要找到训练效果最好的那个模型。

模型优化

上面构造的模型中有2个超参:迭代次数max_iterations决策阈值threshold。我们需要反复多次测试以找到这两个参数的最有值,为此我们需要构造循环多次调用上面的过程。

为了让调用更方便,我们需要先将上面的模型构造和训练过程封装成一个函数,传入训练集、测试集和两个超参,返回混淆矩阵。

fn train_and_test(
    train: &DatasetBase<
        ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
        ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
    >,
    test: &DatasetBase<
        ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
        ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
    >,
    threshold: f64,
    max_iterations: u64,
) -> ConfusionMatrix<&'static str> {
    let model = LogisticRegression::default()
        .max_iterations(max_iterations)
        .gradient_tolerance(0.0001)
        .fit(train)
        .expect("can train model");

    let validation = model.set_threshold(threshold).predict(test);

    let confusion_matrix = validation
        .confusion_matrix(test)
        .expect("can create confusion matrix");

    confusion_matrix
}

有了上面的函数,我们的循环寻找最优最优超参的代码写起来会很简单:

let mut max_accuracy_confusion_matrix = train_and_test(&train, &test, 0.01, 100);
let mut best_threshold = 0.0;
let mut best_max_iterations = 0;
let mut threshold = 0.02;

for max_iterations in (1000..5000).step_by(500) {
    while threshold < 1.0 {
        let confusion_matrix = train_and_test(&train, &test, threshold, max_iterations);

        if confusion_matrix.accuracy() > max_accuracy_confusion_matrix.accuracy() {
            max_accuracy_confusion_matrix = confusion_matrix;
            best_threshold = threshold;
            best_max_iterations = max_iterations;
        }
        threshold += 0.01;
    }
    threshold = 0.02;
}

println!(
    "最精确混淆矩阵: {:?}",
    max_accuracy_confusion_matrix
);
println!(
    "最优迭代次数: {}\n最优决策阈值: {}",
    best_max_iterations, best_threshold
);
println!("精确率:\t{}", max_accuracy_confusion_matrix.accuracy(),);
println!("准确率:\t{}", max_accuracy_confusion_matrix.precision(),);
println!("召回率:\t{}", max_accuracy_confusion_matrix.recall(),);

最终经过优化后,最优模型输出如下:

最精确混淆矩阵: 
classes    | denied     | accepted
denied     | 11         | 0
accepted   | 2          | 22

最优迭代次数: 1000
最优决策阈值: 0.37000000000000016
精确率: 0.94285715
准确率: 0.84615386
召回率: 1

从上面输出我们能看到,只有2个数据分类错误,模型的精确率为94%,模型看起来还不错。

总结

本文中,我们用Linfa训练了一个效果还不错的逻辑回归模型。尽管我们用的数据样本很少,只有100条,但是完整地向大家展示了如何用Linfa进行机器学习。

今天,Rust的机器学习生态已经非常完善,然而社区仍在不断努力,向着Python快速靠近。面向未来,Rust快速、安全的特性会使它成为机器学习领域不可忽视,甚至是主流的编程语言。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/44868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MAX7800实现KWS20 demo演示】

【MAX7800实现KWS20 demo演示】1. 概述2. 关键字定位演示2.1 构建固件&#xff1a;2.2 选择板卡2.3 MAX78000 EVKIT2.3.1 MAX78000 EVKIT下载程序2.3.2 MAX78000 EVKIT 跳线设置2.3.3 MAX78000 EVKIT 操作2.4 MAX78000 Feather2.4.1 MAX78000 Feather羽毛板下载固件2.4.2 MAX78…

IBM MQ MQCSP

一&#xff0c;概念 1.1 用途 用途&#xff1a;MQCSP 结构使授权服务能够验证用户 ID 和密码。您在 MQCONNX 调用上指定 MQCSP 连接安全参数结构。 警告&#xff1a;在某些情况下&#xff0c;客户端应用程序的 MQCSP 结构中的密码将以纯文本形式通过网络发送。要确保客户端应…

【学习笔记58】JavaScript面向对象

一、认识面向对象 &#xff08;一&#xff09;面向过程编程 按照程序执行的过程一步一步的完成程序代码 &#xff08;二&#xff09;面向对象编程 面向对象编程是一种编程的方式/模式官方&#xff1a;对一类具有相同属性和功能的程序代码抽象的描述&#xff0c;实现代码编程…

Triangle Attack: A Query-efficient Decision-based Adversarial Attack

Triangle Attack: A Query-efficient Decision-based Adversarial Attack 三角攻击:一种查询高效的基于决策的对抗性攻击 Abstract 基于决策的攻击对实际应用程序构成了严重的威胁&#xff0c;因为它将目标模型视为一个黑箱&#xff0c;只访问硬预测标签。最近已经做出了很大…

【计组】指令和运算1--《深入浅出计算机组成原理》(二)

一、计算机指令 1、指令 从软件工程师的角度来讲&#xff0c;CPU就是一个执行各种计算机指令&#xff08;Instruction Code&#xff09;的逻辑.。 这里的计算机指令&#xff0c;也可以叫做机器语言。 不同发CPU支持的机器语言不同&#xff0c;如个人电脑用的是Intel的CPU&a…

同样Java后端开发三年,朋友已经涨薪到了30k,而我才刚到12K。必须承认多背背八股文确实有奇效!

程序猿在世人眼里已经成为高薪、为人忠诚的代名词。 然而&#xff0c;小编要说的是&#xff0c;不是所有的程序员工资都是一样的。 世人所不知的是同为程序猿&#xff0c;薪资的差别还是很大的。 众所周知&#xff0c;目前互联网行业是众多行业中薪资待遇最好的&#xff0c;…

2022年NPDP新版教材知识集锦--【第四章节】(2)

【概念设计阶段】(全部获取文末) 概念描述提供了产品概念的优点和特征的定性描述&#xff0c;其必要性体现在&#xff1a; ①为开发团队的所有成员以及与项目相关的成员提供了清晰性和一致性。 ②是向潜在客户解释产品的重要手段之一。 典型流程&#xff1a; 2.1概念工程 …

python使用websocket服务并在fastAPI中启动websocket服务

依赖 pip install websockets-routes 代码 import asyncio import websockets import websockets_routes from websockets.legacy.server import WebSocketServerProtocol from websockets_routes import RoutedPath# 初始化一个router对象 router websockets_routes.Router()…

Archlinux安装软件的那些事

个人主页&#xff1a;董哥聊技术我是董哥&#xff0c;嵌入式领域新星创作者创作理念&#xff1a;专注分享高质量嵌入式文章&#xff0c;让大家读有所得&#xff01;文章目录1、ArchLinux1.1 ArchLinux原则1.2 软件包管理1.2.1 软件仓库1.2.2 包管理器2、Pacman2.1 pacman介绍2.…

什么是幂等性?四种接口幂等性方案详解!

幂等性在我们的工作中无处不在&#xff0c;无论是支付场景还是下订单等核心场景都会涉及&#xff0c;也是分布式系统最常遇到的问题&#xff0c;除此之外&#xff0c;也是大厂面试的重灾区。 知道了幂等性的重要性&#xff0c;下面我就详细介绍幂等性以及具体的解决方案&#…

SpringBoot中自动配置

第一种&#xff1a; 给容器中的组件加上 ConfigurationProperties注解即可 测试&#xff1a; Component ConfigurationProperties(prefix "mycar") public class Car {private String brand;private Integer price;private Integer seatNum;public Integer getSeat…

币圈已死,绿色积分是全新的赛道吗?

近几年来&#xff0c;移动互联网行业的迅猛发展&#xff0c;快速改变着社会业态。尽管如此&#xff0c;仍有大量企业线上线下处于割裂状态&#xff0c;2020 年一场疫情的突然爆发&#xff0c;并持续到 2022年&#xff0c;对零售行业造成流量崩塌、供应链中断、市场供需下滑等压…

现代 CSS 高阶技巧,完美的波浪进度条效果。

将专注于实现复杂布局&#xff0c;兼容设备差异&#xff0c;制作酷炫动画&#xff0c;制作复杂交互&#xff0c;提升可访问性及构建奇思妙想效果等方面的内容。 在兼顾基础概述的同时&#xff0c;注重对技巧的挖掘&#xff0c;结合实际进行运用&#xff0c;欢迎大家关注。 正…

金属非金属如何去毛刺 机器人浮动去毛刺

毛刺的产生 在金属非金属零件的加工中&#xff0c;由于切削加工过程中塑性变形引起的毛边&#xff0c;或者是铸造、模锻等加工的飞边&#xff0c;或是焊接挤出的残料&#xff0c;这些与所要求的形状、尺寸有所出入&#xff0c;在被加工零件上派生出的多余部分即为毛刺&#xf…

音视频开发之 ALSA实战!

前言&#xff1a; 今天我们来分享一个开源的音频采集代码&#xff0c;现在大部分音频采集都是通过ALSA框架去采集&#xff0c;如果大家把ALSA采集代码学懂&#xff0c;那么大部分的音频采集都可以搞定。这个代码是用ALSA进行音频PCM的采集并保存到本地文件。一、alsa框架的介绍…

C#语言实例源码系列-实现输入框焦点变色和窗体拖拽改变大小

专栏分享点击跳转>Unity3D特效百例点击跳转>案例项目实战源码点击跳转>游戏脚本-辅助自动化点击跳转>Android控件全解手册 &#x1f449;关于作者 众所周知&#xff0c;人生是一个漫长的流程&#xff0c;不断克服困难&#xff0c;不断反思前进的过程。在这个过程中…

002.组合总和|||——回溯算法

1.题目链接&#xff1a; 216. 组合总和 III 2.解题思路&#xff1a; 2.1.题目要求&#xff1a; 给一个元素数量k和一个元素和n&#xff0c;要求从范围[1,2,3,4,5,6,7,8,9]中返回所有元素数量为k和元素和为n的组合。&#xff08;每个数字只能使用一次&#xff09; 比如输入k…

深度学习快速入门----Pytorch 系列2

注&#xff1a;参考B站‘小土堆’视频教程 视频链接&#xff1a;【PyTorch深度学习快速入门教程&#xff08;绝对通俗易懂&#xff01;&#xff09;【小土堆】 上一篇&#xff1a;深度学习快速入门----Pytorch 1 文章目录八、神经网络--非线性激活九、神经网络--线性层及其他层…

作为IT行业过来人,我有3个重要建议给后辈程序员!

见字如面&#xff0c;我是军哥&#xff01;作为一名 40 岁的 IT 老兵&#xff0c;我在年轻时踩了不少坑&#xff0c;我总结了其中最重要的 3 个并一次性分享给你&#xff0c;文章不长&#xff0c;你一定要看完哈&#xff5e;1、重视基础还不够&#xff0c;还要注重技术广度和深…

第2-4-8章 规则引擎Drools实战(1)-个人所得税计算器

文章目录9. Drools实战9.1 个人所得税计算器9.1.1 名词解释9.1.2 计算规则9.1.2.1 新税制主要有哪些变化&#xff1f;9.1.2.2 资较高人员本次个税较少&#xff0c;可能到年底扣税增加&#xff1f;9.1.2.3 关于年度汇算清缴9.1.2.4 个人所得税预扣率表&#xff08;居民个人工资、…