Spark MLlib实践指南:从大数据推荐系统到客户流失预测的全流程建模

news2025/1/10 17:16:23

问题一

背景:

本题目基于用户数据,将据数据切分为训练集和验证集,供建模使用。训练集与测试集切分比例为8:2。

数据说明:

capter5_2ml.csv中每列数据分别为userId , movieId , rating , timestamp。

数据:

capter5_2ml.csv

题目:

使用Spark MLlib中的使用ALS算法给每个用户推荐某个商品。:

要求:

  ①设置迭代次数为5次,惩罚系数为0.01,得到评分的矩阵形式(2分)。

②对模型进行拟合,训练出合适的模型(2分)。

③为一组指定的用户生成十大电影推荐(4分)。

④生成前十名用户推荐的一组指定的电影(4分)。

⑤对结果进行正确输出(1分)。

代码:

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;

public class MovieRecommender {
    public static void main(String[] args) {
        // Step 1: 初始化Spark会话
        SparkSession spark = SparkSession.builder()
                .appName("MovieRecommender")
                .master("local[*]") // 本地模式运行
                .getOrCreate();

        // Step 2: 读取数据
        Dataset<Row> data = spark.read()
                .option("header", "true")
                .csv("capter5_2ml.csv");

        // 数据类型转换(userId、movieId、rating)
        data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType))
                .withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType))
                .withColumn("rating", data.col("rating").cast(DataTypes.FloatType));

        // 显示数据的前5行
        data.show(5);

        // Step 3: 数据划分为训练集和测试集,比例为8:2
        Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});
        Dataset<Row> training = splits[0];
        Dataset<Row> test = splits[1];

        // Step 4: 使用ALS模型训练
        ALS als = new ALS()
                .setMaxIter(5)            // 设置迭代次数
                .setRegParam(0.01)        // 设置正则化参数
                .setUserCol("userId")     // 用户ID列
                .setItemCol("movieId")    // 物品ID列
                .setRatingCol("rating")   // 评分列
                .setColdStartStrategy("drop"); // 丢弃冷启动数据

        // 模型拟合,训练模型
        ALSModel model = als.fit(training);

        // Step 5: 模型评估
        Dataset<Row> predictions = model.transform(test);
        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setMetricName("rmse")
                .setLabelCol("rating")
                .setPredictionCol("prediction");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("Root-mean-square error = " + rmse);

        // Step 6: 为每个用户生成前10个电影推荐
        Dataset<Row> userRecs = model.recommendForAllUsers(10);
        userRecs.show(10, false); // 展示每个用户推荐的前10部电影

        // Step 7: 为前10名用户推荐指定电影
        Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);
        Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);
        topUserRecs.show(false); // 展示前10个用户推荐的电影列表

        // Step 8: 输出推荐的电影和评分
        userRecs.select(col("userId"), explode(col("recommendations")).as("rec"))
                .select(col("userId"), col("rec.movieId"), col("rec.rating"))
                .show(false);

        // 关闭Spark会话
        spark.stop();
    }
}

初始化Spark会话

SparkSession spark = SparkSession.builder()

        .appName("MovieRecommender")

        .master("local[*]") // 本地模式运行

        .getOrCreate();
  • SparkSession:是Spark 2.0之后推荐使用的上下文对象,代替了旧版本的SQLContext和HiveContext。通过SparkSession.builder()来创建Spark会话。
  • .appName("MovieRecommender"):指定应用名称,方便在Spark UI中识别。
  • .master("local[*]"):指定运行模式为本地模式(local[*]表示使用所有可用的CPU核心)。
  • .getOrCreate():创建或获取现有的SparkSession。

2. 读取数据

Dataset<Row> data = spark.read()

        .option("header", "true")

        .csv("capter5_2ml.csv");
  • spark.read():使用Spark的DataFrame API来读取数据。
  • .option("header", "true"):指定文件的第一行是表头,这样可以自动识别列名。
  • .csv("capter5_2ml.csv"):读取CSV文件,创建一个Dataset<Row>对象(类似于DataFrame)。

3. 转换数据类型

