SparkMlib 之逻辑回归及其案例

news2025/1/11 8:01:29

文章目录

    • 什么是逻辑回归?
    • 逻辑回归的优缺点
    • 逻辑回归示例——预测回头客
    • 逻辑回归示例——预测西瓜好坏
    • 逻辑回归示例——预测垃圾邮件

什么是逻辑回归?

逻辑回归是一种流行的预测分类响应的方法。它是预测结果概率的广义线性模型的特例。在逻辑回归中,可以通过使用二项式逻辑回归来预测二元结果,也可以通过使用多项式逻辑回归来预测多类结果。

常应用于以下类型的场景:

  1. 预测一个西瓜的好坏;
  2. 预测这封邮件是否是垃圾邮件;
  3. 预测用户是否会成为回头客等等

官网:分类和回归

逻辑回归的优缺点

优点:

  1. 训练速度较快,分类的时候,计算量仅仅只和特征的数目相关;
  2. 简单易理解,模型的可解释性非常好,从特征的权重可以看到不同的特征对最后结果的影响;
  3. 适合二分类问题,不需要缩放输入特征;
  4. 内存资源占用小,因为只需要存储各个维度的特征值。

缺点:

  1. 不能用 Logistic 回归去解决非线性问题,因为 Logistic 的决策面试线性的;
  2. 对多重共线性数据较为敏感;
  3. 很难处理数据不平衡的问题;
  4. 准确率并不是很高,因为形式非常的简单(非常类似线性模型),很难去拟合数据的真实分布;
  5. 逻辑回归本身无法筛选特征,有时会用 gbdt 来筛选特征,然后再上逻辑回归。

参考博客:逻辑回归的优缺点

逻辑回归示例——预测回头客

数据集下载:

链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 

提取码:
lz3l

数据集介绍:

tb_train.csv训练集数据,其中共有五个字段,四个特征字段:user_id、age_range、gender、merchant_id,一个标签字段:label

训练集中的标签字段只有值 010 表示不是回头客,1 表示是回头客。

tb_test.csv测试集数据,其中共有五个字段,四个特征字段:user_id、age_range、gender、merchant_id,一个标签字段:label

测试集中的标签字段都为空值。

需求实现:

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object logistic{

    // TODO 预测用户是否会成为回头客

    def main(args: Array[String]): Unit = {

        val sc: SparkSession = SparkSession.builder().appName("logistic").master("local[*]").getOrCreate()

        // 1.加载训练集数据
        val train_rdd: RDD[Row] = sc.read
                .option("header", "true")
                .csv("tb_train.csv").rdd

        // 2.向量转换
        import sc.implicits._

        val train: DataFrame = train_rdd.map(lines => {
            val arr: Array[String] = lines.mkString(",").split(",")
            LabeledPoint(arr(4).toDouble, Vectors.dense(arr.slice(0, 4).map(_.toDouble)))
        }).toDF("label","features")

        // 3.创建逻辑回归对象
       val lr = new LogisticRegression()
        // 设置最大迭代次数与正则化参数
        lr.setMaxIter(10).setRegParam(0.01)

        // 4. 模型训练
        val model: LogisticRegressionModel = lr.fit(train)

        // 5.模型保存示例
        model.save("./logistic/")

        // 6.加载模型示例
        val regressionModel: LogisticRegressionModel = LogisticRegressionModel.load("./logistic/")

        // 7.加载测试集
        val test_rdd: RDD[Row] = sc.read
                .option("header", "true")
                .csv("tb_test.csv").rdd

        // 8.测试集变量转换
        val test: DataFrame = test_rdd.map(lines => {
            val arr: Array[String] = lines.mkString(",").split(",")
            LabeledPoint(0D, Vectors.dense(arr.slice(0, 4).map(_.toDouble)))
        }).toDF("label", "features")

        // 9.预测测试集数据的结果(不带标签)
        regressionModel
        	.transform(test.select("features"))
        	.select("features","prediction")
        	.limit(100)
        	.show(100)
    }

}

逻辑回归示例——预测西瓜好坏

数据集下载:

链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 

提取码:
lz3l

数据集介绍:

西瓜集.csv 数据集中共有八个字段,六个特征字段:色泽、根蒂、敲声、纹理、脐部、触感,一个标签字段:好瓜,还有一个编号字段。

训练集中的随机百分之20的数据为测试集。