data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType))

           .withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType))

           .withColumn("rating", data.col("rating").cast(DataTypes.FloatType));
  • withColumn():用来创建新的列或修改已有列的值。
  • .cast(DataTypes.IntegerType):将userId和movieId列的数据类型转换为整数。
  • .cast(DataTypes.FloatType):将rating列的数据类型转换为浮点数。

4. 数据展示

data.show(5);

  • show(5):显示前5行数据,方便确认数据读取和转换是否正确。

5. 数据集划分为训练集和测试集

Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});

Dataset<Row> training = splits[0];

Dataset<Row> test = splits[1];
  • randomSplit(new double[]{0.8, 0.2}):将数据随机划分为两个部分,80%用于训练,20%用于测试。
  • splits[0]:获取训练集。
  • splits[1]:获取测试集。

6. 使用ALS模型训练

ALS als = new ALS()

        .setMaxIter(5)            // 设置迭代次数

        .setRegParam(0.01)        // 设置正则化参数

        .setUserCol("userId")     // 用户ID列

        .setItemCol("movieId")    // 物品ID列

        .setRatingCol("rating")   // 评分列

        .setColdStartStrategy("drop"); // 丢弃冷启动数据
  • ALS():ALS(交替最小二乘法)是Spark MLlib用于协同过滤推荐系统的算法。
  • setMaxIter(5):设置最大迭代次数为5次。
  • setRegParam(0.01):设置正则化参数,防止过拟合。
  • setUserCol("userId"):指定用户列。
  • setItemCol("movieId"):指定物品列(电影)。
  • setRatingCol("rating"):指定评分列。
  • setColdStartStrategy("drop"):如果在预测时遇到冷启动问题(没有数据的用户或电影),则丢弃这些结果。

7. 模型拟合(训练)

ALSModel model = als.fit(training);

  • als.fit(training):在训练集上训练ALS模型,返回一个ALSModel对象。

8. 模型评估

Dataset<Row> predictions = model.transform(test);

RegressionEvaluator evaluator = new RegressionEvaluator()

        .setMetricName("rmse")

        .setLabelCol("rating")

        .setPredictionCol("prediction");

double rmse = evaluator.evaluate(predictions);

System.out.println("Root-mean-square error = " + rmse);
  • model.transform(test):使用训练好的模型对测试集进行预测,返回预测结果。
  • RegressionEvaluator:回归模型评估器,用于计算预测误差。
  • setMetricName("rmse"):指定使用均方根误差(RMSE)作为评估指标。
  • evaluate(predictions):对预测结果进行评估,计算RMSE值。

9. 为每个用户生成前10个电影推荐

Dataset<Row> userRecs = model.recommendForAllUsers(10);

userRecs.show(10, false);
  • recommendForAllUsers(10):为每个用户生成前10个电影推荐,结果存储在userRecs中。
  • show(10, false):显示前10个用户的推荐列表。

10. 为前10名用户生成推荐的电影

Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);

Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);

topUserRecs.show(false);
  • select("userId").distinct():选择唯一的userId,获取不重复的用户。
  • limit(10):只选择前10个用户。
  • recommendForUserSubset(topUsers, 10):为前10个用户生成推荐的电影列表。

11. 输出推荐的电影和评分

userRecs.select(col("userId"), explode(col("recommendations")).as("rec"))

        .select(col("userId"), col("rec.movieId"), col("rec.rating"))

        .show(false);
  • explode(col("recommendations")):展开推荐的电影列表(每个用户的推荐电影是一个数组,explode将其展开为多行)。
  • select(col("userId"), col("rec.movieId"), col("rec.rating")):选择用户ID、电影ID和评分列进行展示。
  • show(false):显示完整数据。

12. 关闭Spark会话

spark.stop();

  • spark.stop():关闭Spark会话,释放资源。

问题二

背景:银行需要根据贷款用户的数据信息预测其是否有违约的可能,并对违约的可能性进行预测。对于银行业或者小贷机构而言,信用卡以及信贷服务是高风险和高收益的业务,如何通过用户的海量数据挖掘出用户潜在的信息即信用評分,并参与审批业务的决策从而提高了风险防控措施,该过程不仅提高了业务的审批效率而且给予了关键的决策,同时风险防控如果没有监测到位,对于银行业来说会造成不可估量的损失,因此这部分的工作是至关重要的。

本题目基于某贷款用户行为数据,将提供训练集和验证集供建模使用。

数据说明:

数据:

train.csv训练集

test.csv 测试集

测试集test.csv相比训练集只是少了一列Label,是需要我们去建模预测的

参考代码:

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;

public class LoanDefaultPrediction {
    public static void main(String[] args) {
        // 初始化 Spark 会话
        SparkSession spark = SparkSession.builder()
                .appName("LoanDefaultPrediction")
                .master("local[*]")  // 本地模式运行
                .getOrCreate();

        // 读取训练数据
        Dataset<Row> trainData = spark.read()
                .option("header", "true")
                .option("inferSchema", "true")  // 自动推断数据类型
                .csv("train.csv");

        // 选择特征列进行训练 (排除label)
        String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};

        // 将特征列汇总为单一向量
        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(featureColumns)
                .setOutputCol("features");

        // 将训练数据中的特征列组装成特征向量
        Dataset<Row> trainWithFeatures = assembler.transform(trainData);

        // 逻辑回归模型
        LogisticRegression lr = new LogisticRegression()
                .setLabelCol("label")   // 目标列
                .setFeaturesCol("features");  // 特征向量列

        // 训练模型
        lr.fit(trainWithFeatures);

        // 关闭 Spark 会话
        spark.stop();
    }
}

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.SparkSession;

public class LoanDefaultPredictionTest {
    public static void main(String[] args) {
        // 初始化 Spark 会话
        SparkSession spark = SparkSession.builder()
                .appName("LoanDefaultPrediction")
                .master("local[*]")  // 本地模式运行
                .getOrCreate();

        // 读取测试数据(没有label列)
        Dataset<Row> testData = spark.read()
                .option("header", "true")
                .option("inferSchema", "true")  // 自动推断数据类型
                .csv("test.csv");

        // 选择特征列
        String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};

        // 将特征列汇总为单一向量
        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(featureColumns)
                .setOutputCol("features");

        // 将测试数据中的特征列组装成特征向量
        Dataset<Row> testWithFeatures = assembler.transform(testData);

        // 加载之前训练好的逻辑回归模型
        LogisticRegressionModel model = LogisticRegressionModel.load("path_to_saved_model");

        // 使用模型对测试数据进行预测
        Dataset<Row> predictions = model.transform(testWithFeatures);

        // 展示预测结果
        predictions.select("id", "prediction").show();

        // 模型评估
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("label")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");

        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test set accuracy = " + accuracy);

        // 关闭 Spark 会话
        spark.stop();
    }
}

问题三

客户流失已成为每个希望提高品牌忠诚度的公司重点关注的问题,本题目基于某电信公司流失客户数据集,将提供训练集和验证集供建模使用,请回复验证集数据的模型计算结果文件和建模过程文档。

  1. 回复 当前模型的查全率和查准率分别是多少,数据描述
  2. 回复结果文件要求

结果文件请以逗号分隔符文本文件提供,包含以下字段:

  1. 用户标志
  2. 预测是否进入异常状态(1:异常状态;0-非异常状态)
  3. 建模过程文 档请以WORD文档形式提供,需要详细列出数据探索过程和建模思路。

数据说明:

字段名称

字段类型

中文名称和注释

USER_ID

VARCHAR(16)

用户标志(两文件里的用户标志没有关联性)

FLOW

DECIMAL(16)

当月流量(Byte)

FLOW_LAST_ONE

DECIMAL(16)

上一月流量(Byte)

FLOW_LAST_TWO

DECIMAL(16)

上两个月流量(Byte)

MONTH_FEE

DECIMAL(18,2)

当月收入(元)

MONTHS_3AVG

DECIMAL(18,2)

最近3个月平均收入(元)

BINDEXP_DATE

DATE

绑定到期时间

PHONE_CHANGE

INTEGER

当月是否更换终端

AGE

INTEGER

年龄

OPEN_DATE

DATE

开户时间

REMOVE_TAG

CHARACTER(1)

用户状态(‘A’:正常,其他异常)(验证集中不提供此字段)

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.SparkSession;

import org.apache.spark.ml.feature.VectorAssembler;

import org.apache.spark.ml.classification.LogisticRegression;

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;

import org.apache.spark.sql.types.DataTypes;

import static org.apache.spark.sql.functions.*;