需求实现:

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object Watermelon {

    def main(args: Array[String]): Unit = {

        val sc: SparkSession = SparkSession
                .builder()
                .appName("watermelon")
                .master("local[*]").getOrCreate()

        // 1.加载训练数据集
        val train_rdd: RDD[String] = sc.read
                .option("header", "true")
                .textFile("西瓜集.csv")
                .rdd

        // 2.取出百分之80作为训练集,其余为测试集
        val data: Array[RDD[String]] = train_rdd.randomSplit(Array(0.8, 0.2))

        // 3.转换向量
        import sc.implicits._

        val trainDF: DataFrame = data(0).map(lines => {
            val arr: Array[String] = lines.split(",")
            LabeledPoint(
                if (arr(7).equals("是")) {
                    1D
                } else {
                    0D
                },
                Vectors.dense(
                    // 色泽转换
                    if (arr(1).equals("青绿")){
                        1D
                    }else if (arr(1).equals("乌黑")){
                        2D
                    }else{
                        3D
                    },
                    // 根蒂转换
                    if (arr(2).equals("硬挺")){
                        1D
                    }else if (arr(2).equals("蜷缩")){
                        2D
                    }else{
                        3D
                    },
                    // 敲声转换
                    if (arr(3).equals("清脆")){
                        1D
                    }else if (arr(3).equals("沉闷")){
                        2D
                    }else{
                        3D
                    },
                    // 纹理转换
                    if (arr(4).equals("清晰")){
                        1D
                    }else if (arr(4).equals("模糊")){
                        2D
                    }else{
                        3D
                    },
                    // 脐部转换
                    if (arr(5).equals("平坦")){
                        1D
                    }else if (arr(5).equals("凹陷")){
                        2D
                    }else{
                        3D
                    },
                    // 触感转换
                    if (arr(6).equals("软黏")){
                        1D
                    }else if (arr(6).equals("硬滑")){
                        2D
                    }else{
                        3D
                    }
                )
            )
        }).toDF("label", "features")


        // 4.创建逻辑回归模型
        val lr = new LogisticRegression()

        // 设置参数
        lr.setMaxIter(10).setRegParam(0.01)

        // 5.模型训练
        val model: LogisticRegressionModel = lr.fit(trainDF)

        // 6.将测试数据集转换为向量
        val testDF: DataFrame = data(1).map(lines => {
            val arr: Array[String] = lines.split(",")
            LabeledPoint(
                if (arr(7).equals("是")) {
                    1D
                } else {
                    0D
                },
                Vectors.dense(
                    // 色泽转换
                    if (arr(1).equals("青绿")){
                        1D
                    }else if (arr(1).equals("乌黑")){
                        2D
                    }else{
                        3D
                    },
                    // 根蒂转换
                    if (arr(2).equals("硬挺")){
                        1D
                    }else if (arr(2).equals("蜷缩")){
                        2D
                    }else{
                        3D
                    },
                    // 敲声转换
                    if (arr(3).equals("清脆")){
                        1D
                    }else if (arr(3).equals("沉闷")){
                        2D
                    }else{
                        3D
                    },
                    // 纹理转换
                    if (arr(4).equals("清晰")){
                        1D
                    }else if (arr(4).equals("模糊")){
                        2D
                    }else{
                        3D
                    },
                    // 脐部转换
                    if (arr(5).equals("平坦")){
                        1D
                    }else if (arr(5).equals("凹陷")){
                        2D
                    }else{
                        3D
                    },
                    // 触感转换
                    if (arr(6).equals("软黏")){
                        1D
                    }else if (arr(6).equals("硬滑")){
                        2D
                    }else{
                        3D
                    }
                )
            )
        }).toDF("label", "features")

        // 7.预测西瓜是否是好瓜(带标签)
        println("预测西瓜是否是好瓜(带标签):")
        model.transform(testDF)
                .select("label", "features","prediction")
                .show()

        // 8.预测西瓜是否是好瓜(不带标签)
        println("预测西瓜是否是好瓜(不带标签):")
        model.transform(testDF.select("features"))
                .select("features","prediction")
                .show()

    }

}

逻辑回归示例——预测垃圾邮件

直接看代码

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.sql.{DataFrame, SparkSession}

object Email {

    // TODO 预测垃圾邮件