public class CustomerChurnPrediction {



    public static void main(String[] args) {

        // 初始化 SparkSession

        SparkSession spark = SparkSession.builder()

                .appName("CustomerChurnPrediction")

                .master("local[*]")

                .getOrCreate();



        // 读取训练数据

        Dataset<Row> trainData = spark.read()

                .option("header", "true")

                .option("inferSchema", "true")

                .csv("train.csv");



        // 数据预处理

        // 转换日期为数值特征

        trainData = trainData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE")))

                             .withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));



        // 去除日期列(已转化为数值特征)

        trainData = trainData.drop("BINDEXP_DATE", "OPEN_DATE");



        // 标签处理,将REMOVE_TAG 'A' 转化为 0,其他为 1

        trainData = trainData.withColumn("label", when(col("REMOVE_TAG").equalTo("A"), 0).otherwise(1));



        // 特征列

        String[] featureColumns = new String[]{"FLOW", "FLOW_LAST_ONE", "FLOW_LAST_TWO", "MONTH_FEE", "MONTHS_3AVG",

                "PHONE_CHANGE", "AGE", "days_until_bind_exp", "days_since_open"};



        // 特征向量组装

        VectorAssembler assembler = new VectorAssembler()

                .setInputCols(featureColumns)

                .setOutputCol("features");



        // 将特征列向量化

        Dataset<Row> trainWithFeatures = assembler.transform(trainData);



        // 训练逻辑回归模型

        LogisticRegression lr = new LogisticRegression()

                .setLabelCol("label")

                .setFeaturesCol("features");



        // 模型拟合

        LogisticRegressionModel model = lr.fit(trainWithFeatures);



        // 读取验证集数据(没有label)

        Dataset<Row> testData = spark.read()

                .option("header", "true")

                .option("inferSchema", "true")

                .csv("test.csv");



        // 转换验证集中的日期为数值特征

        testData = testData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE")))

                           .withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));



        // 移除多余列

        testData = testData.drop("BINDEXP_DATE", "OPEN_DATE");



        // 将验证集特征向量化

        Dataset<Row> testWithFeatures = assembler.transform(testData);



        // 使用模型对验证集进行预测

        Dataset<Row> predictions = model.transform(testWithFeatures);



        // 输出查准率和查全率

        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()

                .setLabelCol("label")

                .setMetricName("areaUnderROC");



        double accuracy = evaluator.evaluate(predictions);

        System.out.println("Model Accuracy = " + accuracy);



        // 提取需要的列并输出到文件

        predictions.select("USER_ID", "prediction")

                   .write()

                   .option("header", "true")

                   .csv("prediction_results.csv");



        // 关闭 SparkSession

        spark.stop();

    }

}

1. 数据理解

  • 任务背景:本次任务的目的是预测电信公司的客户是否会进入异常状态。给定的数据包括用户的流量、收入、终端更换情况、年龄、绑定到期时间等特征,并通过历史数据中的“用户状态”来训练模型。
  • 数据集:提供了训练集和验证集两个数据集,训练集中含有目标变量(用户是否进入异常状态),而验证集中只包含特征数据,需要我们预测其异常状态。
  • 数据字段解释
    • USER_ID:用户的唯一标识。
    • FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO:用户当月及前两个月的流量数据。
    • MONTH_FEE、MONTHS_3AVG:用户当月及最近三个月的平均收入。
    • BINDEXP_DATE、OPEN_DATE:用户的绑定到期时间和开户时间。
    • PHONE_CHANGE:用户是否更换终端。
    • AGE:用户年龄。
    • REMOVE_TAG:训练集中提供的用户状态标签(‘A’表示正常,其他表示异常)。

2. 数据探索和可视化

  • 统计描述:对训练集中的数值型字段(如FLOW、MONTH_FEE、AGE等)进行统计描述,计算其最大值、最小值、均值、标准差等,帮助我们了解数据分布。
    • 最大值最小值:了解数据的范围,判断是否存在异常值。
    • 缺失值:检查是否有缺失数据,对有缺失值的字段,考虑填充或删除。
  • 特征分布分析:分析特征的分布,尤其是与目标变量(REMOVE_TAG)之间的关系。
    • 绘制流量、收入、年龄等特征的分布图,观察是否存在显著差异。

3. 数据预处理

  • 日期处理:将日期类型的特征BINDEXP_DATE和OPEN_DATE转换为数值型特征。比如,将它们转换为距离当前日期的天数,以便模型能理解时间间隔对用户状态的影响。
  • 类别变量处理:对于PHONE_CHANGE等离散类别变量,直接使用数值型(如0或1)表示是否更换终端。
  • 标签处理:将训练集中的REMOVE_TAG字段进行二元分类处理。A表示正常用户,转换为0,其他值表示异常用户,转换为1。
  • 特征归一化:由于不同特征的取值范围可能差别较大(如流量和收入单位不同),我们对数值特征进行标准化或归一化,以提高模型的训练效果。
    • 归一化后的特征将有助于模型更加高效地收敛。

4. 特征工程

  • 特征选择:在特征工程中选择与目标变量相关的特征。根据数据探索的结果,我们使用了以下特征:
    • 流量特征:FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO
    • 收入特征:MONTH_FEE、MONTHS_3AVG
    • 终端更换:PHONE_CHANGE
    • 时间特征:days_until_bind_exp(距离绑定到期的天数)、days_since_open(距离开户的天数)
    • 用户年龄:AGE
  • 特征向量化:在模型训练中,我们将这些特征通过VectorAssembler进行特征向量化,方便输入到机器学习模型中。

5. 模型选择与训练

  • 模型选择:基于当前任务是一个二元分类问题,我们尝试了以下分类模型:
    • 逻辑回归(Logistic Regression):作为一个经典的线性模型,逻辑回归能够很好地处理二元分类问题。
    • 随机森林(Random Forest):能够处理高维数据并且具有良好的泛化能力。
    • 梯度提升树(Gradient Boosting Tree):能够通过迭代的方式进行优化,处理非线性关系。
  • 交叉验证:通过交叉验证(Cross Validation)对模型的超参数进行调优。最终,我们选择了表现最好的模型(例如逻辑回归),并使用其对验证集进行预测。
  • 超参数优化:对逻辑回归的正则化参数(regParam)和最大迭代次数(maxIter)进行调参。最终选择了合适的参数组合。

6. 模型评估

  • 模型性能指标
    • 查准率(Precision):表示模型预测出的正样本中有多少是实际的正样本。
    • 查全率(Recall):表示实际的正样本中有多少被模型正确识别为正样本。
    • F1-score:查准率和查全率的调和平均数,用来综合评估模型性能。
  • 混淆矩阵:通过混淆矩阵展示模型在测试集上的表现,查看真阳性(True Positive, TP)、假阳性(False Positive, FP)、真阴性(True Negative, TN)和假阴性(False Negative, FN)的数量,并由此计算出准确率、查准率和查全率。
  • ROC曲线与AUC:使用ROC曲线和AUC值评估模型的分类能力。AUC值越接近1,表示模型性能越好。

7. 模型预测与输出

  • 验证集预测:将验证集通过训练好的模型进行预测,生成每个用户的异常状态预测结果。
  • 结果输出:生成预测结果文件,包含用户标志和预测结果:
    • USER_ID:用户标志。
    • PREDICTION:预测结果,1 表示异常,0 表示正常。

结果文件以逗号分隔的CSV格式输出。

8. 结论与改进方向

  • 结论:当前模型能够较为准确地预测用户是否进入异常状态,但在某些特定情况下(如数据不平衡时),可能会导致查全率偏低。通过调优参数或采用其他更复杂的模型(如XGBoost),有望进一步提升模型的性能。
  • 改进方向
    • 处理数据不平衡问题(通过下采样或上采样)。
    • 尝试更多高级的模型,如XGBoost或深度学习模型。
    • 增加特征工程部分,例如对流量特征进行更多的交互处理或聚类分析。

Spark MLlib 完整代码总结

以下是使用 Spark MLlib 进行机器学习任务的完整流程代码总结。代码包含从数据预处理、特征工程到模型训练、评估和预测的各个步骤,适用于处理典型的二分类问题。

1. 引入依赖

在编写 Spark 应用时,首先需要引入所需的依赖包和库:

import org.apache.spark.sql.SparkSession;

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.functions;

import org.apache.spark.sql.types.*;