    def main(args: Array[String]): Unit = {

        val sc: SparkSession = SparkSession
                .builder()
                .appName("email")
                .master("local[*]").getOrCreate()

        // 训练数据集
        val train_data: DataFrame = sc.createDataFrame(Seq(
            ("you@example.com", "hope you are well", 0.0),
            ("raj@example.com", "nice to hear from you", 0.0),
            ("thomas@example.com", "happy holidays", 0.0),
            ("mark@example.com", "see you tomorrow", 0.0),
            ("dog@example.com", "save loan money", 1.0),
            ("xyz@example.com", "save money", 1.0),
            ("top10@example.com", "low interest rate", 1.0),
            ("marketing@example.com", "cheap loan", 1.0)
        )).toDF("email", "message", "label")

        // 1.使用分词器,对信息内容进行分词,指定输入与输出列
        val tokenizer: Tokenizer = new Tokenizer().setInputCol("message").setOutputCol("words")

        // 2.哈希词频统计,将同一个单词分配到同一个分区
        val hashingTF: HashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features")

        // 3.创建逻辑回归模型
        val lr = new LogisticRegression()

        // 设置参数
        lr.setMaxIter(10).setRegParam(0.01)

        // 4.设置管线,进行组合
        val pipeline: Pipeline = new Pipeline().setStages(Array(tokenizer,hashingTF, lr))

        // 5.生成训练模型
        val model: PipelineModel = pipeline.fit(train_data)

        // 6.创建测试数据集
         val test: DataFrame = sc.createDataFrame(Seq(
          ("you@example.com", "ab how are you"),
          ("jain@example.com", "ab hope doing well"),
          ("caren@example.com", "ab want some money"),
          ("zhou@example.com", "ab secure loan"),
          ("ted@example.com", "ab need loan")
        )).toDF("email", "message")

        // 7.对测试集进行预测
        model.transform(test)
                .select("email","message","prediction")
                .show()

    }

}

参考博客:Spark(五)————MLlib

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/45803.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EasyRecovery2022中文版电脑端数据恢复软件

EasyRecovery2023数据恢复软件是一款文件恢复软件,能够恢复内容类型非常多,包括办公文档、文件夹、电子邮件、照片、音频等一些常用文件类型都是可以进行恢复,操作非常简单,只需要将存储设备连接到电脑上,运行EasyReco…

【全志T113-S3_100ask】16-1 linux系统驱动四线电阻屏(tpadc、tslib)

【全志T113-S3_100ask】16-1 linux系统使用TPADC驱动四线电阻屏(rtp、tslib)(一)背景(二)焊接鬼才(三)解析input上报事件(四)C语言解析input上报事件&#xf…

大数据技术——Flume简介安装配置使用案例

文章目录1. Flume 概述1.1 Flume简介1.2 Flume的特点1.3 Flume的基础架构2. Flume安装配置2.1 下载地址2.2 安装部署3. Flume 使用案例3.1 实时监控单个追加文件3.2 实时监控目录下多个新文件3.3 实时监控目录下的多个追加文件1. Flume 概述 1.1 Flume简介 Flume是一种可配置、…

【Linux】Linux的环境变量(PATH、env、子进程继承环境变量等)

文章目录环境变量1、从一个小案例认识环境变量PATH2、常用的环境变量相关指令与系统调用3、子进程如何继承环境变量的?4、测试其它环境变量环境变量 1、从一个小案例认识环境变量PATH 我们在shell中通过file查看文件信息,看到我们常使用的指令都是可执…

C++ 类的静态成员详解

目录 前言 一、类的静态成员 1.static关键字 2.静态成员变量 3.静态成员函数 二、程序样例 1.程序演示 2.程序截图 总结 前言 本文记录C中 static 修饰类成员成为静态成员,其中包括静态成员类别、作用和程序演示。 嫌文字啰嗦的可直接跳到最后的总结。 一、类的静…

特征提取 - 骨架、中轴和距离变换

目录 1. 介绍 骨架 skeleton 中轴变换 Medial axis transformation 距离变换 distance transform 2. 距离变换的代码实现 distanceTransform 函数介绍 normalize 函数介绍 取局部最大值 完整代码 3. comparation 1. 介绍 骨架 skeleton 骨架的定义:就是…

【毕业设计】33-基于单片机的直流电机的转速检测与控制设计(原理图工程+PCB工程+源代码工程+仿真工程+答辩论文)