import org.apache.spark.ml.Pipeline;

import org.apache.spark.ml.PipelineModel;

import org.apache.spark.ml.feature.*;

import org.apache.spark.ml.classification.LogisticRegression;

import org.apache.spark.ml.classification.LogisticRegressionModel;

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;

import org.apache.spark.ml.tuning.CrossValidator;

import org.apache.spark.ml.tuning.ParamGridBuilder;

import org.apache.spark.ml.tuning.CrossValidatorModel;

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;

import org.apache.spark.ml.linalg.Vector;

2. 创建 Spark 会话

SparkSession spark = SparkSession.builder()

    .appName("Spark ML Example")

    .master("local[*]") // 可以根据环境修改

    .getOrCreate();

3. 加载数据

假设我们有两个数据集:训练集 train.csv 和测试集 test.csv,需要进行数据加载:

Dataset<Row> trainData = spark.read().option("header", "true")

    .option("inferSchema", "true")  // 推断数据类型

    .csv("path/to/train.csv");



Dataset<Row> testData = spark.read().option("header", "true")

    .option("inferSchema", "true")

    .csv("path/to/test.csv");

4. 数据预处理

4.1 处理缺失值

// 对数值列进行缺失值处理(比如用平均值填充)

trainData = trainData.na().fill(0); // 或者 .fill(“默认值”)

4.2 将标签转化为数值

// 假设标签列为 “label”

StringIndexer labelIndexer = new StringIndexer()

    .setInputCol("label")

    .setOutputCol("indexedLabel")

    .fit(trainData);

4.3 特征处理:向量组装

// 将所有特征列转化为特征向量

VectorAssembler assembler = new VectorAssembler()

    .setInputCols(new String[]{"feature1", "feature2", "feature3"})  // 根据实际特征列名修改

    .setOutputCol("features");



trainData = assembler.transform(trainData);

testData = assembler.transform(testData);

5. 模型训练

5.1 逻辑回归模型

LogisticRegression lr = new LogisticRegression()

    .setLabelCol("indexedLabel") // 标签列

    .setFeaturesCol("features")  // 特征列

    .setMaxIter(10)

    .setRegParam(0.01);

5.2 构建管道

Pipeline pipeline = new Pipeline()

    .setStages(new PipelineStage[]{labelIndexer, assembler, lr});

5.3 交叉验证与模型调优

// 参数网格调优

ParamGridBuilder paramGrid = new ParamGridBuilder()

    .addGrid(lr.regParam(), new double[]{0.1, 0.01})

    .addGrid(lr.maxIter(), new int[]{10, 20});



// 二分类评估器

BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()

    .setLabelCol("indexedLabel");



// 交叉验证

CrossValidator cv = new CrossValidator()

    .setEstimator(pipeline)

    .setEvaluator(evaluator)

    .setEstimatorParamMaps(paramGrid.build())

    .setNumFolds(5);



CrossValidatorModel cvModel = cv.fit(trainData);

6. 模型评估

6.1 在测试集上进行预测

Dataset<Row> predictions = cvModel.transform(testData);



// 显示前几条预测结果

predictions.select("user_id", "prediction", "probability").show(5);

6.2 计算模型的性能指标

// 评估 AUC(Area Under ROC Curve)

double auc = evaluator.evaluate(predictions);

System.out.println("AUC: " + auc);



// 混淆矩阵

MulticlassClassificationEvaluator multiEval = new MulticlassClassificationEvaluator()

    .setLabelCol("indexedLabel")

    .setMetricName("accuracy");



double accuracy = multiEval.evaluate(predictions);

System.out.println("Test Accuracy = " + accuracy);

7. 结果输出

将预测结果导出为 CSV 文件:

// 只导出用户ID和预测结果

predictions.select("user_id", "prediction")

    .write()

    .option("header", "true")

    .csv("path/to/output_predictions.csv");

8. 模型持久化

将模型保存以供未来使用:

cvModel.write().overwrite().save("path/to/saved_model");

9. 模型加载(如有需要)

如果需要加载保存的模型以便进行预测:

CrossValidatorModel loadedModel = CrossValidatorModel.load("path/to/saved_model");



// 使用加载的模型进行预测

Dataset<Row> newPredictions = loadedModel.transform(newData);