typora-root-url: ./ 【毕业设计】33-基于单片机的直流电机的转速检测与控制设计(原理图工程PCB工程源代码工程仿真工程答辩论文) 文章目录typora-root-url: ./【毕业设计】33-基于单片机的直流电机的转速检测与控制设计(原理图工程PCB工程源…

盘点国内主流数字孪生厂商!你了解几家?

在国内,主流的数字孪生解决方案厂商包括华龙迅达、精航伟泰、羚数智能、力控科技、华力创通、同元软控、优也科技、51world、卡奥斯、摩尔元数、易知微、木棉树软件等。由于中国数字孪生市场仍处于早期发展阶段,且受限于建模、仿真和基于数据融合的数字线…

基于单RGB相机的全新三维表示方法|NeurIPS 2022

随着深度学习的发展,基于单张RGB图像的人体三维重建取得了持续进展。 但基于现有的表示方法,如参数化模型、体素栅格、三角网格和隐式神经表示,难以构筑兼顾高质量结果和实时速度的系统。 针对上述问题,天津大学团队联合清华大学…

Linux用户管理

文章目录一. 引子二. 用户管理1. 用户切换2. 注销用户3. 添加用户4. 设置用户密码5. 删除用户6. 查询用户信息三. 用户组管理1. 新增用户组2. 新增用户时添加组3. 修改用户的组四. 用户和组相关文件1. /etc/passwd2. /etc/shadow3. /etc/group一. 引子 Linux是一个多用户、多任…

【JavaScript作用域】

JavaScript作用域1 本节目标2 作用域2.1 作用域概述2.2 全局作用域2.3 局部作用域3 变量的作用域3.1 变量作用域的分类3.2 全局变量3.3 局部变量3.4 从执行效率看全局变量与局部变量3.5 JS没有块级作用域4 作用域链1 本节目标 说出JavaScript的两种作用域区分全局变量和局部变…

TinyML:是否是FPGA在人工智能方面的最佳应用?

TinyML 也是机器学习的一种,他的特点就是缩小深度学习网络可以在微型硬件中使用,主要应用在智能设备上。超低功耗嵌入式设备正在“入侵”我们的世界,借助新的嵌入式机器学习框架,它们将进一步推动人工智能驱动的物联网设备的普及。…

机器学习:一文从入门到读懂PCA(主成分分析)

深度学习:PCA白化前置知识内积的几何意义基基变换不同基下的向量变换逆矩阵不同基下的空间变换方差协方差协方差矩阵协方差矩阵对角化特征值分解、空间变换主成分分析(PCA)两个原则公式推导求解流程代码实现PCA的优缺点优点缺点前置知识 维度…

【测试沉思录】18.如何测试微信小程序?

作者:雷远缘 编辑:毕小烦 一. 先知道小程序是什么 啥是小程序? “小程序是一种不需要下载安装即可使用的应用,它实现了应用 “触手可及” 的梦想,用户扫一扫或者搜一下即可打开应用。也体现了 “用完即走” 的理念&am…

[附源码]Python计算机毕业设计SSM基于Java的民宿运营管理网站(程序+LW)

环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 Maven管理等…

详解 Spring Boot 项目中的配置文件

目录 1. Spring Boot 项目中配日文件的作用是什么 2. Spring Boot 配置文件的两种格式 3. properties 配置文件 3.1 properties 配置文件的基本语法 3.2 properties 配置文件的分类 3.3 如何读取配置文件 3.4 properties 配置文件的优缺点分析 4. yml 配置文件 4.1 yml …

【JavaSE】初识泛型

大家好!我是保护小周ღ,本期为大家带来的是 Java的泛型,会来大家初步了解什么是泛型,以及泛型的使用,感受一手泛型的思想,面向对象编程太爽了~ 目录 一、泛型是什么? 二、泛型的语法 三、包…

Java给图片增加水印,根据图片大小自适应,右下角/斜角/平铺

Hi,I’m Shendi 最近写自己的文件服务器,上传图片时需要自动增加水印,在这里记录一下 文章目录效果展示读取图片从 byte[] 读取图片获取画板绘制水印根据图片大小自适应水印大小右下角文字水印斜角水印平铺水印图片水印输出图片水印就是在图片…

《剑指 Offer 》—58 - I. 翻转单词顺序

《剑指 Offer 》—58 - I. 翻转单词顺序 注意:本题与151 题相同:https://leetcode-cn.com/problems/reverse-words-in-a-string/ 注意:此题对比原题有改动 文章目录《剑指 Offer 》—58 - I. 翻转单词顺序一、题目内容二、个人答案&#xf…

Git 打patch (打补丁)的使用

patch 的使用 一般是diff ,apply ,format-patch,am 1 生成patch git diff > test.patch 这个是打补丁(test.patch自己取的名字,这个命令可以看出没有指定修改的问题所以默认把所有修改的文件都打patch了,同时还需要注意,这里是本地修改的没有执行add缓存的) 如果想指定某…