10. 建模总结

  1. 数据探索
    • 对数据进行缺失值处理和简单的统计描述分析。
    • 特征和标签的处理,确保数据符合模型的输入要求。
  2. 特征工程
    • 将数值列转化为向量格式。
    • 使用StringIndexer对标签列进行编码处理。
  3. 模型训练
    • 选择逻辑回归作为初始模型。
    • 使用交叉验证和参数网格进行模型调优。
  4. 模型评估
    • 使用 AUC 和准确率作为评估指标。
    • 根据模型的性能优化超参数。
  5. 结果输出
    • 输出预测结果到文件,并保存模型以供未来预测使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2155775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详解 Linux 系统下的进程(下)

目录 一.进程控制 1.进程创建 a.Linux 系统中&#xff0c;如何创建一个进程&#xff1f; b.进程创建成功后&#xff0c;Linux 底层会为其做些什么&#xff1f; 2.进程终止 a.什么是进程终止&#xff1f; b.进程终止的方法有哪些&#xff1f; c.exit 与 _exit的区别 3.…

通过logstash同步elasticsearch数据

1 概述 logstash是一个对数据进行抽取、转换、输出的工具&#xff0c;能对接多种数据源和目标数据。本文介绍通过它来同步elasticsearch的数据。 2 环境 实验仅仅需要一台logstash机器和两台elasticsearch机器&#xff08;elasticsearch v7.1.0&#xff09;。本文用docker来模…

NLP 序列标注任务核心梳理

句向量标注 用 bert 生成句向量用 lstm 或 bert 承接 bert 的输出&#xff0c;保证模型可以学习到内容的连续性。此时 lstm 输入形状为&#xff1a; pooled_output.unsqueeze(0) (1, num_sentence, vector_size) 应用场景 词性标注句法分析 文本加标点 相当于粗粒度的分词任…

实时同步 解决存储问题 sersync

目录 1.sersync服务 2.sersync同步整体架构 ​编辑 3.rsync服务准备 4.sersync部署使用 5.修改配置文件 6.启动sersync 7.接入nfs服务 8.联调测试 1.sersync服务 sersync服务其实就是由两个服务组成一个是inotify服务和rsync服务组成 inotify服务用来监控那个…

Linux 文件系统(上)

目录 一.预备阶段 1.认识文件 2.OS对内存文件的管理 3.C库函数和系统调用接口 a.C库函数——fopen b.系统调用接口——open 二.理解文件描述符 1.一张图&#xff0c;详解文件描述符的由来 2.fd的分配规则 3.从fd的角度理解FILE 三.重定向和缓冲区 1.前置知识——理解…

网络安全-CSRF

一、环境 DVWA网上找 二、简单介绍 这个漏洞很早之前了&#xff0c;但是为了避免大家在面试等等的时候被问到&#xff0c;这里给大家温习一下 CSRF全程是没有黑客参与的&#xff0c;全程都是用户自己在操作 三、环境演练 这个是DVWA的提交表单页面&#xff0c;我这里伪造…

【2020工业图像异常检测文献】PaDiM

PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization 1、Background 在单类学习&#xff08;仅使用正常数据&#xff08;即“单一类”&#xff09;来训练模型&#xff09;环境中的异常检测和定位任务方法中&#xff0c;要么需要深度神经网…

结合HashMap与Java 8的Function和Optional消除ifelse判断

shigen坚持更新文章的博客写手&#xff0c;记录成长&#xff0c;分享认知&#xff0c;留住感动。个人IP&#xff1a;shigen 在文章的开头我们先从这些场景进入本期的问题&#xff1a; 业务代码中各种if-else有遇到过吗&#xff0c;有什么好的优化方式&#xff1b;java8出来这么…

鸿蒙开发(NEXT/API 12)【跨设备互通特性简介】协同服务

跨设备互通提供跨设备的相机、扫描、图库访问能力&#xff0c;平板或2in1设备可以调用手机的相机、扫描、图库等功能。 说明 本章节以拍照为例展开介绍&#xff0c;扫描、图库功能的使用与拍照类似。 用户在平板或2in1设备上使用富文本类编辑应用&#xff08;如&#xff1a;…

学习 git 命令行的简单操作, 能够将代码上传到 Gitee 上

首先登录自己的gitee并创建好仓库 将仓库与Linux终端做链接 比如说我这里已经创建好了一个我的Linux学习仓库 点开克隆/下载&#xff1a; 在你的终端中粘贴上图中1中的指令 此时他会让你输入你的用户名和密码&#xff0c;用户名就是上图中3中Username for ....中后面你的一个…

预付费计量系统实体模型

1. 预付费计量系统实体模型 A generic entity model for electricity payment metering systems is shown in Figure 2. Although it provides a limited perspective, it does serve to convey certain essential concepts. 关于电子式预付费电表系统的实体模型见图 2…

李宏毅结构化学习 03

文章目录 一、Sequence Labeling 问题概述二、Hidden Markov Model(HMM)三、Conditional Random Field(CRF)四、Structured Perceptron/SVM五、Towards Deep Learning 一、Sequence Labeling 问题概述 二、Hidden Markov Model(HMM) 上图 training data 中的黑色字为x&#xff…

如何备份SqlServer数据库

第一步&#xff1a;登录你要备份的服务器数据库ssms 第二步&#xff1a;选择你要备份的数据库 此处已PZ-SJCS 数据库为例 右键该数据库-->任务-->备份 第三步&#xff1a;选择你备份的类型备份组件等&#xff0c;目标磁盘 &#xff0c;点击添加选择将你备份的文件备份那…

全面详尽的 PHP 环境搭建教程

目录 目录 PHP 环境搭建概述 在 Windows 上搭建 PHP 环境 使用集成环境 XAMPP 安装步骤 配置和测试 常用配置 手动安装 Apache、PHP 和 MySQL 安装 Apache 安装 PHP 安装 MySQL 配置 PHP 连接 MySQL 在 Linux 上搭建 PHP 环境 使用 LAMP 方案 安装 Apache 安装 …

【25.6】C++智能交友系统

常见错误总结 const-1 如下代码会报错 原因如下&#xff1a; man是一个const修饰的对象&#xff0c;即man不能修改任何内容&#xff0c;但是man所调用的play函数只是一个普通的函数&#xff0c;所以出现了报错。我们需要在play函数中加上const修饰&#xff0c;或者删除man对…

《论分布式存储系统架构设计》写作框架,软考高级系统架构设计师

论文真题 分布式存储系统&#xff08;Distributed Storage System&#xff09;通常将数据分散存储在多台独立的设备上。传统的网络存储系统采用集中的存储服务器存放所有数据&#xff0c;存储服务器成为系统性能的瓶颈&#xff0c;也是可靠性和安全性的焦点&#xff0c;不能满…

FreeRTOS-时间片调度

FreeRTOS-时间片调度 一、时间片调度简介二、时间片调度实验 一、时间片调度简介 同等优先级任务轮流的享有相同的CPU时间(可设置)&#xff0c;叫时间片&#xff0c;在FreeRTOS中&#xff0c;一个时间片就等于SysTick中断周期&#xff0c;所以说时间片大小取决于滴答定时器中断…

windows安装Anaconda教程

一、简介 Anaconda 是一个开源的 Python 和 R 语言的分发平台&#xff0c;专为科学计算和数据分析设计。它包含了包管理器 Conda&#xff0c;可以方便地安装和管理库、环境和依赖项。此外&#xff0c;Anaconda 还附带了许多数据科学工具和库&#xff0c;如 Jupyter Notebook 和…

【HTTPS】中间人攻击和证书的验证

中间人攻击 服务器可以创建出一堆公钥和私钥&#xff0c;黑客也可以按照同样的方式&#xff0c;创建一对公钥和私钥&#xff0c;冒充自己是服务器&#xff08;搅屎棍&#xff09; 黑客自己也能生成一对公钥和私钥。生成公钥和私钥的算法是开放的&#xff0c;服务器能生产&…

iOS17找不到developer mode

iOS17找不到开发者模式 developer mode 下载过app之后、弹窗Developer Mode Required之后&#xff0c;这个菜单就出现了&#xff08;之前死活找不到&#xff09;。 背景&#xff1a;用蒲公英分发测试app&#xff0c;有个同事买了新机(iphone 15 pro max)&#xff0c;添加了白名…